In [1]:
import numpy as np
import copy
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F

import time, pathlib, os, sys

## 2D dataset

In [2]:
class LipschitzNet(nn.Module):
    
    def __init__(self, net, K=1., lamda=1.):
        super().__init__()
        
        self.net = net
        self.lamda = lamda
        self.X = None
        
        self.Y = None
        self.dydx = None
        self.K = K
        
        self.gp = 0
        self.gclipper = torch.Tensor([999])
        
    def forward(self, x):
        self.X = torch.autograd.Variable(x, requires_grad=True)
        
        self.Y = self.net(self.X)
        y = self.Y+0.
        y.register_hook(self.scale_gradient_back)
        return y
    
    def scale_gradient_back(self, grad):
#         print("Old", grad.shape)
#         grad = torch.min(torch.max(grad, -self.gclipper), self.gclipper)
#         print("New", grad.shape)
#         return grad
#         return torch.min(torch.max(grad, -self.gclipper), self.gclipper)
        return torch.minimum(torch.maximum(grad, -self.gclipper), self.gclipper)

    
    def get_dydx(self):
        self.dydx = torch.autograd.grad(outputs=self.Y, inputs=self.X,
                                    grad_outputs=torch.ones_like(self.Y),
                                    only_inputs=True, retain_graph=True, create_graph=True)[0]
        return self.dydx
    
    def get_gradient_penalty(self):
#         m = self.dydx.shape[0]
        dydx_norm = torch.norm(self.dydx, p=2, dim=1, keepdim=True)
#         self.cond = -(torch.abs(dydx_norm/self.K) -1.)
        self.cond = -dydx_norm/self.K +1.
        
#         a=-20
#         intolerables = torch.log(torch.exp(a*(self.cond-0.1))+1)/a
#         dydx_norm = torch.norm(self.dydx.data, p=2, dim=1, keepdim=True)
#         self.gp = 0.5*((intolerables*5)**2).mean()*self.lamda

#         intolerables = torch.clamp(F.softplus(self.cond-0.1, beta=-20), -1, 1)
        intolerables = F.softplus(self.cond-0.1, beta=-20)
#         intolerables = F.softplus(F.softplus(self.cond-0.1, beta=-20)+2, beta=5)-2

        self.gp = (self.smooth_l1(intolerables*5)).mean()*self.lamda
        
    
#         self.gp = 0.5*(intolerables**2).sum()*self.lamda
        
#         intolerables = torch.min(self.cond-1e-1, torch.zeros_like(self.cond))
#         self.gp = 0.5*(intolerables**2).mean()*self.lamda
#         self.gp = torch.abs(intolerables).mean()*self.lamda
#         self.gp = 0.5*(torch.abs(intolerables)+intolerables**2).mean()*self.lamda

        return self.gp
    
    ##
    def get_gradient_clipper(self):
        with torch.no_grad():
            cond = self.cond.data
            linear_mask = cond>0.14845
#             print(dydx.shape, linear_mask.shape)
            a = 20.
            gclipper = -((1.05*(cond-1))**4)+1
            gclipper = torch.log(torch.exp(a*gclipper)+1)/a
#             print(gclipper.shape)
            gc2 = 3*cond-0.0844560006
#             print(gc2.shape)
            gclipper[linear_mask] = gc2[linear_mask]
#             print(gclipper.shape)
#             gclipper = torch.clamp(gclipper, min=0.01)
            self.gclipper = gclipper
            
        return self.gclipper

    def smooth_l1(self, x, beta=1):
        mask = x<beta
        y = torch.empty_like(x)
        y[mask] = 0.5*(x[mask]**2)/beta
        y[~mask] = torch.abs(x[~mask])-0.5*beta
        return y
    
    def compute_penalty_and_clipper(self):
        self.get_dydx()
        self.get_gradient_penalty()
        self.get_gradient_clipper()
        return

In [3]:
class GradientPenaltyNet(nn.Module):
    
    def __init__(self, net, K=1., lamda=1.):
        super().__init__()
        
        self.net = net
        self.lamda = lamda
        self.X = None
        
        self.Y = None
        self.dydx = None
        self.K = K
        
        self.gp = 0
        
    def forward(self, x):
        self.X = torch.autograd.Variable(x, requires_grad=True)
        
        self.Y = self.net(self.X)
        y = self.Y+0.
        return y

    def get_dydx(self):
        self.dydx = torch.autograd.grad(outputs=self.Y, inputs=self.X,
                                    grad_outputs=torch.ones_like(self.Y),
                                    only_inputs=True, retain_graph=True, create_graph=True)[0]
        return self.dydx
    
    def get_gradient_penalty(self):
        dydx_norm = torch.norm(self.dydx, p=2, dim=1, keepdim=True)
        
        self.gp = ((dydx_norm-1)**2).mean()*self.lamda
        return self.gp
    
    def MSE_loss(self, diff):
        return 0.5*(diff**2).mean()
    
    def compute_penalty(self):
        self.get_dydx()
        self.get_gradient_penalty()
        return

In [4]:
class LipschitzPenaltyNet(nn.Module):
    
    def __init__(self, net, K=1., lamda=1.):
        super().__init__()
        
        self.net = net
        self.lamda = lamda
        self.X = None
        
        self.Y = None
        self.dydx = None
        self.K = K
        
        self.gp = 0
        
    def forward(self, x):
        self.X = torch.autograd.Variable(x, requires_grad=True)
        
        self.Y = self.net(self.X)
        y = self.Y+0.
        return y

    def get_dydx(self):
        self.dydx = torch.autograd.grad(outputs=self.Y, inputs=self.X,
                                    grad_outputs=torch.ones_like(self.Y),
                                    only_inputs=True, retain_graph=True, create_graph=True)[0]
        return self.dydx
    
    def get_gradient_penalty(self):
        dydx_norm = torch.norm(self.dydx, p=2, dim=1, keepdim=True)
        
        self.gp = ((torch.maximum(dydx_norm-1, torch.Tensor([0])))**2).mean()*self.lamda
#         self.gp = ((dydx_norm-1)**2).mean()*self.lamda
        return self.gp
    
    def MSE_loss(self, diff):
        return 0.5*(diff**2).mean()
    
    def compute_penalty(self):
        self.get_dydx()
        self.get_gradient_penalty()
        return

In [5]:
log_file = './lipschitz_out/LOG_grads_constraint.txt'

In [6]:
%matplotlib tk

