In [1]:
import numpy as np
import copy
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F

import time, pathlib, os, sys

## 2D dataset

In [2]:
class LipschitzNet(nn.Module):
    
    def __init__(self, net, K=1., lamda=1.):
        super().__init__()
        
        self.net = net
        self.lamda = lamda
        self.X = None
        
        self.Y = None
        self.dydx = None
        self.K = K
        
        self.gp = 0
        self.gclipper = torch.Tensor([999])
        
    def forward(self, x):
        self.X = torch.autograd.Variable(x, requires_grad=True)
        
        self.Y = self.net(self.X)
        y = self.Y+0.
        y.register_hook(self.scale_gradient_back)
        return y
    
    def scale_gradient_back(self, grad):
#         print("Old", grad.shape)
#         grad = torch.min(torch.max(grad, -self.gclipper), self.gclipper)
#         print("New", grad.shape)
#         return grad
#         return torch.min(torch.max(grad, -self.gclipper), self.gclipper)
        return torch.minimum(torch.maximum(grad, -self.gclipper), self.gclipper)

    
    def get_dydx(self):
        self.dydx = torch.autograd.grad(outputs=self.Y, inputs=self.X,
                                    grad_outputs=torch.ones_like(self.Y),
                                    only_inputs=True, retain_graph=True, create_graph=True)[0]
        return self.dydx
    
    def get_gradient_penalty(self):
#         m = self.dydx.shape[0]
        dydx_norm = torch.norm(self.dydx, p=2, dim=1, keepdim=True)
#         self.cond = -(torch.abs(dydx_norm/self.K) -1.)
        self.cond = -dydx_norm/self.K +1.
        
#         a=-20
#         intolerables = torch.log(torch.exp(a*(self.cond-0.1))+1)/a
#         dydx_norm = torch.norm(self.dydx.data, p=2, dim=1, keepdim=True)
#         self.gp = 0.5*((intolerables*5)**2).mean()*self.lamda

#         intolerables = torch.clamp(F.softplus(self.cond-0.1, beta=-20), -1, 1)
        intolerables = F.softplus(self.cond-0.1, beta=-20)
#         intolerables = F.softplus(F.softplus(self.cond-0.1, beta=-20)+2, beta=5)-2

        self.gp = (self.smooth_l1(intolerables*5)).mean()*self.lamda
        
    
#         self.gp = 0.5*(intolerables**2).sum()*self.lamda
        
#         intolerables = torch.min(self.cond-1e-1, torch.zeros_like(self.cond))
#         self.gp = 0.5*(intolerables**2).mean()*self.lamda
#         self.gp = torch.abs(intolerables).mean()*self.lamda
#         self.gp = 0.5*(torch.abs(intolerables)+intolerables**2).mean()*self.lamda

        return self.gp
    
    ##
    def get_gradient_clipper(self):
        with torch.no_grad():
            cond = self.cond.data
            linear_mask = cond>0.14845
#             print(dydx.shape, linear_mask.shape)
            a = 20.
            gclipper = -((1.05*(cond-1))**4)+1
            gclipper = torch.log(torch.exp(a*gclipper)+1)/a
#             print(gclipper.shape)
            gc2 = 3*cond-0.0844560006
#             print(gc2.shape)
            gclipper[linear_mask] = gc2[linear_mask]
#             print(gclipper.shape)
#             gclipper = torch.clamp(gclipper, min=0.01)
            self.gclipper = gclipper
            
        return self.gclipper

    def smooth_l1(self, x, beta=1):
        mask = x<beta
        y = torch.empty_like(x)
        y[mask] = 0.5*(x[mask]**2)/beta
        y[~mask] = torch.abs(x[~mask])-0.5*beta
        return y
    
    def compute_penalty_and_clipper(self):
        self.get_dydx()
        self.get_gradient_penalty()
        self.get_gradient_clipper()
        return

In [3]:
class GradientPenaltyNet(nn.Module):
    
    def __init__(self, net, K=1., lamda=1.):
        super().__init__()
        
        self.net = net
        self.lamda = lamda
        self.X = None
        
        self.Y = None
        self.dydx = None
        self.K = K
        
        self.gp = 0
        
    def forward(self, x):
        self.X = torch.autograd.Variable(x, requires_grad=True)
        
        self.Y = self.net(self.X)
        y = self.Y+0.
        return y

    def get_dydx(self):
        self.dydx = torch.autograd.grad(outputs=self.Y, inputs=self.X,
                                    grad_outputs=torch.ones_like(self.Y),
                                    only_inputs=True, retain_graph=True, create_graph=True)[0]
        return self.dydx
    
    def get_gradient_penalty(self):
        dydx_norm = torch.norm(self.dydx, p=2, dim=1, keepdim=True)
        
        self.gp = ((dydx_norm-1)**2).mean()*self.lamda
        return self.gp
    
    def MSE_loss(self, diff):
        return 0.5*(diff**2).mean()
    
    def compute_penalty(self):
        self.get_dydx()
        self.get_gradient_penalty()
        return

In [4]:
class LipschitzPenaltyNet(nn.Module):
    
    def __init__(self, net, K=1., lamda=1.):
        super().__init__()
        
        self.net = net
        self.lamda = lamda
        self.X = None
        
        self.Y = None
        self.dydx = None
        self.K = K
        
        self.gp = 0
        
    def forward(self, x):
        self.X = torch.autograd.Variable(x, requires_grad=True)
        
        self.Y = self.net(self.X)
        y = self.Y+0.
        return y

    def get_dydx(self):
        self.dydx = torch.autograd.grad(outputs=self.Y, inputs=self.X,
                                    grad_outputs=torch.ones_like(self.Y),
                                    only_inputs=True, retain_graph=True, create_graph=True)[0]
        return self.dydx
    
    def get_gradient_penalty(self):
        dydx_norm = torch.norm(self.dydx, p=2, dim=1, keepdim=True)
        
        self.gp = ((torch.maximum(dydx_norm-1, torch.Tensor([0])))**2).mean()*self.lamda
#         self.gp = ((dydx_norm-1)**2).mean()*self.lamda
        return self.gp
    
    def MSE_loss(self, diff):
        return 0.5*(diff**2).mean()
    
    def compute_penalty(self):
        self.get_dydx()
        self.get_gradient_penalty()
        return