for data_indx in range(3):
    if data_indx == 1:
        # The two-dimensional domain of the fit.....
        ########https://scipython.com/blog/non-linear-least-squares-fitting-of-a-two-dimensional-data/#########
        x1min, x1max, nx1 = -5, 6, 75
        x2min, x2max, nx2 = -3, 7, 75
        x1, x2 = np.linspace(x1min, x1max, nx1), np.linspace(x2min, x2max, nx2)
        X1, X2 = np.meshgrid(x1, x2)

        # Our function to fit is going to be a sum of two-dimensional Gaussians
        def gaussian(x1, x2, x10, x20, x1alpha, x2alpha, A):
            return A * np.exp( -((x1-x10)/x1alpha)**2 -((x2-x20)/x2alpha)**2)

        # A list of the Gaussian parameters: x10, x20, x1alpha, x2alpha, A
        gprms = [(0, 2, 2.5, 5.4, 1.5),
                 (-1, 4, 6, 2.5, 1.8),
                 (-3, -0.5, 1, 2, 4),
                 (3, 0.5, 2, 1, 5)
                ]

        # Standard deviation of normally-distributed noise to add in generating
        # our test function to fit.
        # The function to be fit is Z.
        Y = np.zeros(X1.shape)
        for p in gprms:
            Y += gaussian(X1, X2, *p)
        ### Adding noise to the data
        # noise_sigma = 0.1
        # Z += noise_sigma * np.random.randn(*Z.shape)

        ####Scaling the data to range -1,1
        X1 = 2*(X1 - X1.min())/(X1.max() - X1.min()) -1
        X2 = 2*(X2 - X2.min())/(X2.max() - X2.min()) -1
        Y = 2*(Y - Y.min())/(Y.max() - Y.min()) -1
        Y = Y/2

        x1 = X1.reshape(-1)
        x2 = X2.reshape(-1)

        xx = torch.Tensor(np.c_[x1, x2])
        yy = torch.Tensor(Y.reshape(-1,1))
        
        
        ###########draw
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        ax.plot_surface(X1, X2, Y, cmap='viridis')
        ax.set_xlabel('X1')
        ax.set_ylabel('X2')
        ax.set_zlabel('Y')
        # plt.show()
        plt.savefig("./lipschitz_out/data_regression_2.pdf", bbox_inches='tight')
        plt.close()
        
        plt.figure(figsize=(6/6*5,6.2/6*5))
        plt.axis('equal')
        lvls = 12
        plt.contourf(X1, X2, Y, levels=lvls)#, cmap=matplotlib.cm.bwr)

        clrs = ['white','white','white','white','white','white','white','k','k','k','k','k']
        cs = plt.contour(X1, X2, Y, levels=lvls, linestyles="None", colors="k", linewidths=1)
        plt.clabel(cs, cs.levels, inline=True, fontsize=10, fmt="%1.1f", colors=clrs)

        plt.locator_params(axis='x', nbins=5)
        plt.locator_params(axis='y', nbins=5)
        plt.xlabel("x1")
        plt.ylabel("x2")
        plt.xlim(-1,1)
        plt.ylim(-1,1)

        plt.savefig("./lipschitz_out/data_regression_2c.pdf", bbox_inches='tight')
        plt.close()
        
    elif data_indx == 0:
        num_points = 50
        X1 = np.linspace(-2.5, 1.5, num_points)
        # X1 = np.linspace(-2.5, 0, num_points)
        X2 = np.linspace(-2, 4, num_points)
        # X2 = np.linspace(-2, 2, num_points)
        X1, X2 = np.meshgrid(X1, X2)
        Y = np.sin(np.sqrt(X1**2 + X2**2))*2-1.

        ####Scaling the data to range -1,1
        X1 = 2*(X1 - X1.min())/(X1.max() - X1.min()) -1
        X2 = 2*(X2 - X2.min())/(X2.max() - X2.min()) -1
        Y = 2*(Y - Y.min())/(Y.max() - Y.min()) -1
        Y = Y

        x1 = X1.reshape(-1)
        x2 = X2.reshape(-1)

        xx = torch.Tensor(np.c_[x1, x2])
        yy = torch.Tensor(Y.reshape(-1,1))
        
        ###########draw
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        ax.plot_surface(X1, X2, Y, cmap='viridis')
        ax.set_xlabel('X1')
        ax.set_ylabel('X2')
        ax.set_zlabel('Y')
        # plt.show()
        plt.savefig("./lipschitz_out/data_regression_1.pdf", bbox_inches='tight')
        plt.close()
        
        plt.figure(figsize=(6/6*5,6.2/6*5))
        plt.axis('equal')
        lvls = 12
        plt.contourf(X1, X2, Y, levels=lvls)#, cmap=matplotlib.cm.bwr)

        clrs = ['white','white','white','white','white','white','white','k','k','k','k','k']
        cs = plt.contour(X1, X2, Y, levels=lvls, linestyles="None", colors="k", linewidths=1)
        plt.clabel(cs, cs.levels, inline=True, fontsize=10, fmt="%1.1f", colors=clrs)

        plt.locator_params(axis='x', nbins=5)
        plt.locator_params(axis='y', nbins=5)
        plt.xlabel("x1")
        plt.ylabel("x2")
        plt.xlim(-1,1)
        plt.ylim(-1,1)

        plt.savefig("./lipschitz_out/data_regression_1c.pdf", bbox_inches='tight')
        plt.close()

    elif data_indx == 2:
        def twospirals(n_points, noise=.5 , ang=720):
            """
             Returns the two spirals dataset.
            """
            n = np.sqrt(np.random.rand(n_points,1)) * ang * (2*np.pi)/360
            d1x = -np.cos(n)*n + np.random.rand(n_points,1) * noise
            d1y = np.sin(n)*n + np.random.rand(n_points,1) * noise
            return (np.vstack((np.hstack((d1x,d1y)),np.hstack((-d1x,-d1y)))), 
                    np.hstack((np.zeros(n_points),np.ones(n_points))))

        np.random.seed(987)
        # x, y = twospirals(200, ang=420)
        x, y = twospirals(200, ang=400)
        x, Y = x/10, y.reshape(-1)
        X1, X2 = x[:,0], x[:,1]
        
        ####Scaling the data to range -1,1
        X1 = 2*(X1 - X1.min())/(X1.max() - X1.min()) -1
        X2 = 2*(X2 - X2.min())/(X2.max() - X2.min()) -1

        x1 = X1.reshape(-1)
        x2 = X2.reshape(-1)

        xx = torch.Tensor(np.c_[x1, x2])
        yy = torch.Tensor(Y.reshape(-1,1))
        
        ########## draw
        plt.scatter(x1, x2, c=y, s=50, edgecolors='k', lw=0.5)
        plt.grid()
        plt.savefig("./lipschitz_out/data_classification_1_invex.pdf", bbox_inches='tight')
        plt.close()
        
    else:
        raise NotImplementedError("Not implemented Error")
        
        
    
    network_seeds = [147, 258, 369]
    learning_rate = 0.005 #0.01 0.005
    EPOCHS = 7500 #5000 #7500

    

    if data_indx == 2:
        use_sigmoid = True
        criterion = nn.BCELoss()
        actf = nn.LeakyReLU
        
        class Sigmoid4(nn.Module):
            def forward(self,x):
                return torch.sigmoid(x*4)

        nn.Sigmoid = Sigmoid4
    else:
        criterion = nn.MSELoss()
        use_sigmoid = False
        actf = nn.ELU

    
    for lambda_ in [1, 3]:
        
        original_stdout = sys.stdout
        sys.stdout = open(log_file, 'a')
        print("===================================================")
        print(f"Dataset {data_indx}; lambda_={lambda_}")
        print("===================================================")
            # Reset the standard output
        sys.stdout.close()
        sys.stdout = original_stdout 

        #################################################
        per_step_time = []
        info_per_seed = []

        fig = plt.figure(figsize=(15,6))
        ax = fig.add_subplot(121,projection='3d')
        ax2 = fig.add_subplot(122)
        # ax.view_init(28,20)

        for ns in network_seeds:
            torch.manual_seed(ns)

            net_lips = nn.Sequential(nn.Linear(2,10),
                                     actf(),
                                     nn.Linear(10,10),
                                     actf(),
                                     nn.Linear(10,1),
                                     nn.Sigmoid() if use_sigmoid else nn.Identity())

            lipsNet = LipschitzNet(net_lips, K=1, lamda=lambda_)


            optimizer = torch.optim.Adam(lipsNet.parameters(), lr=learning_rate)

            for epoch in range(EPOCHS):
                start = time.time()

                yout = lipsNet(xx)    
                lipsNet.compute_penalty_and_clipper()
                loss = criterion(yout, yy) + lipsNet.gp

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                per_step_time.append(time.time()-start)

                if epoch%100 == 0:
                    min_val = float(lipsNet.cond.min())
                    max_k = float(torch.norm(lipsNet.dydx, p=2, dim=1, keepdim=True).max())
                    min_k = float(torch.norm(lipsNet.dydx, p=2, dim=1, keepdim=True).min())
                    print(f'Epoch: {epoch}, Loss:{float(loss-lipsNet.gp)}, MinCond: {min_val}, MaxK: {max_k}, MinK: {min_k}')

                    ax.clear()
                    ax.scatter(X1, X2, yy.data.numpy().reshape(X1.shape), marker= '.')
                    ax.scatter(X1, X2, yout.data.numpy().reshape(X1.shape), color='r', marker='.')

                    ax2.clear()
                    if use_sigmoid:
                        ax2.scatter(X1, X2, c=yout.data.numpy().reshape(Y.shape), s=70, edgecolors='k', lw=0.5)
                        ax2.scatter(X1, X2, c=Y, marker='.')
                    else:
                        ax2.contourf(X1, X2, yout.data.numpy().reshape(Y.shape), levels=20)

                    fig.canvas.draw()
                    plt.pause(0.01)

            if use_sigmoid:
                acc = float(((yout>0.5).type(torch.float32)==yy).type(torch.float32).mean()*100)
                info_per_seed.append(
                    f'Epoch: {epoch}, Loss:{float(loss-lipsNet.gp)}, MinCond: {min_val}, MaxK: {max_k}, MinK: {min_k}, Acc: {acc}'
                )
            else:
                info_per_seed.append(f'Epoch: {epoch}, Loss:{float(loss-lipsNet.gp)}, MinCond: {min_val}, MaxK: {max_k}, MinK: {min_k}')

                
        original_stdout = sys.stdout
        sys.stdout = open(log_file, 'a')
        print("###############################################")
        print("GC-GP")
        for i in range(len(info_per_seed)):
            print(i, network_seeds[i], info_per_seed[i])
        print(np.mean(per_step_time), np.std(per_step_time))
        print("###############################################")
        # Reset the standard output
        sys.stdout.close()
        sys.stdout = original_stdout 
        
        
        
        plt.close()
        #################################################    
        
        if lambda_ == 1:
            per_step_time = []
            info_per_seed = []

            fig = plt.figure(figsize=(15,6))
            ax = fig.add_subplot(121,projection='3d')
            ax2 = fig.add_subplot(122)
            # ax.view_init(28,20)

            for ns in network_seeds:
                torch.manual_seed(ns)

                net_lips = nn.Sequential(nn.utils.spectral_norm(nn.Linear(2,10)),
                                         actf(),
                                         nn.utils.spectral_norm(nn.Linear(10,10)),
                                         actf(),
                                         nn.utils.spectral_norm(nn.Linear(10,1)),
                                         nn.Sigmoid() if use_sigmoid else nn.Identity())

                ## no clipper and loss used..., just for getting gradient norm.
                snNet = LipschitzNet(net_lips)

                optimizer = torch.optim.Adam(snNet.parameters(), lr=learning_rate)

                for epoch in range(EPOCHS):
                    start = time.time()

                    yout = snNet(xx)
                    loss = criterion(yout, yy)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    per_step_time.append(time.time()-start)

                    if epoch%100 == 0:
                        yout = snNet(xx)
                        snNet.get_dydx()

                        max_k = float(torch.norm(snNet.dydx, p=2, dim=1, keepdim=True).max())
                        min_k = float(torch.norm(snNet.dydx, p=2, dim=1, keepdim=True).min())
                        print(f'Epoch: {epoch}, Loss:{float(loss)}, MaxK: {max_k}, MinK: {min_k}')

                        ax.clear()
                        ax.scatter(X1, X2, yy.data.numpy().reshape(X1.shape), marker= '.')
                        ax.scatter(X1, X2, yout.data.numpy().reshape(X1.shape), color='r', marker='.')

                        ax2.clear()
                        if use_sigmoid:
                            ax2.scatter(X1, X2, c=yout.data.numpy().reshape(Y.shape), s=70, edgecolors='k', lw=0.5)
                            ax2.scatter(X1, X2, c=Y, marker='.')
                        else:
                            ax2.contourf(X1, X2, yout.data.numpy().reshape(Y.shape), levels=20)

                        fig.canvas.draw()
                        plt.pause(0.01)
                        print("\n")

                if use_sigmoid:
                    acc = float(((yout>0.5).type(torch.float32)==yy).type(torch.float32).mean()*100)
                    info_per_seed.append(
                        f'Epoch: {epoch}, Loss:{float(loss)}, MaxK: {max_k}, MinK: {min_k}, Acc: {acc}'
                    )
                else:
                    info_per_seed.append(f'Epoch: {epoch}, Loss:{float(loss)}, MaxK: {max_k}, MinK: {min_k}')


            original_stdout = sys.stdout
            sys.stdout = open(log_file, 'a')
            print("###############################################")
            print("SN")
            for i in range(len(info_per_seed)):
                print(i, network_seeds[i], info_per_seed[i])
            print(np.mean(per_step_time), np.std(per_step_time))
            print("###############################################")
            # Reset the standard output
            sys.stdout.close()
            sys.stdout = original_stdout 

            plt.close()
        #################################################    
        
        per_step_time = []
        info_per_seed = []

        fig = plt.figure(figsize=(15,6))
        ax = fig.add_subplot(121,projection='3d')
        ax2 = fig.add_subplot(122)
        # ax.view_init(28,20)

        for ns in network_seeds:
            torch.manual_seed(ns)

            net_lips = nn.Sequential(nn.Linear(2,10),
                                     actf(),
                                     nn.Linear(10,10),
                                     actf(),
                                     nn.Linear(10,1),
                                     nn.Sigmoid() if use_sigmoid else nn.Identity())

            gpNet = GradientPenaltyNet(net_lips, K=1, lamda=lambda_)

            optimizer = torch.optim.Adam(gpNet.parameters(), lr=learning_rate)
            for epoch in range(EPOCHS):
                start = time.time()

                yout = gpNet(xx)    
                gpNet.compute_penalty()
                loss = criterion(yout, yy) + gpNet.gp


                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                per_step_time.append(time.time()-start)
                if epoch%100 == 0:
                    allk = torch.norm(gpNet.dydx, p=2, dim=1, keepdim=True)
                    max_k = float(allk.max())
                    min_k = float(allk.min())

                    print(f'Epoch: {epoch}, Loss:{float(loss-gpNet.gp)}, MaxK: {max_k} MinK: {min_k}')

                    ax.clear()
                    ax.scatter(X1, X2, yy.data.numpy().reshape(X1.shape), marker= '.')
                    ax.scatter(X1, X2, yout.data.numpy().reshape(X1.shape), color='r', marker='.')

                    ax2.clear()
                    if use_sigmoid:
                        ax2.scatter(X1, X2, c=yout.data.numpy().reshape(Y.shape), s=70, edgecolors='k', lw=0.5)
                        ax2.scatter(X1, X2, c=Y, marker='.')
                    else:
                        ax2.contourf(X1, X2, yout.data.numpy().reshape(Y.shape), levels=20)

                    fig.canvas.draw()
                    plt.pause(0.01)
                    print("\n")

            if use_sigmoid:
                acc = float(((yout>0.5).type(torch.float32)==yy).type(torch.float32).mean()*100)
                info_per_seed.append(
                    f'Epoch: {epoch}, Loss:{float(loss-gpNet.gp)}, MaxK: {max_k}, MinK: {min_k}, Acc: {acc}'
                )
            else:
                info_per_seed.append(f'Epoch: {epoch}, Loss:{float(loss-gpNet.gp)}, MaxK: {max_k}, MinK: {min_k}')

                
        original_stdout = sys.stdout
        sys.stdout = open(log_file, 'a')
        print("###############################################")
        print("GP")
        for i in range(len(info_per_seed)):
            print(i, network_seeds[i], info_per_seed[i])
        print(np.mean(per_step_time), np.std(per_step_time))
        print("###############################################")
        # Reset the standard output
        sys.stdout.close()
        sys.stdout = original_stdout 
        
        plt.close()
        #################################################    

        per_step_time = []
        info_per_seed = []

        fig = plt.figure(figsize=(15,6))
        ax = fig.add_subplot(121,projection='3d')
        ax2 = fig.add_subplot(122)
        # ax.view_init(28,20)

        for ns in network_seeds:
            torch.manual_seed(ns)

            net_lips = nn.Sequential(nn.Linear(2,10),
                                     actf(),
                                     nn.Linear(10,10),
                                     actf(),
                                     nn.Linear(10,1),
                                     nn.Sigmoid() if use_sigmoid else nn.Identity())

            lpNet = LipschitzPenaltyNet(net_lips, K=1, lamda=lambda_)

            optimizer = torch.optim.Adam(lpNet.parameters(), lr=learning_rate)

            for epoch in range(EPOCHS):
                start = time.time()

                yout = lpNet(xx)    
                lpNet.compute_penalty()
                loss = criterion(yout, yy) + lpNet.gp


                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                per_step_time.append(time.time()-start)

                if epoch%100 == 0:
                    allk = torch.norm(lpNet.dydx, p=2, dim=1, keepdim=True)
                    max_k = float(allk.max())
                    min_k = float(allk.min())

                    print(f'Epoch: {epoch}, Loss:{float(loss-lpNet.gp)}, MaxK: {max_k} MinK: {min_k}')

                    ax.clear()
                    ax.scatter(X1, X2, yy.data.numpy().reshape(X1.shape), marker= '.')
                    ax.scatter(X1, X2, yout.data.numpy().reshape(X1.shape), color='r', marker='.')

                    ax2.clear()
                    if use_sigmoid:
                        ax2.scatter(X1, X2, c=yout.data.numpy().reshape(Y.shape), s=70, edgecolors='k', lw=0.5)
                        ax2.scatter(X1, X2, c=Y, marker='.')
                    else:
                        ax2.contourf(X1, X2, yout.data.numpy().reshape(Y.shape), levels=20)

                    fig.canvas.draw()
                    plt.pause(0.01)
                    print("\n")
            if use_sigmoid:
                acc = float(((yout>0.5).type(torch.float32)==yy).type(torch.float32).mean()*100)
                info_per_seed.append(
                    f'Epoch: {epoch}, Loss:{float(loss-lpNet.gp)}, MaxK: {max_k}, MinK: {min_k}, Acc: {acc}'
                )
            else:
                info_per_seed.append(f'Epoch: {epoch}, Loss:{float(loss-lpNet.gp)}, MaxK: {max_k}, MinK: {min_k}')

                
        original_stdout = sys.stdout
        sys.stdout = open(log_file, 'a')
        print("###############################################")
        print("LP")
        for i in range(len(info_per_seed)):
            print(i, network_seeds[i], info_per_seed[i])
        print(np.mean(per_step_time), np.std(per_step_time))
        print("###############################################")
        # Reset the standard output
        sys.stdout.close()
        sys.stdout = original_stdout 
        
        plt.close()

        #################################################    
                