In [5]:
log_file = './data_collection/00_Gradient_Comparision/diff_grads_auto_.txt'

In [6]:
%matplotlib tk

for data_indx in range(3):
# for data_indx in [2]:
    if data_indx == 1:
        # The two-dimensional domain of the fit.....
        ########https://scipython.com/blog/non-linear-least-squares-fitting-of-a-two-dimensional-data/#########
        x1min, x1max, nx1 = -5, 6, 75
        x2min, x2max, nx2 = -3, 7, 75
        x1, x2 = np.linspace(x1min, x1max, nx1), np.linspace(x2min, x2max, nx2)
        X1, X2 = np.meshgrid(x1, x2)

        # Our function to fit is going to be a sum of two-dimensional Gaussians
        def gaussian(x1, x2, x10, x20, x1alpha, x2alpha, A):
            return A * np.exp( -((x1-x10)/x1alpha)**2 -((x2-x20)/x2alpha)**2)

        # A list of the Gaussian parameters: x10, x20, x1alpha, x2alpha, A
        gprms = [(0, 2, 2.5, 5.4, 1.5),
                 (-1, 4, 6, 2.5, 1.8),
                 (-3, -0.5, 1, 2, 4),
                 (3, 0.5, 2, 1, 5)
                ]

        # Standard deviation of normally-distributed noise to add in generating
        # our test function to fit.
        # The function to be fit is Z.
        Y = np.zeros(X1.shape)
        for p in gprms:
            Y += gaussian(X1, X2, *p)
        ### Adding noise to the data
        # noise_sigma = 0.1
        # Z += noise_sigma * np.random.randn(*Z.shape)

        ####Scaling the data to range -1,1
        X1 = 2*(X1 - X1.min())/(X1.max() - X1.min()) -1
        X2 = 2*(X2 - X2.min())/(X2.max() - X2.min()) -1
        Y = 2*(Y - Y.min())/(Y.max() - Y.min()) -1
        Y = Y/2

        x1 = X1.reshape(-1)
        x2 = X2.reshape(-1)

        xx = torch.Tensor(np.c_[x1, x2])
        yy = torch.Tensor(Y.reshape(-1,1))
        
    elif data_indx == 0:
        num_points = 50
        X1 = np.linspace(-2.5, 1.5, num_points)
        # X1 = np.linspace(-2.5, 0, num_points)
        X2 = np.linspace(-2, 4, num_points)
        # X2 = np.linspace(-2, 2, num_points)
        X1, X2 = np.meshgrid(X1, X2)
        Y = np.sin(np.sqrt(X1**2 + X2**2))*2-1.

        ####Scaling the data to range -1,1
        X1 = 2*(X1 - X1.min())/(X1.max() - X1.min()) -1
        X2 = 2*(X2 - X2.min())/(X2.max() - X2.min()) -1
        Y = 2*(Y - Y.min())/(Y.max() - Y.min()) -1
        Y = Y

        x1 = X1.reshape(-1)
        x2 = X2.reshape(-1)

        xx = torch.Tensor(np.c_[x1, x2])
        yy = torch.Tensor(Y.reshape(-1,1))
        
    elif data_indx == 2:
        def twospirals(n_points, noise=.5 , ang=720):
            """
             Returns the two spirals dataset.
            """
            n = np.sqrt(np.random.rand(n_points,1)) * ang * (2*np.pi)/360
            d1x = -np.cos(n)*n + np.random.rand(n_points,1) * noise
            d1y = np.sin(n)*n + np.random.rand(n_points,1) * noise
            return (np.vstack((np.hstack((d1x,d1y)),np.hstack((-d1x,-d1y)))), 
                    np.hstack((np.zeros(n_points),np.ones(n_points))))

        np.random.seed(987)
        # x, y = twospirals(200, ang=420)
        x, y = twospirals(200, ang=400)
        x, Y = x/10, y.reshape(-1)
        X1, X2 = x[:,0], x[:,1]
        
        ####Scaling the data to range -1,1
        X1 = 2*(X1 - X1.min())/(X1.max() - X1.min()) -1
        X2 = 2*(X2 - X2.min())/(X2.max() - X2.min()) -1

        x1 = X1.reshape(-1)
        x2 = X2.reshape(-1)

        xx = torch.Tensor(np.c_[x1, x2])
        yy = torch.Tensor(Y.reshape(-1,1))
    else:
        raise NotImplementedError("Not implemented Error")
        
        
    
    network_seeds = [147, 258, 369]
    learning_rate = 0.005 #0.01 0.005
    EPOCHS = 7500 #5000 #7500

    

    if data_indx == 2:
        use_sigmoid = True
        criterion = nn.BCELoss()
        actf = nn.LeakyReLU
        
        class Sigmoid4(nn.Module):
            def forward(self,x):
                return torch.sigmoid(x*4)

        nn.Sigmoid = Sigmoid4
    else:
        criterion = nn.MSELoss()
        use_sigmoid = False
        actf = nn.ELU

    
    for lambda_ in [1, 3]:
        
        original_stdout = sys.stdout
        sys.stdout = open(log_file, 'a')
        print("===================================================")
        print(f"Dataset {data_indx}; lambda_={lambda_}")
        print("===================================================")
            # Reset the standard output
        sys.stdout.close()
        sys.stdout = original_stdout 

        #################################################
        per_step_time = []
        info_per_seed = []

        fig = plt.figure(figsize=(15,6))
        ax = fig.add_subplot(121,projection='3d')
        ax2 = fig.add_subplot(122)
        # ax.view_init(28,20)

        for ns in network_seeds:
            torch.manual_seed(ns)

            net_lips = nn.Sequential(nn.Linear(2,10),
                                     actf(),
                                     nn.Linear(10,10),
                                     actf(),
                                     nn.Linear(10,1),
                                     nn.Sigmoid() if use_sigmoid else nn.Identity())

            lipsNet = LipschitzNet(net_lips, K=1, lamda=lambda_)


            optimizer = torch.optim.Adam(lipsNet.parameters(), lr=learning_rate)

            for epoch in range(EPOCHS):
                start = time.time()

                yout = lipsNet(xx)    
                lipsNet.compute_penalty_and_clipper()
                loss = criterion(yout, yy) + lipsNet.gp

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                per_step_time.append(time.time()-start)

                if epoch%100 == 0:
                    min_val = float(lipsNet.cond.min())
                    max_k = float(torch.norm(lipsNet.dydx, p=2, dim=1, keepdim=True).max())
                    min_k = float(torch.norm(lipsNet.dydx, p=2, dim=1, keepdim=True).min())
                    print(f'Epoch: {epoch}, Loss:{float(loss-lipsNet.gp)}, MinCond: {min_val}, MaxK: {max_k}, MinK: {min_k}')

                    ax.clear()
                    ax.scatter(X1, X2, yy.data.numpy().reshape(-1), marker= '.')
                    ax.scatter(X1, X2, yout.data.numpy().reshape(-1), color='r', marker='.')

                    ax2.clear()
                    if use_sigmoid:
                        ax2.scatter(X1, X2, c=yout.data.numpy().reshape(Y.shape), s=70, edgecolors='k', lw=0.5)
                        ax2.scatter(X1, X2, c=Y, marker='.')
                    else:
                        ax2.contourf(X1, X2, yout.data.numpy().reshape(Y.shape), levels=20)

                    fig.canvas.draw()
                    plt.pause(0.01)

            if use_sigmoid:
                acc = float(((yout>0.5).type(torch.float32)==yy).type(torch.float32).mean()*100)
                info_per_seed.append(
                    f'Epoch: {epoch}, Loss:{float(loss-lipsNet.gp)}, MinCond: {min_val}, MaxK: {max_k}, MinK: {min_k}, Acc: {acc}'
                )
            else:
                info_per_seed.append(f'Epoch: {epoch}, Loss:{float(loss-lipsNet.gp)}, MinCond: {min_val}, MaxK: {max_k}, MinK: {min_k}')

                
        original_stdout = sys.stdout
        sys.stdout = open(log_file, 'a')
        print("###############################################")
        print("GC-GP")
        for i in range(len(info_per_seed)):
            print(i, network_seeds[i], info_per_seed[i])
        print(np.mean(per_step_time), np.std(per_step_time))
        print("###############################################")
        # Reset the standard output
        sys.stdout.close()
        sys.stdout = original_stdout 
        
        
        
        plt.close()
        #################################################    
        
        if lambda_ == 1:
            per_step_time = []
            info_per_seed = []

            fig = plt.figure(figsize=(15,6))
            ax = fig.add_subplot(121,projection='3d')
            ax2 = fig.add_subplot(122)
            # ax.view_init(28,20)

            for ns in network_seeds:
                torch.manual_seed(ns)

                net_lips = nn.Sequential(nn.utils.spectral_norm(nn.Linear(2,10)),
                                         actf(),
                                         nn.utils.spectral_norm(nn.Linear(10,10)),
                                         actf(),
                                         nn.utils.spectral_norm(nn.Linear(10,1)),
                                         nn.Sigmoid() if use_sigmoid else nn.Identity())

                ## no clipper and loss used..., just for getting gradient norm.
                snNet = LipschitzNet(net_lips)

                optimizer = torch.optim.Adam(snNet.parameters(), lr=learning_rate)

                for epoch in range(EPOCHS):
                    start = time.time()

                    yout = snNet(xx)
                    loss = criterion(yout, yy)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    per_step_time.append(time.time()-start)

                    if epoch%100 == 0:
                        yout = snNet(xx)
                        snNet.get_dydx()

                        max_k = float(torch.norm(snNet.dydx, p=2, dim=1, keepdim=True).max())
                        min_k = float(torch.norm(snNet.dydx, p=2, dim=1, keepdim=True).min())
                        print(f'Epoch: {epoch}, Loss:{float(loss)}, MaxK: {max_k}, MinK: {min_k}')

                        ax.clear()
                        ax.scatter(X1, X2, yy.data.numpy().reshape(-1), marker= '.')
                        ax.scatter(X1, X2, yout.data.numpy().reshape(-1), color='r', marker='.')

                        ax2.clear()
                        if use_sigmoid:
                            ax2.scatter(X1, X2, c=yout.data.numpy().reshape(Y.shape), s=70, edgecolors='k', lw=0.5)
                            ax2.scatter(X1, X2, c=Y, marker='.')
                        else:
                            ax2.contourf(X1, X2, yout.data.numpy().reshape(Y.shape), levels=20)

                        fig.canvas.draw()
                        plt.pause(0.01)
                        print("\n")

                if use_sigmoid:
                    acc = float(((yout>0.5).type(torch.float32)==yy).type(torch.float32).mean()*100)
                    info_per_seed.append(
                        f'Epoch: {epoch}, Loss:{float(loss)}, MaxK: {max_k}, MinK: {min_k}, Acc: {acc}'
                    )
                else:
                    info_per_seed.append(f'Epoch: {epoch}, Loss:{float(loss)}, MaxK: {max_k}, MinK: {min_k}')


            original_stdout = sys.stdout
            sys.stdout = open(log_file, 'a')
            print("###############################################")
            print("SN")
            for i in range(len(info_per_seed)):
                print(i, network_seeds[i], info_per_seed[i])
            print(np.mean(per_step_time), np.std(per_step_time))
            print("###############################################")
            # Reset the standard output
            sys.stdout.close()
            sys.stdout = original_stdout 

            plt.close()
        #################################################    
        
        per_step_time = []
        info_per_seed = []

        fig = plt.figure(figsize=(15,6))
        ax = fig.add_subplot(121,projection='3d')
        ax2 = fig.add_subplot(122)
        # ax.view_init(28,20)

        for ns in network_seeds:
            torch.manual_seed(ns)

            net_lips = nn.Sequential(nn.Linear(2,10),
                                     actf(),
                                     nn.Linear(10,10),
                                     actf(),
                                     nn.Linear(10,1),
                                     nn.Sigmoid() if use_sigmoid else nn.Identity())

            gpNet = GradientPenaltyNet(net_lips, K=1, lamda=lambda_)

            optimizer = torch.optim.Adam(gpNet.parameters(), lr=learning_rate)
            for epoch in range(EPOCHS):
                start = time.time()

                yout = gpNet(xx)    
                gpNet.compute_penalty()
                loss = criterion(yout, yy) + gpNet.gp


                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                per_step_time.append(time.time()-start)
                if epoch%100 == 0:
                    allk = torch.norm(gpNet.dydx, p=2, dim=1, keepdim=True)
                    max_k = float(allk.max())
                    min_k = float(allk.min())

                    print(f'Epoch: {epoch}, Loss:{float(loss-gpNet.gp)}, MaxK: {max_k} MinK: {min_k}')

                    ax.clear()
                    ax.scatter(X1, X2, yy.data.numpy().reshape(-1), marker= '.')
                    ax.scatter(X1, X2, yout.data.numpy().reshape(-1), color='r', marker='.')

                    ax2.clear()
                    if use_sigmoid:
                        ax2.scatter(X1, X2, c=yout.data.numpy().reshape(Y.shape), s=70, edgecolors='k', lw=0.5)
                        ax2.scatter(X1, X2, c=Y, marker='.')
                    else:
                        ax2.contourf(X1, X2, yout.data.numpy().reshape(Y.shape), levels=20)

                    fig.canvas.draw()
                    plt.pause(0.01)
                    print("\n")

            if use_sigmoid:
                acc = float(((yout>0.5).type(torch.float32)==yy).type(torch.float32).mean()*100)
                info_per_seed.append(
                    f'Epoch: {epoch}, Loss:{float(loss-gpNet.gp)}, MaxK: {max_k}, MinK: {min_k}, Acc: {acc}'
                )
            else:
                info_per_seed.append(f'Epoch: {epoch}, Loss:{float(loss-gpNet.gp)}, MaxK: {max_k}, MinK: {min_k}')

                
        original_stdout = sys.stdout
        sys.stdout = open(log_file, 'a')
        print("###############################################")
        print("GP")
        for i in range(len(info_per_seed)):
            print(i, network_seeds[i], info_per_seed[i])
        print(np.mean(per_step_time), np.std(per_step_time))
        print("###############################################")
        # Reset the standard output
        sys.stdout.close()
        sys.stdout = original_stdout 
        
        plt.close()
        #################################################    

        per_step_time = []
        info_per_seed = []

        fig = plt.figure(figsize=(15,6))
        ax = fig.add_subplot(121,projection='3d')
        ax2 = fig.add_subplot(122)
        # ax.view_init(28,20)

        for ns in network_seeds:
            torch.manual_seed(ns)

            net_lips = nn.Sequential(nn.Linear(2,10),
                                     actf(),
                                     nn.Linear(10,10),
                                     actf(),
                                     nn.Linear(10,1),
                                     nn.Sigmoid() if use_sigmoid else nn.Identity())

            lpNet = LipschitzPenaltyNet(net_lips, K=1, lamda=lambda_)

            optimizer = torch.optim.Adam(lpNet.parameters(), lr=learning_rate)

            for epoch in range(EPOCHS):
                start = time.time()

                yout = lpNet(xx)    
                lpNet.compute_penalty()
                loss = criterion(yout, yy) + lpNet.gp


                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                per_step_time.append(time.time()-start)

                if epoch%100 == 0:
                    allk = torch.norm(lpNet.dydx, p=2, dim=1, keepdim=True)
                    max_k = float(allk.max())
                    min_k = float(allk.min())

                    print(f'Epoch: {epoch}, Loss:{float(loss-lpNet.gp)}, MaxK: {max_k} MinK: {min_k}')

                    ax.clear()
                    ax.scatter(X1, X2, yy.data.numpy().reshape(-1), marker= '.')
                    ax.scatter(X1, X2, yout.data.numpy().reshape(-1), color='r', marker='.')

                    ax2.clear()
                    if use_sigmoid:
                        ax2.scatter(X1, X2, c=yout.data.numpy().reshape(Y.shape), s=70, edgecolors='k', lw=0.5)
                        ax2.scatter(X1, X2, c=Y, marker='.')
                    else:
                        ax2.contourf(X1, X2, yout.data.numpy().reshape(Y.shape), levels=20)

                    fig.canvas.draw()
                    plt.pause(0.01)
                    print("\n")
            if use_sigmoid:
                acc = float(((yout>0.5).type(torch.float32)==yy).type(torch.float32).mean()*100)
                info_per_seed.append(
                    f'Epoch: {epoch}, Loss:{float(loss-lpNet.gp)}, MaxK: {max_k}, MinK: {min_k}, Acc: {acc}'
                )
            else:
                info_per_seed.append(f'Epoch: {epoch}, Loss:{float(loss-lpNet.gp)}, MaxK: {max_k}, MinK: {min_k}')

                
        original_stdout = sys.stdout
        sys.stdout = open(log_file, 'a')
        print("###############################################")
        print("LP")
        for i in range(len(info_per_seed)):
            print(i, network_seeds[i], info_per_seed[i])
        print(np.mean(per_step_time), np.std(per_step_time))
        print("###############################################")
        # Reset the standard output
        sys.stdout.close()
        sys.stdout = original_stdout 
        
        plt.close()

        #################################################    
                