Epoch: 0, Loss:0.4675413966178894, MinCond: 0.8093889951705933, MaxK: 0.19061100482940674, MinK: 0.09312445670366287
Epoch: 100, Loss:0.1101626455783844, MinCond: 0.06541401147842407, MaxK: 0.9345859885215759, MinK: 0.32371005415916443
Epoch: 200, Loss:0.10048636794090271, MinCond: 0.060855329036712646, MaxK: 0.9391446709632874, MinK: 0.08893952518701553
Epoch: 300, Loss:0.09841816872358322, MinCond: 0.06962788105010986, MaxK: 0.9303721189498901, MinK: 0.003657673019915819
Epoch: 400, Loss:0.09749960154294968, MinCond: 0.062367141246795654, MaxK: 0.9376328587532043, MinK: 0.05337119847536087
Epoch: 500, Loss:0.09665811061859131, MinCond: 0.04284536838531494, MaxK: 0.9571546316146851, MinK: 0.08605284988880157
Epoch: 600, Loss:0.09575740993022919, MinCond: 0.041986703872680664, MaxK: 0.9580132961273193, MinK: 0.024017874151468277
Epoch: 700, Loss:0.09485293924808502, MinCond: 0.0292627215385437, MaxK: 0.9707372784614563, MinK: 0.005432970356196165
Epoch: 800, Loss:0.0937163308262825, Mi