Epoch: 0, Loss:0.6975633502006531, MinCond: 0.6876180171966553, MaxK: 0.3123819828033447, MinK: 0.0034458984155207872
Epoch: 100, Loss:0.3903731107711792, MinCond: -0.00606846809387207, MaxK: 1.006068468093872, MinK: 0.1219935417175293
Epoch: 200, Loss:0.37399786710739136, MinCond: -0.00011873245239257812, MaxK: 1.0001187324523926, MinK: 0.10634154081344604
Epoch: 300, Loss:0.3659294843673706, MinCond: -0.01186370849609375, MaxK: 1.0118637084960938, MinK: 0.06986920535564423
Epoch: 400, Loss:0.3440655469894409, MinCond: -0.029628992080688477, MaxK: 1.0296289920806885, MinK: 0.060880642384290695
Epoch: 500, Loss:0.34734129905700684, MinCond: -0.016026854515075684, MaxK: 1.0160268545150757, MinK: 0.15523108839988708
Epoch: 600, Loss:0.3418020009994507, MinCond: -0.0063048601150512695, MaxK: 1.0063048601150513, MinK: 0.1394931823015213
Epoch: 700, Loss:0.32673439383506775, MinCond: -0.10525310039520264, MaxK: 1.1052531003952026, MinK: 0.1277502477169037
Epoch: 800, Loss:0.3331665098667145

Epoch: 6800, Loss:0.31670668721199036, MinCond: -0.05692088603973389, MaxK: 1.0569208860397339, MinK: 0.21828223764896393
Epoch: 6900, Loss:0.3264968991279602, MinCond: -0.02300560474395752, MaxK: 1.0230056047439575, MinK: 0.25197485089302063
Epoch: 7000, Loss:0.3254600763320923, MinCond: -0.25430142879486084, MaxK: 1.2543014287948608, MinK: 0.13625076413154602
Epoch: 7100, Loss:0.3266500234603882, MinCond: -0.135756254196167, MaxK: 1.135756254196167, MinK: 0.1560475379228592
Epoch: 7200, Loss:0.31079649925231934, MinCond: -0.2703385353088379, MaxK: 1.270338535308838, MinK: 0.19870898127555847
Epoch: 7300, Loss:0.3095064163208008, MinCond: -0.23524773120880127, MaxK: 1.2352477312088013, MinK: 0.12787607312202454
Epoch: 7400, Loss:0.3117102086544037, MinCond: -0.196435809135437, MaxK: 1.196435809135437, MinK: 0.18211457133293152
Epoch: 0, Loss:0.7046768069267273, MinCond: 0.9105192422866821, MaxK: 0.08948075771331787, MinK: 0.008457088842988014
Epoch: 100, Loss:0.3881162703037262, MinCo