Epoch: 6800, Loss:0.08240922540426254, MinCond: 0.07991290092468262, MaxK: 0.9200870990753174, MinK: 0.012734458781778812
Epoch: 6900, Loss:0.08340948820114136, MinCond: 0.07596248388290405, MaxK: 0.924037516117096, MinK: 0.01448710635304451
Epoch: 7000, Loss:0.08375763893127441, MinCond: 0.07110750675201416, MaxK: 0.9288924932479858, MinK: 0.011840272694826126
Epoch: 7100, Loss:0.0828489288687706, MinCond: 0.0769193172454834, MaxK: 0.9230806827545166, MinK: 0.012198586016893387
Epoch: 7200, Loss:0.08240754157304764, MinCond: 0.08247685432434082, MaxK: 0.9175231456756592, MinK: 0.00430163973942399
Epoch: 7300, Loss:0.08296964317560196, MinCond: 0.07838153839111328, MaxK: 0.9216184616088867, MinK: 0.00330479652620852
Epoch: 7400, Loss:0.0828145295381546, MinCond: 0.08073240518569946, MaxK: 0.9192675948143005, MinK: 0.008122405968606472
Epoch: 0, Loss:0.46714794635772705, MinCond: 0.9141939282417297, MaxK: 0.08580604940652847, MinK: 0.04806589335203171
Epoch: 100, Loss:0.109778493642807,