Epoch: 6100, Loss:0.35700523853302, MinCond: -0.02976822853088379, MaxK: 1.0297682285308838, MinK: 0.07303755730390549
Epoch: 6200, Loss:0.3569834530353546, MinCond: -0.0370326042175293, MaxK: 1.0370326042175293, MinK: 0.07236931473016739
Epoch: 6300, Loss:0.3570425808429718, MinCond: -0.03145945072174072, MaxK: 1.0314594507217407, MinK: 0.07353591173887253
Epoch: 6400, Loss:0.35603591799736023, MinCond: -0.03693842887878418, MaxK: 1.0369384288787842, MinK: 0.07354677468538284
Epoch: 6500, Loss:0.3561100661754608, MinCond: -0.037582993507385254, MaxK: 1.0375829935073853, MinK: 0.07042764127254486
Epoch: 6600, Loss:0.3564469516277313, MinCond: -0.035990357398986816, MaxK: 1.0359903573989868, MinK: 0.0719381794333458
Epoch: 6700, Loss:0.35601332783699036, MinCond: -0.03369283676147461, MaxK: 1.0336928367614746, MinK: 0.07301080971956253
Epoch: 6800, Loss:0.35668179392814636, MinCond: -0.03609728813171387, MaxK: 1.0360972881317139, MinK: 0.07464022934436798
Epoch: 6900, Loss:0.35718199610

Epoch: 5500, Loss:0.3063303828239441, MinCond: -0.049677491188049316, MaxK: 1.0496774911880493, MinK: 0.3296522796154022
Epoch: 5600, Loss:0.3041852116584778, MinCond: -0.055362582206726074, MaxK: 1.055362582206726, MinK: 0.22002571821212769
Epoch: 5700, Loss:0.2954280972480774, MinCond: -0.09045791625976562, MaxK: 1.0904579162597656, MinK: 0.25747427344322205
Epoch: 5800, Loss:0.2912251949310303, MinCond: -0.36633622646331787, MaxK: 1.3663362264633179, MinK: 0.134017214179039
Epoch: 5900, Loss:0.30570274591445923, MinCond: -0.20059239864349365, MaxK: 1.2005923986434937, MinK: 0.16574810445308685
Epoch: 6000, Loss:0.2977698743343353, MinCond: -0.2048705816268921, MaxK: 1.204870581626892, MinK: 0.10116113722324371
Epoch: 6100, Loss:0.29729729890823364, MinCond: -0.15353333950042725, MaxK: 1.1535333395004272, MinK: 0.1736096441745758
Epoch: 6200, Loss:0.3004051446914673, MinCond: -0.13797259330749512, MaxK: 1.1379725933074951, MinK: 0.09618805348873138
Epoch: 6300, Loss:0.299214601516723



Epoch: 6300, Loss:0.4307621717453003, MaxK: 0.8288027048110962, MinK: 0.18565933406352997


Epoch: 6400, Loss:0.42999228835105896, MaxK: 0.8386061191558838, MinK: 0.19332851469516754


Epoch: 6500, Loss:0.4328664541244507, MaxK: 0.7815640568733215, MinK: 0.20083048939704895


Epoch: 6600, Loss:0.43156570196151733, MaxK: 0.8155107498168945, MinK: 0.19839385151863098


Epoch: 6700, Loss:0.4300651550292969, MaxK: 0.8357378840446472, MinK: 0.19534337520599365


Epoch: 6800, Loss:0.43570664525032043, MaxK: 0.8086941242218018, MinK: 0.19929978251457214


Epoch: 6900, Loss:0.42819744348526, MaxK: 0.8239133954048157, MinK: 0.19157205522060394


Epoch: 7000, Loss:0.43471986055374146, MaxK: 0.7788218259811401, MinK: 0.2032061070203781


Epoch: 7100, Loss:0.42973193526268005, MaxK: 0.846164345741272, MinK: 0.19492430984973907


Epoch: 7200, Loss:0.43561050295829773, MaxK: 0.7829843163490295, MinK: 0.191166952252388


Epoch: 7300, Loss:0.43423178791999817, MaxK: 0.7823600769042969, MinK: 0.17571



Epoch: 300, Loss:0.44125521183013916, MaxK: 0.8708763718605042, MinK: 0.18905971944332123


Epoch: 400, Loss:0.4341660737991333, MaxK: 0.9549058675765991, MinK: 0.1803920418024063


Epoch: 500, Loss:0.4348195791244507, MaxK: 1.0105252265930176, MinK: 0.12839165329933167


Epoch: 600, Loss:0.4358968436717987, MaxK: 0.9362443089485168, MinK: 0.1887371689081192


Epoch: 700, Loss:0.44001683592796326, MaxK: 0.9396427869796753, MinK: 0.1737593412399292


Epoch: 800, Loss:0.44197237491607666, MaxK: 0.8673434257507324, MinK: 0.18315038084983826


Epoch: 900, Loss:0.43403738737106323, MaxK: 0.8981805443763733, MinK: 0.18625608086585999


Epoch: 1000, Loss:0.43376052379608154, MaxK: 0.9556382894515991, MinK: 0.1832478940486908


Epoch: 1100, Loss:0.43908795714378357, MaxK: 0.9377568364143372, MinK: 0.18406152725219727


Epoch: 1200, Loss:0.4366447329521179, MaxK: 0.8849114179611206, MinK: 0.18961000442504883


Epoch: 1300, Loss:0.4320463538169861, MaxK: 0.9336695075035095, MinK: 0.18825362622



Epoch: 1800, Loss:0.13403502106666565, MaxK: 1.6021976470947266 MinK: 0.13687102496623993


Epoch: 1900, Loss:0.1346910297870636, MaxK: 1.5779409408569336 MinK: 0.17795556783676147


Epoch: 2000, Loss:0.13483142852783203, MaxK: 2.0368285179138184 MinK: 0.2371259480714798


Epoch: 2100, Loss:0.135044664144516, MaxK: 1.9274944067001343 MinK: 0.317592978477478


Epoch: 2200, Loss:0.13566777110099792, MaxK: 1.8004249334335327 MinK: 0.3267371654510498


Epoch: 2300, Loss:0.13395226001739502, MaxK: 1.962682843208313 MinK: 0.4250860810279846


Epoch: 2400, Loss:0.12935177981853485, MaxK: 1.9566841125488281 MinK: 0.41985616087913513


Epoch: 2500, Loss:0.12961487472057343, MaxK: 1.9070508480072021 MinK: 0.4506291449069977


Epoch: 2600, Loss:0.13361622393131256, MaxK: 2.2572996616363525 MinK: 0.42854219675064087


Epoch: 2700, Loss:0.12639373540878296, MaxK: 1.8983817100524902 MinK: 0.41575348377227783


Epoch: 2800, Loss:0.12453000247478485, MaxK: 1.7768391370773315 MinK: 0.4155802130699157



Epoch: 3300, Loss:0.1437026858329773, MaxK: 1.5827834606170654 MinK: 0.01965790055692196


Epoch: 3400, Loss:0.14415797591209412, MaxK: 3.0046732425689697 MinK: 0.017385950312018394


Epoch: 3500, Loss:0.14294776320457458, MaxK: 1.7239083051681519 MinK: 0.016412371769547462


Epoch: 3600, Loss:0.13835948705673218, MaxK: 1.593291997909546 MinK: 0.01439408678561449


Epoch: 3700, Loss:0.13379469513893127, MaxK: 1.5668702125549316 MinK: 0.013788281939923763


Epoch: 3800, Loss:0.13573519885540009, MaxK: 1.6855247020721436 MinK: 0.013694464229047298


Epoch: 3900, Loss:0.13765603303909302, MaxK: 2.0508670806884766 MinK: 0.011616483330726624


Epoch: 4000, Loss:0.13338437676429749, MaxK: 1.6835582256317139 MinK: 0.01109400112181902


Epoch: 4100, Loss:0.13218046724796295, MaxK: 2.037520170211792 MinK: 0.01091101486235857


Epoch: 4200, Loss:0.13472507894039154, MaxK: 1.9439913034439087 MinK: 0.38024428486824036


Epoch: 4300, Loss:0.13076969981193542, MaxK: 1.8360192775726318 MinK: 0.3425



Epoch: 4800, Loss:0.11327000707387924, MaxK: 1.687666893005371 MinK: 0.051610495895147324


Epoch: 4900, Loss:0.11292479932308197, MaxK: 1.4968305826187134 MinK: 0.16026915609836578


Epoch: 5000, Loss:0.11108960211277008, MaxK: 1.5362496376037598 MinK: 0.1959238052368164


Epoch: 5100, Loss:0.11327699571847916, MaxK: 1.558508276939392 MinK: 0.22690333425998688


Epoch: 5200, Loss:0.11240187287330627, MaxK: 1.4301047325134277 MinK: 0.27252113819122314


Epoch: 5300, Loss:0.11162948608398438, MaxK: 1.4507390260696411 MinK: 0.15142735838890076


Epoch: 5400, Loss:0.1119355633854866, MaxK: 1.4488861560821533 MinK: 0.33538490533828735


Epoch: 5500, Loss:0.10812292993068695, MaxK: 1.5299180746078491 MinK: 0.16604119539260864


Epoch: 5600, Loss:0.11943204700946808, MaxK: 1.9823235273361206 MinK: 0.24322620034217834


Epoch: 5700, Loss:0.11645525693893433, MaxK: 1.6786830425262451 MinK: 0.29411137104034424


Epoch: 5800, Loss:0.1180371567606926, MaxK: 1.64787757396698 MinK: 0.229114770889



Epoch: 7200, Loss:0.01814347133040428, MaxK: 1.4290087223052979 MinK: 0.0


Epoch: 7300, Loss:0.0181284062564373, MaxK: 1.3886305093765259 MinK: 0.0


Epoch: 7400, Loss:0.018139677122235298, MaxK: 1.375279188156128 MinK: 0.0


Epoch: 0, Loss:0.7046768069267273, MaxK: 0.08948075771331787 MinK: 0.008457088842988014


Epoch: 100, Loss:0.30767932534217834, MaxK: 1.5249332189559937 MinK: 0.11415022611618042


Epoch: 200, Loss:0.09202246367931366, MaxK: 1.7884386777877808 MinK: 0.02721574902534485


Epoch: 300, Loss:0.06670205295085907, MaxK: 2.8895986080169678 MinK: 0.0032798536121845245


Epoch: 400, Loss:0.054883766919374466, MaxK: 1.6395126581192017 MinK: 5.176010017748922e-05


Epoch: 500, Loss:0.047451432794332504, MaxK: 2.2326390743255615 MinK: 0.0


Epoch: 600, Loss:0.0391235426068306, MaxK: 2.753605604171753 MinK: 0.0


Epoch: 700, Loss:0.034294020384550095, MaxK: 3.199357271194458 MinK: 0.0


Epoch: 800, Loss:0.02932814508676529, MaxK: 1.667211651802063 MinK: 0.0


Epoch: 900, Lo



Epoch: 2800, Loss:0.023557256907224655, MaxK: 1.5503814220428467 MinK: 0.0