Epoch: 6100, Loss:0.08353894203901291, MinCond: 0.02047741413116455, MaxK: 0.9795225858688354, MinK: 0.007050294429063797
Epoch: 6200, Loss:0.08360236883163452, MinCond: 0.02302253246307373, MaxK: 0.9769774675369263, MinK: 0.00961578544229269
Epoch: 6300, Loss:0.08531726151704788, MinCond: 0.015578687191009521, MaxK: 0.9844213128089905, MinK: 0.004902179352939129
Epoch: 6400, Loss:0.08574827760457993, MinCond: 0.01796698570251465, MaxK: 0.9820330142974854, MinK: 0.014936643652617931
Epoch: 6500, Loss:0.08367084711790085, MinCond: 0.022210240364074707, MaxK: 0.9777897596359253, MinK: 0.003326146164909005
Epoch: 6600, Loss:0.08337639272212982, MinCond: 0.02182769775390625, MaxK: 0.9781723022460938, MinK: 0.004339209292083979
Epoch: 6700, Loss:0.08335385471582413, MinCond: 0.022084057331085205, MaxK: 0.9779159426689148, MinK: 0.0073048812337219715
Epoch: 6800, Loss:0.08558221161365509, MinCond: 0.01958400011062622, MaxK: 0.9804159998893738, MinK: 0.023269852623343468
Epoch: 6900, Loss:0.0

Epoch: 5400, Loss:0.08401594310998917, MinCond: 0.06106007099151611, MaxK: 0.9389399290084839, MinK: 0.02743564359843731
Epoch: 5500, Loss:0.08266556262969971, MinCond: 0.0668715238571167, MaxK: 0.9331284761428833, MinK: 0.014935730025172234
Epoch: 5600, Loss:0.08365947008132935, MinCond: 0.062144696712493896, MaxK: 0.9378553032875061, MinK: 0.026720372959971428
Epoch: 5700, Loss:0.08265198767185211, MinCond: 0.06637781858444214, MaxK: 0.9336221814155579, MinK: 0.01639643684029579
Epoch: 5800, Loss:0.08491005003452301, MinCond: 0.057819366455078125, MaxK: 0.9421806335449219, MinK: 0.030530771240592003
Epoch: 5900, Loss:0.08421174436807632, MinCond: 0.06021994352340698, MaxK: 0.939780056476593, MinK: 0.02863357774913311
Epoch: 6000, Loss:0.08413416892290115, MinCond: 0.06039971113204956, MaxK: 0.9396002888679504, MinK: 0.028615547344088554
Epoch: 6100, Loss:0.0828285738825798, MinCond: 0.06583744287490845, MaxK: 0.9341625571250916, MinK: 0.021839113906025887
Epoch: 6200, Loss:0.08219843

Epoch: 6200, Loss:0.0915423184633255, MaxK: 0.9967157244682312, MinK: 0.12129813432693481


Epoch: 6300, Loss:0.09212289750576019, MaxK: 0.9887887835502625, MinK: 0.12394104152917862


Epoch: 6400, Loss:0.09157013893127441, MaxK: 0.9995726943016052, MinK: 0.11784820258617401


Epoch: 6500, Loss:0.09176874905824661, MaxK: 0.9887830018997192, MinK: 0.12391801923513412


Epoch: 6600, Loss:0.09272649884223938, MaxK: 0.9802227020263672, MinK: 0.12854284048080444


Epoch: 6700, Loss:0.09115777909755707, MaxK: 0.9994391202926636, MinK: 0.12281535565853119


Epoch: 6800, Loss:0.09361371397972107, MaxK: 0.9672448635101318, MinK: 0.1276865452528


Epoch: 6900, Loss:0.09137604385614395, MaxK: 0.999270498752594, MinK: 0.12278765439987183


Epoch: 7000, Loss:0.09094145148992538, MaxK: 0.9985159635543823, MinK: 0.12396927922964096


Epoch: 7100, Loss:0.09220604598522186, MaxK: 0.9994620084762573, MinK: 0.12149429321289062


Epoch: 7200, Loss:0.09162447601556778, MaxK: 0.9987316727638245, MinK: 0.123

Epoch: 100, Loss:0.11195409297943115, MaxK: 0.948203980922699, MinK: 0.389646053314209


Epoch: 200, Loss:0.10023407638072968, MaxK: 0.9976208806037903, MinK: 0.24503590166568756


Epoch: 300, Loss:0.09560345113277435, MaxK: 0.998683750629425, MinK: 0.17906349897384644


Epoch: 400, Loss:0.09366464614868164, MaxK: 0.9983646273612976, MinK: 0.15265405178070068


Epoch: 500, Loss:0.09376922249794006, MaxK: 0.9974090456962585, MinK: 0.1490304172039032


Epoch: 600, Loss:0.09282767027616501, MaxK: 0.9986919164657593, MinK: 0.14456813037395477


Epoch: 700, Loss:0.09310819208621979, MaxK: 0.9978045225143433, MinK: 0.13046494126319885


Epoch: 800, Loss:0.09219977259635925, MaxK: 0.9984487891197205, MinK: 0.13341790437698364