Epoch: 2900, Loss:0.023121679201722145, MaxK: 1.525391936302185 MinK: 0.0


Epoch: 3000, Loss:0.022273950278759003, MaxK: 1.5461722612380981 MinK: 0.0


Epoch: 3100, Loss:0.021460503339767456, MaxK: 1.5410888195037842 MinK: 0.0


Epoch: 3200, Loss:0.021190378814935684, MaxK: 1.6803295612335205 MinK: 0.0


Epoch: 3300, Loss:0.020873794332146645, MaxK: 1.4834572076797485 MinK: 0.0


Epoch: 3400, Loss:0.02045670710504055, MaxK: 1.463491678237915 MinK: 0.0


Epoch: 3500, Loss:0.02008618600666523, MaxK: 1.448476791381836 MinK: 0.0


Epoch: 3600, Loss:0.02009447105228901, MaxK: 1.4195661544799805 MinK: 0.0


Epoch: 3700, Loss:0.01972232572734356, MaxK: 1.4220658540725708 MinK: 0.0


Epoch: 3800, Loss:0.019398340955376625, MaxK: 1.419029712677002 MinK: 0.0


Epoch: 3900, Loss:0.018966902047395706, MaxK: 1.3508355617523193 MinK: 0.0


Epoch: 4000, Loss:0.018976498395204544, MaxK: 1.3643139600753784 MinK: 0.0


Epoch

Epoch: 3900, Loss:0.3247641324996948, MinCond: -0.03429996967315674, MaxK: 1.0342999696731567, MinK: 0.2817598879337311
Epoch: 4000, Loss:0.3237101435661316, MinCond: -0.015129566192626953, MaxK: 1.015129566192627, MinK: 0.11650705337524414
Epoch: 4100, Loss:0.324494868516922, MinCond: -0.02465200424194336, MaxK: 1.0246520042419434, MinK: 0.2201259583234787
Epoch: 4200, Loss:0.32261523604393005, MinCond: -0.008474469184875488, MaxK: 1.0084744691848755, MinK: 0.22551903128623962
Epoch: 4300, Loss:0.3211103081703186, MinCond: -0.06577074527740479, MaxK: 1.0657707452774048, MinK: 0.19680075347423553
Epoch: 4400, Loss:0.32455337047576904, MinCond: -0.007660627365112305, MaxK: 1.0076606273651123, MinK: 0.19989480078220367
Epoch: 4500, Loss:0.3222790062427521, MinCond: -0.03377091884613037, MaxK: 1.0337709188461304, MinK: 0.1437484174966812
Epoch: 4600, Loss:0.32229629158973694, MinCond: -0.0235670804977417, MaxK: 1.0235670804977417, MinK: 0.09208162128925323
Epoch: 4700, Loss:0.322538077831

Epoch: 3200, Loss:0.35806283354759216, MinCond: 0.03386789560317993, MaxK: 0.9661321043968201, MinK: 0.07781177759170532
Epoch: 3300, Loss:0.3580624759197235, MinCond: 0.031392037868499756, MaxK: 0.9686079621315002, MinK: 0.08001334965229034
Epoch: 3400, Loss:0.3579210937023163, MinCond: 0.028768539428710938, MaxK: 0.9712314605712891, MinK: 0.08067397773265839
Epoch: 3500, Loss:0.3579430878162384, MinCond: 0.03226667642593384, MaxK: 0.9677333235740662, MinK: 0.07937687635421753
Epoch: 3600, Loss:0.35830605030059814, MinCond: 0.03624892234802246, MaxK: 0.9637510776519775, MinK: 0.07931926101446152
Epoch: 3700, Loss:0.35845962166786194, MinCond: 0.025668025016784668, MaxK: 0.9743319749832153, MinK: 0.07768602669239044
Epoch: 3800, Loss:0.3580796420574188, MinCond: 0.03644686937332153, MaxK: 0.9635531306266785, MinK: 0.06922909617424011
Epoch: 3900, Loss:0.3576834499835968, MinCond: 0.03466212749481201, MaxK: 0.965337872505188, MinK: 0.06801077723503113
Epoch: 4000, Loss:0.357942342758178

Epoch: 2500, Loss:0.3341093063354492, MinCond: -0.03397822380065918, MaxK: 1.0339782238006592, MinK: 0.07440657913684845
Epoch: 2600, Loss:0.3344569802284241, MinCond: -0.008208155632019043, MaxK: 1.008208155632019, MinK: 0.046111296862363815
Epoch: 2700, Loss:0.33823710680007935, MinCond: 0.0043604373931884766, MaxK: 0.9956395626068115, MinK: 0.12888212502002716
Epoch: 2800, Loss:0.33771082758903503, MinCond: -0.0310516357421875, MaxK: 1.0310516357421875, MinK: 0.03549986332654953
Epoch: 2900, Loss:0.34136122465133667, MinCond: -0.015797853469848633, MaxK: 1.0157978534698486, MinK: 0.08391039818525314
Epoch: 3000, Loss:0.33939310908317566, MinCond: -0.10937047004699707, MaxK: 1.109370470046997, MinK: 0.12434374541044235
Epoch: 3100, Loss:0.33698153495788574, MinCond: -0.021495342254638672, MaxK: 1.0214953422546387, MinK: 0.11057863384485245
Epoch: 3200, Loss:0.33397915959358215, MinCond: -0.044835686683654785, MaxK: 1.0448356866836548, MinK: 0.12473081797361374
Epoch: 3300, Loss:0.336



Epoch: 2400, Loss:0.2430906593799591, MaxK: 1.8486905097961426 MinK: 0.005947418510913849


Epoch: 2500, Loss:0.2392522096633911, MaxK: 3.485140323638916 MinK: 0.006835754960775375


Epoch: 2600, Loss:0.22421303391456604, MaxK: 1.6527119874954224 MinK: 0.018793756142258644


Epoch: 2700, Loss:0.2194056212902069, MaxK: 3.4282684326171875 MinK: 0.01558589842170477


Epoch: 2800, Loss:0.3202536702156067, MaxK: 1.5624362230300903 MinK: 0.02960946038365364


Epoch: 2900, Loss:0.2850857973098755, MaxK: 1.998638391494751 MinK: 0.3678692877292633