Epoch: 900, Loss:0.09243680536746979, MaxK: 0.9939364194869995, MinK: 0.129857137799263


Epoch: 1000, Loss:0.09435737133026123, MaxK: 0.9783097505569458, MinK: 0.12744592130184174


Epoch: 1100, Loss:0.09152626246213913, MaxK: 0.9990610480308533, MinK: 0.1280287504196

Epoch: 1600, Loss:0.10043318569660187, MaxK: 1.2643011808395386 MinK: 0.5778688788414001


Epoch: 1700, Loss:0.09565278142690659, MaxK: 1.271422028541565 MinK: 0.48785221576690674


Epoch: 1800, Loss:0.08970429748296738, MaxK: 1.2771517038345337 MinK: 0.3495151400566101


Epoch: 1900, Loss:0.08649793267250061, MaxK: 1.2989165782928467 MinK: 0.3386869430541992


Epoch: 2000, Loss:0.08366058021783829, MaxK: 1.3179330825805664 MinK: 0.25794970989227295


Epoch: 2100, Loss:0.08186496794223785, MaxK: 1.328800082206726 MinK: 0.2209625244140625


Epoch: 2200, Loss:0.08087659627199173, MaxK: 1.3474324941635132 MinK: 0.1943458914756775


Epoch: 2300, Loss:0.07979638874530792, MaxK: 1.3581516742706299 MinK: 0.16796603798866272


Epoch: 2400, Loss:0.07875625044107437, MaxK: 1.3587926626205444 MinK: 0.13904906809329987


Epoch: 2500, Loss:0.07796722650527954, MaxK: 1.3539186716079712 MinK: 0.12049775570631027


Epoch: 2600, Loss:0.07736992835998535, MaxK: 1.3601170778274536 MinK: 0.071658112108707

Epoch: 3100, Loss:0.07039302587509155, MaxK: 1.22511887550354 MinK: 0.2759552001953125


Epoch: 3200, Loss:0.07019753754138947, MaxK: 1.243615746498108 MinK: 0.3071536123752594


Epoch: 3300, Loss:0.06989124417304993, MaxK: 1.2486913204193115 MinK: 0.29836195707321167


Epoch: 3400, Loss:0.06907409429550171, MaxK: 1.2524847984313965 MinK: 0.29947853088378906


Epoch: 3500, Loss:0.06881817430257797, MaxK: 1.2890268564224243 MinK: 0.3041094243526459


Epoch: 3600, Loss:0.06833042949438095, MaxK: 1.2875488996505737 MinK: 0.2958541214466095


Epoch: 3700, Loss:0.06798698753118515, MaxK: 1.2901288270950317 MinK: 0.2874935567378998


Epoch: 3800, Loss:0.06794923543930054, MaxK: 1.2927014827728271 MinK: 0.27784788608551025


Epoch: 3900, Loss:0.06748946756124496, MaxK: 1.2968318462371826 MinK: 0.2848719358444214


Epoch: 4000, Loss:0.06764347851276398, MaxK: 1.3012841939926147 MinK: 0.28741589188575745


Epoch: 4100, Loss:0.06737696379423141, MaxK: 1.3378945589065552 MinK: 0.28397494554519653



Epoch: 4700, Loss:0.08752164244651794, MaxK: 1.201233148574829 MinK: 0.5972481369972229


Epoch: 4800, Loss:0.08685731887817383, MaxK: 1.2396183013916016 MinK: 0.6174504160881042


Epoch: 4900, Loss:0.08653085678815842, MaxK: 1.2292085886001587 MinK: 0.6108132004737854


Epoch: 5000, Loss:0.08631925284862518, MaxK: 1.2383686304092407 MinK: 0.6039916276931763


Epoch: 5100, Loss:0.0861603170633316, MaxK: 1.2421220541000366 MinK: 0.5958732962608337


Epoch: 5200, Loss:0.08603418618440628, MaxK: 1.2524981498718262 MinK: 0.5927850008010864


Epoch: 5300, Loss:0.08583363890647888, MaxK: 1.249115228652954 MinK: 0.581495463848114


Epoch: 5400, Loss:0.08571689575910568, MaxK: 1.2477585077285767 MinK: 0.5783885717391968


Epoch: 5500, Loss:0.08573242276906967, MaxK: 1.2453497648239136 MinK: 0.5748957395553589


Epoch: 5600, Loss:0.08562499284744263, MaxK: 1.2618991136550903 MinK: 0.5769286155700684


Epoch: 5700, Loss:0.08542142063379288, MaxK: 1.2510308027267456 MinK: 0.5687457323074341


E



Epoch: 6200, Loss:0.05632392689585686, MaxK: 1.196050763130188 MinK: 0.03317934647202492


Epoch: 6300, Loss:0.056701451539993286, MaxK: 1.198144555091858 MinK: 0.03359280526638031


Epoch: 6400, Loss:0.05704859644174576, MaxK: 1.2056162357330322 MinK: 0.03347235545516014


Epoch: 6500, Loss:0.05553802102804184, MaxK: 1.1853846311569214 MinK: 0.02219483070075512


Epoch: 6600, Loss:0.054982542991638184, MaxK: 1.1911317110061646 MinK: 0.017393574118614197


Epoch: 6700, Loss:0.05628959834575653, MaxK: 1.2154500484466553 MinK: 0.009199772961437702


Epoch: 6800, Loss:0.054607830941677094, MaxK: 1.1905544996261597 MinK: 0.011259442195296288


Epoch: 6900, Loss:0.05577049404382706, MaxK: 1.2118847370147705 MinK: 0.010229051113128662


Epoch: 7000, Loss:0.055194102227687836, MaxK: 1.2021254301071167 MinK: 0.017828011885285378


Epoch: 7100, Loss:0.05575239285826683, MaxK: 1.2169712781906128 MinK: 0.011349973268806934


Epoch: 7200, Loss:0.05604613572359085, MaxK: 1.2256981134414673 MinK: 



Epoch: 100, Loss:0.12470800429582596, MaxK: 0.8826879262924194 MinK: 0.5087127089500427


Epoch: 200, Loss:0.08426772058010101, MaxK: 1.1475906372070312 MinK: 0.17859792709350586


Epoch: 300, Loss:0.07615384459495544, MaxK: 1.158318281173706 MinK: 0.07908207923173904


Epoch: 400, Loss:0.07431776821613312, MaxK: 1.1827386617660522 MinK: 0.04411524534225464


Epoch: 500, Loss:0.0736062079668045, MaxK: 1.1888573169708252 MinK: 0.04375787079334259


Epoch: 600, Loss:0.07269873470067978, MaxK: 1.1811437606811523 MinK: 0.03242579847574234


Epoch: 700, Loss:0.07144197821617126, MaxK: 1.1766772270202637 MinK: 0.0038585783913731575


Epoch: 800, Loss:0.06952836364507675, MaxK: 1.1844241619110107 MinK: 0.004191353917121887


Epoch: 900, Loss:0.06645560264587402, MaxK: 1.211020827293396 MinK: 0.004382446873933077


Epoch: 1000, Loss:0.06273017823696136, MaxK: 1.2329165935516357 MinK: 0.029387162998318672


Epoch: 1100, Loss:0.06021171808242798, MaxK: 1.2337640523910522 MinK: 0.04517722502350

Epoch: 1200, Loss:0.09362684190273285, MinCond: 0.09721136093139648, MaxK: 0.9027886390686035, MinK: 0.01280919834971428
Epoch: 1300, Loss:0.0929919183254242, MinCond: 0.10192173719406128, MaxK: 0.8980782628059387, MinK: 0.009522262029349804
Epoch: 1400, Loss:0.09252259880304337, MinCond: 0.11016738414764404, MaxK: 0.889832615852356, MinK: 0.0029949063900858164
Epoch: 1500, Loss:0.091232068836689, MinCond: 0.10372459888458252, MaxK: 0.8962754011154175, MinK: 0.019456855952739716
Epoch: 1600, Loss:0.09100087732076645, MinCond: 0.10232514142990112, MaxK: 0.8976748585700989, MinK: 0.03244432434439659
Epoch: 1700, Loss:0.09159181267023087, MinCond: 0.10726732015609741, MaxK: 0.8927326798439026, MinK: 0.010293625295162201
Epoch: 1800, Loss:0.09028280526399612, MinCond: 0.09874856472015381, MaxK: 0.9012514352798462, MinK: 0.026892071589827538
Epoch: 1900, Loss:0.09111017733812332, MinCond: 0.10776901245117188, MaxK: 0.8922309875488281, MinK: 0.021201543509960175
Epoch: 2000, Loss:0.090630896

Epoch: 500, Loss:0.10181567817926407, MinCond: 0.11874961853027344, MaxK: 0.8812503814697266, MinK: 0.05470752343535423
Epoch: 600, Loss:0.10070356726646423, MinCond: 0.10338097810745239, MaxK: 0.8966190218925476, MinK: 0.048011407256126404
Epoch: 700, Loss:0.09986308962106705, MinCond: 0.11010122299194336, MaxK: 0.8898987770080566, MinK: 0.006592744030058384
Epoch: 800, Loss:0.09922134131193161, MinCond: 0.09414857625961304, MaxK: 0.905851423740387, MinK: 0.0022051124833524227
Epoch: 900, Loss:0.09862982481718063, MinCond: 0.09951287508010864, MaxK: 0.9004871249198914, MinK: 0.0137712387368083
Epoch: 1000, Loss:0.09779475629329681, MinCond: 0.10054636001586914, MaxK: 0.8994536399841309, MinK: 0.012024904601275921
Epoch: 1100, Loss:0.09841805696487427, MinCond: 0.0960015058517456, MaxK: 0.9039984941482544, MinK: 0.006044625770300627
Epoch: 1200, Loss:0.09767888486385345, MinCond: 0.10691356658935547, MaxK: 0.8930864334106445, MinK: 0.016649719327688217
Epoch: 1300, Loss:0.0961008965969

Epoch: 7300, Loss:0.08819468319416046, MinCond: 0.1291886568069458, MaxK: 0.8708113431930542, MinK: 0.0048232791014015675
Epoch: 7400, Loss:0.08806006610393524, MinCond: 0.12828993797302246, MaxK: 0.8717100620269775, MinK: 0.00945028755813837
Epoch: 0, Loss:1.0060815811157227, MinCond: 0.794582724571228, MaxK: 0.20541724562644958, MinK: 0.09743992239236832
Epoch: 100, Loss:0.12581850588321686, MinCond: 0.14704018831253052, MaxK: 0.8529598116874695, MinK: 0.5047584772109985
Epoch: 200, Loss:0.10878675431013107, MinCond: 0.13125211000442505, MaxK: 0.868747889995575, MinK: 0.26487118005752563
Epoch: 300, Loss:0.10353346914052963, MinCond: 0.11955082416534424, MaxK: 0.8804491758346558, MinK: 0.1591382473707199
Epoch: 400, Loss:0.10198258608579636, MinCond: 0.12197059392929077, MaxK: 0.8780294060707092, MinK: 0.10026823729276657
Epoch: 500, Loss:0.10141835361719131, MinCond: 0.1289830207824707, MaxK: 0.8710169792175293, MinK: 0.08088178187608719
Epoch: 600, Loss:0.10107934474945068, MinCond

Epoch: 6600, Loss:0.08832797408103943, MinCond: 0.11575466394424438, MaxK: 0.8842453360557556, MinK: 0.026824625208973885
Epoch: 6700, Loss:0.08752374351024628, MinCond: 0.1111527681350708, MaxK: 0.8888472318649292, MinK: 0.030520915985107422
Epoch: 6800, Loss:0.08935288339853287, MinCond: 0.1164543628692627, MaxK: 0.8835456371307373, MinK: 0.022553633898496628
Epoch: 6900, Loss:0.0871300995349884, MinCond: 0.10752081871032715, MaxK: 0.8924791812896729, MinK: 0.03203814849257469
Epoch: 7000, Loss:0.08933063596487045, MinCond: 0.11570829153060913, MaxK: 0.8842917084693909, MinK: 0.02098306268453598
Epoch: 7100, Loss:0.08743354678153992, MinCond: 0.10904818773269653, MaxK: 0.8909518122673035, MinK: 0.028310153633356094
Epoch: 7200, Loss:0.08864959329366684, MinCond: 0.11391890048980713, MaxK: 0.8860810995101929, MinK: 0.02011561580002308
Epoch: 7300, Loss:0.08772437274456024, MinCond: 0.11070752143859863, MaxK: 0.8892924785614014, MinK: 0.02316770888864994
Epoch: 7400, Loss:0.08843170851

Epoch: 400, Loss:0.15542027354240417, MaxK: 1.033403992652893 MinK: 0.8732911348342896


Epoch: 500, Loss:0.15456105768680573, MaxK: 1.03474760055542 MinK: 0.8941870927810669


Epoch: 600, Loss:0.15356144309043884, MaxK: 1.0380467176437378 MinK: 0.9108306765556335


Epoch: 700, Loss:0.1521230787038803, MaxK: 1.0431374311447144 MinK: 0.9232540726661682


Epoch: 800, Loss:0.1494913548231125, MaxK: 1.0530269145965576 MinK: 0.9256128668785095


Epoch: 900, Loss:0.14473535120487213, MaxK: 1.0548874139785767 MinK: 0.9053400158882141


Epoch: 1000, Loss:0.14186938107013702, MaxK: 1.0651800632476807 MinK: 0.8930226564407349


Epoch: 1100, Loss:0.13927693665027618, MaxK: 1.0709211826324463 MinK: 0.8875967860221863