Epoch: 3000, Loss:0.29759156703948975, MaxK: 1.7728230953216553 MinK: 0.5387115478515625


Epoch: 3100, Loss:0.2932659089565277, MaxK: 1.656814694404602 MinK: 0.5231845378875732


Epoch: 3200, Loss:0.2835172414779663, MaxK: 1.7988507747650146 MinK: 0.40182849764823914


Epoch: 3300, Loss:0.27315109968185425, MaxK: 1.7600555419921875 MinK: 0.49249595403671265


Epoch: 3400, Loss:0.2642378807067871, MaxK: 1.6940150260925293 MinK: 0.3911789357662201




Epoch: 4000, Loss:0.24092580378055573, MaxK: 1.76994788646698 MinK: 0.458808034658432


Epoch: 4100, Loss:0.2429259717464447, MaxK: 1.4497075080871582 MinK: 0.45652762055397034


Epoch: 4200, Loss:0.24273543059825897, MaxK: 1.4450753927230835 MinK: 0.45371824502944946


Epoch: 4300, Loss:0.24092143774032593, MaxK: 1.4673545360565186 MinK: 0.45223045349121094


Epoch: 4400, Loss:0.24224600195884705, MaxK: 1.4576692581176758 MinK: 0.4582096040248871


Epoch: 4500, Loss:0.24174706637859344, MaxK: 1.4683669805526733 MinK: 0.4575141370296478


Epoch: 4600, Loss:0.2422686219215393, MaxK: 1.4463266134262085 MinK: 0.4557419419288635


Epoch: 4700, Loss:0.24132101237773895, MaxK: 1.455305576324463 MinK: 0.4538086950778961


Epoch: 4800, Loss:0.2421373426914215, MaxK: 1.430065393447876 MinK: 0.4526137709617615


Epoch: 4900, Loss:0.23745599389076233, MaxK: 1.608162522315979 MinK: 0.452982097864151


Epoch: 5000, Loss:0.2397467941045761, MaxK: 1.452893853187561 MinK: 0.45511576533317566


Epoch



Epoch: 5500, Loss:0.1828266680240631, MaxK: 2.5173463821411133 MinK: 0.02092810720205307


Epoch: 5600, Loss:0.18924957513809204, MaxK: 1.5338994264602661 MinK: 0.02067267708480358


Epoch: 5700, Loss:0.18114951252937317, MaxK: 2.125422477722168 MinK: 0.02214108221232891


Epoch: 5800, Loss:0.5395994186401367, MaxK: 3.813767194747925 MinK: 0.0002834236656781286


Epoch: 5900, Loss:0.38719844818115234, MaxK: 1.932677149772644 MinK: 0.024378158152103424


Epoch: 6000, Loss:0.2901454567909241, MaxK: 1.470487356185913 MinK: 0.018102167174220085


Epoch: 6100, Loss:0.22551563382148743, MaxK: 1.5028162002563477 MinK: 0.04107236862182617


Epoch: 6200, Loss:0.21368971467018127, MaxK: 1.8866695165634155 MinK: 0.04411919042468071


Epoch: 6300, Loss:0.2162638008594513, MaxK: 1.91727876663208 MinK: 0.04978834092617035


Epoch: 6400, Loss:0.22451940178871155, MaxK: 1.7729922533035278 MinK: 0.048405010253190994


Epoch: 6500, Loss:0.22873330116271973, MaxK: 1.6939326524734497 MinK: 0.04720794782



Epoch: 200, Loss:0.1858268529176712, MaxK: 1.8849388360977173 MinK: 0.11102274060249329


Epoch: 300, Loss:0.1040600836277008, MaxK: 1.98335862159729 MinK: 0.016643546521663666


Epoch: 400, Loss:0.07325851917266846, MaxK: 1.578175663948059 MinK: 0.004185437224805355


Epoch: 500, Loss:0.06175702065229416, MaxK: 1.5854241847991943 MinK: 0.0009094029664993286


Epoch: 600, Loss:0.08006031066179276, MaxK: 1.868546962738037 MinK: 8.761713252170011e-05


Epoch: 700, Loss:0.07713717222213745, MaxK: 2.08536696434021 MinK: 7.995063788257539e-05


Epoch: 800, Loss:0.07866190373897552, MaxK: 1.8949100971221924 MinK: 7.285030005732551e-05


Epoch: 900, Loss:0.07815399765968323, MaxK: 1.8908082246780396 MinK: 0.00010140731319552287


Epoch: 1000, Loss:0.07704973220825195, MaxK: 1.936436653137207 MinK: 0.0001192323979921639


Epoch: 1100, Loss:0.07821150124073029, MaxK: 1.8717767000198364 MinK: 0.00012274539039935917


Epoch: 1200, Loss:0.07868404686450958, MaxK: 1.8848787546157837 MinK: 0.00010



Epoch: 1500, Loss:0.030660737305879593, MaxK: 1.3006328344345093 MinK: 0.0008165091276168823


Epoch: 1600, Loss:0.02884540520608425, MaxK: 1.2358546257019043 MinK: 0.0004864144138991833


Epoch: 1700, Loss:0.02693840302526951, MaxK: 1.230424404144287 MinK: 0.00019538940978236496


Epoch: 1800, Loss:0.02597951330244541, MaxK: 1.2268681526184082 MinK: 0.00013913228758610785


Epoch: 1900, Loss:0.02587110549211502, MaxK: 1.1958621740341187 MinK: 0.00010588533041300252


Epoch: 2000, Loss:0.024475522339344025, MaxK: 1.2219161987304688 MinK: 5.10977661178913e-05


Epoch: 2100, Loss:0.023479195311665535, MaxK: 1.1754316091537476 MinK: 4.2713443690445274e-05


Epoch: 2200, Loss:0.027704529464244843, MaxK: 1.3739187717437744 MinK: 4.211056148051284e-05


Epoch: 2300, Loss:0.02312358096241951, MaxK: 1.1800391674041748 MinK: 3.7855130358366296e-05


Epoch: 2400, Loss:0.02275313436985016, MaxK: 1.1729191541671753 MinK: 3.819368430413306e-05


Epoch: 2500, Loss:0.02245916798710823, MaxK: 1.1685

In [7]:
break the code

SyntaxError: invalid syntax (<ipython-input-7-607481896e13>, line 1)