Epoch: 1200, Loss:0.13640831410884857, MaxK: 1.0973095893859863 MinK: 0.8900249004364014


Epoch: 1300, Loss:0.134062722325325, MaxK: 1.1177562475204468 MinK: 0.8685009479522705


Epoch: 1400, Loss:0.13224957883358002, MaxK: 1.1367247104644775 MinK: 0.8578178882598877


Epoch: 1500,

Epoch: 2000, Loss:0.13058975338935852, MaxK: 1.0917680263519287 MinK: 0.8565473556518555


Epoch: 2100, Loss:0.1293492615222931, MaxK: 1.1138975620269775 MinK: 0.8448495268821716


Epoch: 2200, Loss:0.12707975506782532, MaxK: 1.1342419385910034 MinK: 0.8211839199066162


Epoch: 2300, Loss:0.1259668916463852, MaxK: 1.1480977535247803 MinK: 0.7933701872825623


Epoch: 2400, Loss:0.1248386800289154, MaxK: 1.1423150300979614 MinK: 0.7545086145401001


Epoch: 2500, Loss:0.1243138313293457, MaxK: 1.1438307762145996 MinK: 0.7276654243469238


Epoch: 2600, Loss:0.12258577346801758, MaxK: 1.11429762840271 MinK: 0.6802395582199097


Epoch: 2700, Loss:0.12245199084281921, MaxK: 1.1297613382339478 MinK: 0.6517759561538696


Epoch: 2800, Loss:0.12180979549884796, MaxK: 1.1241090297698975 MinK: 0.6220524907112122


Epoch: 2900, Loss:0.12166440486907959, MaxK: 1.1298032999038696 MinK: 0.6110888719558716


Epoch: 3000, Loss:0.12131484597921371, MaxK: 1.130399465560913 MinK: 0.6101469397544861


Epoch:

Epoch: 3600, Loss:0.06275954097509384, MaxK: 1.1195846796035767 MinK: 0.026772765442728996


Epoch: 3700, Loss:0.06422679871320724, MaxK: 1.1421928405761719 MinK: 0.027383413165807724


Epoch: 3800, Loss:0.06295429170131683, MaxK: 1.1268649101257324 MinK: 0.030282657593488693


Epoch: 3900, Loss:0.063108429312706, MaxK: 1.1286540031433105 MinK: 0.03235786035656929


Epoch: 4000, Loss:0.06408804655075073, MaxK: 1.1391031742095947 MinK: 0.03667834401130676


Epoch: 4100, Loss:0.06435959786176682, MaxK: 1.1418458223342896 MinK: 0.03743693605065346


Epoch: 4200, Loss:0.06407523155212402, MaxK: 1.1461690664291382 MinK: 0.029402416199445724


Epoch: 4300, Loss:0.06314340978860855, MaxK: 1.1262462139129639 MinK: 0.011826497502624989


Epoch: 4400, Loss:0.06351687014102936, MaxK: 1.140488862991333 MinK: 0.010601851157844067


Epoch: 4500, Loss:0.06298745423555374, MaxK: 1.1355100870132446 MinK: 0.0073956991545856


Epoch: 4600, Loss:0.06393514573574066, MaxK: 1.144128441810608 MinK: 0.0056177

Epoch: 5100, Loss:0.06387538462877274, MaxK: 1.1441816091537476 MinK: 0.03393784537911415


Epoch: 5200, Loss:0.06447946280241013, MaxK: 1.1523550748825073 MinK: 0.03923199698328972


Epoch: 5300, Loss:0.06336845457553864, MaxK: 1.1325312852859497 MinK: 0.03017386421561241


Epoch: 5400, Loss:0.06433518975973129, MaxK: 1.1498913764953613 MinK: 0.04133095592260361


Epoch: 5500, Loss:0.06390949338674545, MaxK: 1.1420162916183472 MinK: 0.04140039160847664


Epoch: 5600, Loss:0.06468665599822998, MaxK: 1.1526873111724854 MinK: 0.04170599952340126


Epoch: 5700, Loss:0.06438612192869186, MaxK: 1.1474031209945679 MinK: 0.044963326305150986


Epoch: 5800, Loss:0.06393054872751236, MaxK: 1.1388463973999023 MinK: 0.045341700315475464


Epoch: 5900, Loss:0.06400755792856216, MaxK: 1.1406809091567993 MinK: 0.04213852807879448


Epoch: 6000, Loss:0.06287841498851776, MaxK: 1.121561050415039 MinK: 0.03267575427889824


Epoch: 6100, Loss:0.06429096311330795, MaxK: 1.1437901258468628 MinK: 0.0433258

Epoch: 6600, Loss:0.06286322325468063, MaxK: 1.126990556716919 MinK: 0.016843637451529503


Epoch: 6700, Loss:0.0637902319431305, MaxK: 1.1434717178344727 MinK: 0.014691635966300964


Epoch: 6800, Loss:0.06396042555570602, MaxK: 1.154653787612915 MinK: 0.007477167062461376


Epoch: 6900, Loss:0.06300686299800873, MaxK: 1.1512763500213623 MinK: 0.007832604460418224


Epoch: 7000, Loss:0.06324641406536102, MaxK: 1.162447214126587 MinK: 0.01505741011351347


Epoch: 7100, Loss:0.06316426396369934, MaxK: 1.19184410572052 MinK: 0.02197723649442196


Epoch: 7200, Loss:0.0634208470582962, MaxK: 1.1986216306686401 MinK: 0.02646363340318203


Epoch: 7300, Loss:0.06426677107810974, MaxK: 1.2237181663513184 MinK: 0.030960995703935623


Epoch: 7400, Loss:0.06293823570013046, MaxK: 1.2071939706802368 MinK: 0.0341397300362587


Epoch: 0, Loss:0.23010878264904022, MinCond: 0.8095769882202148, MaxK: 0.19042302668094635, MinK: 0.09312353283166885
Epoch: 100, Loss:0.030157597735524178, MinCond: 0.7691374

Epoch: 6100, Loss:0.007284392137080431, MinCond: 0.09523904323577881, MaxK: 0.9047609567642212, MinK: 0.026474904268980026
Epoch: 6200, Loss:0.007266454864293337, MinCond: 0.10394018888473511, MaxK: 0.8960598111152649, MinK: 0.014903444796800613
Epoch: 6300, Loss:0.0073552182875573635, MinCond: 0.08526778221130371, MaxK: 0.9147322177886963, MinK: 0.028686530888080597
Epoch: 6400, Loss:0.00725167989730835, MinCond: 0.10490643978118896, MaxK: 0.895093560218811, MinK: 0.01161852478981018
Epoch: 6500, Loss:0.0072674741968512535, MinCond: 0.10466516017913818, MaxK: 0.8953348398208618, MinK: 0.00736223952844739



KeyboardInterrupt



In [None]:
# break the code