In [1]:
import numpy as np
import copy
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random, os, pathlib

import torch
import torch.nn as nn
from torch.utils import data

## MNIST dataset

In [2]:
import mylibrary.datasets as datasets
import mylibrary.nnlib as tnn
from classes import *

In [3]:
mnist = datasets.MNIST()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

# train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [4]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [5]:
def shuffle_data():
    global train_data, train_label
    randidx = random.sample(range(len(train_label)), k=len(train_label))
    train_data = train_data[randidx]
    train_label = train_label[randidx]

In [6]:
input_size = 784
output_size = 10

learning_rate = 0.0001
batch_size = 50

In [7]:
network_seeds = [147, 258, 369]
# network_seeds = [369]
network_seed = 369

EPOCHS = 20

actf = nn.LeakyReLU
# actf = nn.ELU

learning_rate = 0.005
lambda_ = 2
criterion = nn.BCELoss()
sigmoid = nn.Sigmoid()

use_mixup = True

In [8]:
class MNIST_OneClass_Balanced(data.Dataset):
    
    def __init__(self, data, label, class_index):
        self.data = data
        self.label = label
        self.class_index = class_index
        
        mask = (label==class_index)
        self.label = mask.type(torch.float32).reshape(-1,1)
        self.class_data = torch.nonzero(mask).reshape(-1)
        self.other_data = torch.nonzero(~mask).reshape(-1)
        
        random.seed(network_seed)
        self._shuffle_data_()
        self.count = 0
        
    def __len__(self):
        return 2*len(self.class_data)
    
    def _shuffle_data_(self):
#         randidx = np.random.permutation(len(self.other_data))
        randidx = random.sample(range(len(self.other_data)), k=len(self.other_data))
        self.other_data = self.other_data[randidx]
    
    def __getitem__(self, idx):
        if idx < len(self.class_data):
            idx = self.class_data[idx]
            img, lbl = self.data[idx], self.label[idx]
        else:
            idx = self.other_data[idx-len(self.class_data)]
            img, lbl = self.data[idx], self.label[idx]
            self.count += 1
            if self.count >= len(self.class_data): 
                self._shuffle_data_()
                self.count = 0
        return img, lbl

In [9]:
# class_idx = 0
# train_dataset = MNIST_OneClass_Balanced(train_data, train_label, class_idx)
# test_dataset = MNIST_OneClass_Balanced(test_data, test_label, class_idx)

In [10]:
# len(train_dataset), len(test_dataset)

In [11]:
# train_loader_all = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
# test_loader_all = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [12]:
# %matplotlib inline
# # img, lbl = train_dataset[11010]
# img, lbl = test_dataset[10]
# print(lbl)
# plt.imshow(img.reshape(28,28))

## Convex

In [13]:
class UnivariateCNN(nn.Module):
    
    def __init__(self, channels:list, actf=nn.LeakyReLU):
        super().__init__()
        assert len(channels)>1

        layers = []
        for i in range(len(channels)-1):
            la = nn.Conv2d(channels[i], channels[i+1], kernel_size=(5,5), stride=2, padding=1)
            layers.append(la)
            layers.append(actf())
        layers.append(nn.AdaptiveAvgPool2d(1))
        self.features = nn.Sequential(*layers)
        self.fc = nn.Sequential(nn.Linear(channels[-1], 1))
        
    def forward(self, x):
        x = x.reshape(-1,1, 28, 28)
        x = self.features(x)
        s = x.shape
        return self.fc(x.reshape(s[0], s[1]))    
    

class ConvexCNN(nn.Module):
    
    def __init__(self, channels:list, actf=nn.LeakyReLU):
        super().__init__()
        assert len(channels)>1

        layers = []
        for i in range(len(channels)-1):
            la = nn.Conv2d(channels[i], channels[i+1], kernel_size=(5,5), stride=2, padding=1)
            layers.append(la)
            if i>0:
                layers[-1].weight.data *= 0.2
            layers.append(actf())
        layers.append(nn.AdaptiveAvgPool2d(1))
        self.features = nn.Sequential(*layers)
        self.fc = nn.Sequential(nn.Linear(channels[-1], 1))
        
    def forward(self, x):
        x = x.reshape(-1,1,28, 28)
        for i in range(2, len(self.features)-1, 2):
            self.features[i].weight.data.abs_()
        for i in range(0, len(self.fc), 2):
            self.fc[i].weight.data.abs_()
            
        x = self.features(x)
        s = x.shape
        return self.fc(x.reshape(s[0], s[1]))    
    


In [14]:
# cnn = UnivariateCNN([1, 16, 32, 64])
cnn = ConvexCNN([1, 16, 32, 64])
cnn(torch.randn(2,1,28,28))

tensor([[3.5953],
        [3.6228]], grad_fn=<AddmmBackward>)

In [15]:
a = nn.Sequential(nn.Conv2d(1, 1, kernel_size=5, padding=2, stride=2),
             nn.Conv2d(1,1, kernel_size=5, padding=2, stride=2))
# b = nn.Conv2d(1, 1, kernel_size=5, padding=2, stride=4)
b = nn.Conv2d(1, 1, kernel_size=5, padding=2, stride=2, dilation=2)
x = torch.randn(2,1,32,32)

a(x).shape, b(x).shape

(torch.Size([2, 1, 8, 8]), torch.Size([2, 1, 14, 14]))

In [16]:
a = nn.Sequential(nn.Conv2d(1, 1, kernel_size=5, padding=2, stride=2),
                 nn.Conv2d(1,1, kernel_size=5, padding=2, stride=2),
                 nn.Conv2d(1,1, kernel_size=5, padding=2, stride=2))
b = nn.Conv2d(1, 1, kernel_size=5, padding=2, stride=4, dilation=4)
x = torch.randn(2,1,28,28)
# nn.Conv2d()

a(x).shape, b(x).shape

(torch.Size([2, 1, 4, 4]), torch.Size([2, 1, 4, 4]))

In [17]:
stat_per_class = []
net_list = []
for class_idx in range(10):
    print(class_idx)
    train_dataset = MNIST_OneClass_Balanced(train_data, train_label, class_idx)
    test_dataset = MNIST_OneClass_Balanced(test_data, test_label, class_idx)

    train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
    test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

    torch.manual_seed(network_seed)
    Net = ConvexCNN([1, 16, 32], actf)
    optimizer = torch.optim.Adam(Net.parameters(), lr=learning_rate)
    losses = []
    train_accs = []
    test_accs = []

    index = 0
    for epoch in range(20):
        train_acc = 0
        train_count = 0
        for xx, yy in train_loader:
            index += 1
            if use_mixup:
                rand_indx = np.random.permutation(len(xx))
                rand_lambda = 1-torch.rand(len(xx), 1)*0.1
                x_mix = rand_lambda*xx+(1-rand_lambda)*xx[rand_indx]
                y_mix = rand_lambda*yy+(1-rand_lambda)*yy[rand_indx]
            else:
                x_mix = xx
                y_mix = yy

            yout = sigmoid(Net(x_mix))    
            loss = criterion(yout, y_mix)
            losses.append(float(loss))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            outputs = (yout.data.cpu().numpy() > 0.5).astype(float)
            correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
            train_acc += correct
            train_count += len(outputs)

#             if index%200 == 0:
        train_accs.append(float(train_acc)/train_count*100)
        train_acc = 0
        train_count = 0

        print(f'Epoch: {epoch}:{index},  Loss:{float(loss)}')
        test_count = 0
        test_acc = 0
        for xx, yy in test_loader:
            with torch.no_grad():
                yout = sigmoid(Net(xx))
            outputs = (yout.data.cpu().numpy() > 0.5).astype(float)
            correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
            test_acc += correct
            test_count += len(xx)
        test_accs.append(float(test_acc)/test_count*100)
        print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
        print()
                
    ### after each class index is finished training
    stat_per_class.append(
    f'Class: {class_idx} -> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}'
    )
    print(stat_per_class[-1], '\n')
    net_list.append(Net)

0
Epoch: 0:237,  Loss:0.502179741859436
Train Acc:71.64%, Test Acc:74.69%

Epoch: 1:474,  Loss:0.41291874647140503
Train Acc:84.88%, Test Acc:87.40%

Epoch: 2:711,  Loss:0.2867847979068756
Train Acc:89.68%, Test Acc:75.87%

Epoch: 3:948,  Loss:0.23302306234836578
Train Acc:91.80%, Test Acc:87.96%

Epoch: 4:1185,  Loss:0.28443029522895813
Train Acc:92.55%, Test Acc:88.98%

Epoch: 5:1422,  Loss:0.2573814392089844
Train Acc:92.99%, Test Acc:94.54%

Epoch: 6:1659,  Loss:0.20436663925647736
Train Acc:93.89%, Test Acc:90.31%

Epoch: 7:1896,  Loss:0.2413509041070938
Train Acc:93.53%, Test Acc:91.38%

Epoch: 8:2133,  Loss:0.1579991579055786
Train Acc:94.61%, Test Acc:90.51%

Epoch: 9:2370,  Loss:0.17704105377197266
Train Acc:94.43%, Test Acc:95.00%

Epoch: 10:2607,  Loss:0.25889015197753906
Train Acc:94.83%, Test Acc:93.88%

Epoch: 11:2844,  Loss:0.1945742964744568
Train Acc:95.03%, Test Acc:96.33%

Epoch: 12:3081,  Loss:0.11924849450588226
Train Acc:95.13%, Test Acc:91.94%

Epoch: 13:3318,  L

Train Acc:93.29%, Test Acc:93.83%

Epoch: 4:1085,  Loss:0.24745318293571472
Train Acc:94.83%, Test Acc:96.13%

Epoch: 5:1302,  Loss:0.20287036895751953
Train Acc:95.03%, Test Acc:95.80%

Epoch: 6:1519,  Loss:0.1972137689590454
Train Acc:95.42%, Test Acc:96.30%

Epoch: 7:1736,  Loss:0.25880295038223267
Train Acc:95.67%, Test Acc:95.63%

Epoch: 8:1953,  Loss:0.2320770025253296
Train Acc:95.99%, Test Acc:95.91%

Epoch: 9:2170,  Loss:0.246842160820961
Train Acc:96.19%, Test Acc:93.05%

Epoch: 10:2387,  Loss:0.14881765842437744
Train Acc:96.37%, Test Acc:96.64%

Epoch: 11:2604,  Loss:0.22326210141181946
Train Acc:96.26%, Test Acc:96.86%

Epoch: 12:2821,  Loss:0.1039741262793541
Train Acc:96.19%, Test Acc:96.36%

Epoch: 13:3038,  Loss:0.11739006638526917
Train Acc:96.35%, Test Acc:96.75%

Epoch: 14:3255,  Loss:0.22225986421108246
Train Acc:96.55%, Test Acc:96.08%

Epoch: 15:3472,  Loss:0.1546502411365509
Train Acc:96.57%, Test Acc:95.18%

Epoch: 16:3689,  Loss:0.19527067244052887
Train Acc:9

In [18]:
## find the classification error from classification per class network.
acc_test = 0
count_test = 0
with torch.no_grad():
    for index in range(len(test_label) // batch_size):
        xx = test_data[index * batch_size:(index + 1) * batch_size]
        yy = test_label[index * batch_size:(index + 1) * batch_size]
        yout = []
        for net in net_list:
            yout.append(sigmoid(net(xx)).reshape(-1))
        yout = torch.stack(yout, dim=1).argmax(dim=1)
        acc = (yout == yy).type(torch.float32).sum()
        count_test += len(xx)
        acc_test += acc
        
for s in stat_per_class:
    print(s)
print(f"Total Accuracy (Argmax) is : {float(acc_test/count_test)}")

Class: 0 -> Train Acc 96.08306601384433 ; Test Acc 96.3265306122449
Class: 1 -> Train Acc 98.47967962029071 ; Test Acc 98.98678414096916
Class: 2 -> Train Acc 96.07250755287009 ; Test Acc 97.14147286821705
Class: 3 -> Train Acc 96.19148589137171 ; Test Acc 97.32673267326733
Class: 4 -> Train Acc 96.7819239986306 ; Test Acc 97.50509164969449
Class: 5 -> Train Acc 97.12230215827337 ; Test Acc 96.8609865470852
Class: 6 -> Train Acc 97.22034471105103 ; Test Acc 97.54697286012527
Class: 7 -> Train Acc 96.19313647246608 ; Test Acc 96.20622568093385
Class: 8 -> Train Acc 93.98393437019314 ; Test Acc 96.14989733059548
Class: 9 -> Train Acc 92.49453689695747 ; Test Acc 91.97224975222993
Total Accuracy (Argmax) is : 0.9412000179290771


## NN

In [19]:
stat_per_class = []
net_list = []
for class_idx in range(10):
    print(class_idx)
    train_dataset = MNIST_OneClass_Balanced(train_data, train_label, class_idx)
    test_dataset = MNIST_OneClass_Balanced(test_data, test_label, class_idx)

    train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
    test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

    torch.manual_seed(network_seed)
    Net = UnivariateCNN([1, 16, 32], actf)
    optimizer = torch.optim.Adam(Net.parameters(), lr=learning_rate)
    losses = []
    train_accs = []
    test_accs = []

    index = 0
    for epoch in range(20):
        train_acc = 0
        train_count = 0
        for xx, yy in train_loader:
            index += 1

            if use_mixup:
                rand_indx = np.random.permutation(len(xx))
                rand_lambda = 1-torch.rand(len(xx), 1)*0.1
                x_mix = rand_lambda*xx+(1-rand_lambda)*xx[rand_indx]
                y_mix = rand_lambda*yy+(1-rand_lambda)*yy[rand_indx]
            else:
                x_mix = xx
                y_mix = yy

            yout = sigmoid(Net(x_mix))    
            loss = criterion(yout, y_mix)
            losses.append(float(loss))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            outputs = (yout.data.cpu().numpy() > 0.5).astype(float)
            correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
            train_acc += correct
            train_count += len(outputs)

#             if index%200 == 0:
        train_accs.append(float(train_acc)/train_count*100)
        train_acc = 0
        train_count = 0

        print(f'Epoch: {epoch}:{index},  Loss:{float(loss)}')
        test_count = 0
        test_acc = 0
        for xx, yy in test_loader:
            with torch.no_grad():
                yout = sigmoid(Net(xx))
            outputs = (yout.data.cpu().numpy() > 0.5).astype(float)
            correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
            test_acc += correct
            test_count += len(xx)
        test_accs.append(float(test_acc)/test_count*100)
        print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
        print()
                
    ### after each class index is finished training
    stat_per_class.append(
    f'Class: {class_idx} -> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}'
    )
    print(stat_per_class[-1], '\n')
    net_list.append(Net)

0
Epoch: 0:237,  Loss:0.1804620772600174
Train Acc:86.23%, Test Acc:94.74%

Epoch: 1:474,  Loss:0.18958808481693268
Train Acc:95.75%, Test Acc:97.04%

Epoch: 2:711,  Loss:0.198486790060997
Train Acc:97.21%, Test Acc:97.24%

Epoch: 3:948,  Loss:0.17238833010196686
Train Acc:97.78%, Test Acc:97.81%

Epoch: 4:1185,  Loss:0.13299264013767242
Train Acc:98.27%, Test Acc:98.83%

Epoch: 5:1422,  Loss:0.17925728857517242
Train Acc:98.51%, Test Acc:97.86%

Epoch: 6:1659,  Loss:0.14914283156394958
Train Acc:98.73%, Test Acc:99.23%

Epoch: 7:1896,  Loss:0.13187280297279358
Train Acc:98.97%, Test Acc:99.34%

Epoch: 8:2133,  Loss:0.12645836174488068
Train Acc:99.09%, Test Acc:99.18%

Epoch: 9:2370,  Loss:0.18057182431221008
Train Acc:99.17%, Test Acc:98.62%

Epoch: 10:2607,  Loss:0.16148313879966736
Train Acc:99.07%, Test Acc:98.98%

Epoch: 11:2844,  Loss:0.15103358030319214
Train Acc:99.21%, Test Acc:98.78%

Epoch: 12:3081,  Loss:0.11444096267223358
Train Acc:99.28%, Test Acc:99.34%

Epoch: 13:3318

Epoch: 3:868,  Loss:0.15785810351371765
Train Acc:97.18%, Test Acc:96.97%

Epoch: 4:1085,  Loss:0.18532609939575195
Train Acc:97.84%, Test Acc:97.81%

Epoch: 5:1302,  Loss:0.1583007574081421
Train Acc:98.17%, Test Acc:98.43%

Epoch: 6:1519,  Loss:0.14335201680660248
Train Acc:98.20%, Test Acc:98.43%

Epoch: 7:1736,  Loss:0.21838518977165222
Train Acc:98.45%, Test Acc:98.37%

Epoch: 8:1953,  Loss:0.16314858198165894
Train Acc:98.59%, Test Acc:98.60%

Epoch: 9:2170,  Loss:0.21017208695411682
Train Acc:98.77%, Test Acc:97.03%

Epoch: 10:2387,  Loss:0.14147599041461945
Train Acc:98.74%, Test Acc:98.77%

Epoch: 11:2604,  Loss:0.15156495571136475
Train Acc:98.83%, Test Acc:99.10%

Epoch: 12:2821,  Loss:0.09501968324184418
Train Acc:98.99%, Test Acc:98.54%

Epoch: 13:3038,  Loss:0.08736466616392136
Train Acc:98.81%, Test Acc:98.77%

Epoch: 14:3255,  Loss:0.1709200143814087
Train Acc:99.05%, Test Acc:98.71%

Epoch: 15:3472,  Loss:0.117319256067276
Train Acc:99.04%, Test Acc:98.71%

Epoch: 16:3

In [20]:
## find the classification error from classification per class network.
acc_test = 0
count_test = 0
with torch.no_grad():
    for index in range(len(test_label) // batch_size):
        xx = test_data[index * batch_size:(index + 1) * batch_size]
        yy = test_label[index * batch_size:(index + 1) * batch_size]
        yout = []
        for net in net_list:
            yout.append(sigmoid(net(xx)).reshape(-1))
        yout = torch.stack(yout, dim=1).argmax(dim=1)
        acc = (yout == yy).type(torch.float32).sum()
        count_test += len(xx)
        acc_test += acc
        
for s in stat_per_class:
    print(s)
print(f"Total Accuracy (Argmax) is : {float(acc_test/count_test)}")

Class: 0 -> Train Acc 99.62012493668749 ; Test Acc 99.64285714285714
Class: 1 -> Train Acc 99.67368733313556 ; Test Acc 99.51541850220265
Class: 2 -> Train Acc 99.11883182275932 ; Test Acc 98.74031007751938
Class: 3 -> Train Acc 99.57592562387865 ; Test Acc 99.15841584158416
Class: 4 -> Train Acc 99.60629921259843 ; Test Acc 99.54175152749491
Class: 5 -> Train Acc 99.2897989300867 ; Test Acc 99.10313901345292
Class: 6 -> Train Acc 99.63670158837445 ; Test Acc 99.58246346555325
Class: 7 -> Train Acc 99.50518754988029 ; Test Acc 99.17315175097276
Class: 8 -> Train Acc 99.29926508289182 ; Test Acc 98.92197125256673
Class: 9 -> Train Acc 99.28559421751555 ; Test Acc 98.26560951437067
Total Accuracy (Argmax) is : 0.9793000221252441


# Invex

In [21]:
use_mixup = True
use_check = False
check_every = 2
check_size = 100

m_,s_ = 1, 0

In [22]:
stat_per_class = []
net_list = []
for class_idx in range(10):
    print(class_idx)
    train_dataset = MNIST_OneClass_Balanced(train_data, train_label, class_idx)
    test_dataset = MNIST_OneClass_Balanced(test_data, test_label, class_idx)

    train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
    test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

    torch.manual_seed(network_seed)
    lips_net = UnivariateCNN([1, 16, 32], actf)
    Net = BasicInvexNet(784, lips_net, lambda_)
    optimizer = torch.optim.Adam(Net.parameters(), lr=learning_rate)
    losses = []
    train_accs = []
    test_accs = []

    index = 0
    for epoch in range(20):
        train_acc = 0
        train_count = 0
        for xx, yy in train_loader:
            index += 1
            
            optimizer.zero_grad()
            if use_check and epoch%check_every == 0:
                rand_inp = torch.rand(check_size, 784)*m_+s_
                Net(rand_inp)
                Net.compute_penalty_and_clipper()
                Net.gp.backward(retain_graph=True)
            
            if use_mixup:
                rand_indx = np.random.permutation(len(xx))
                rand_lambda = 1-torch.rand(len(xx), 1)*0.1
                x_mix = rand_lambda*xx+(1-rand_lambda)*xx[rand_indx]
                y_mix = rand_lambda*yy+(1-rand_lambda)*yy[rand_indx]
            else:
                x_mix = xx
                y_mix = yy

            yout = sigmoid(Net(x_mix))   
            Net.compute_penalty_and_clipper()
            loss = criterion(yout, y_mix) + Net.gp
            losses.append(float(loss))

            outputs = (yout.data.cpu().numpy() > 0.5).astype(float)
            preds = (yy.data.cpu().numpy() > 0.5).astype(float)
            correct = (outputs == preds).astype(float).sum()
            train_acc += correct
            train_count += len(outputs)

            loss.backward()
            optimizer.step()

#             if index%200 == 0:
        train_accs.append(float(train_acc)/train_count*100)
        train_acc = 0
        train_count = 0

        min_val, gp = float(Net.cond.min()) , float(Net.gp)
        print(f'Epoch: {epoch}:{index},  Loss:{float(loss)}, MinVal:{min_val}, gp: {gp}')
        test_count = 0
        test_acc = 0
        for xx, yy in test_loader:
#                     with torch.no_grad():
            yout = sigmoid(Net(xx))
            outputs = (yout.data.cpu().numpy() > 0.5).astype(float)
            correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
            test_acc += correct
            test_count += len(xx)
        test_accs.append(float(test_acc)/test_count*100)
        print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
        print()
                
    ### after each class index is finished training
    stat_per_class.append(
    f'Class: {class_idx} -> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}'
    )
    print(stat_per_class[-1], '\n')
    net_list.append(Net)

0
Epoch: 0:237,  Loss:0.5909543037414551, MinVal:0.8676339983940125, gp: 6.44832484633755e-17
Train Acc:46.17%, Test Acc:76.07%

Epoch: 1:474,  Loss:0.3354240953922272, MinVal:2.0141360759735107, gp: 7.606768226345064e-37
Train Acc:87.41%, Test Acc:93.32%

Epoch: 2:711,  Loss:0.22962957620620728, MinVal:1.5341095924377441, gp: 1.704854733112834e-28
Train Acc:92.77%, Test Acc:94.64%

Epoch: 3:948,  Loss:0.23411162197589874, MinVal:1.2865544557571411, gp: 5.091819069732541e-24
Train Acc:94.08%, Test Acc:95.10%

Epoch: 4:1185,  Loss:0.19331051409244537, MinVal:1.765328288078308, gp: 2.738273829559003e-32
Train Acc:94.93%, Test Acc:96.28%

Epoch: 5:1422,  Loss:0.15277491509914398, MinVal:1.5068694353103638, gp: 5.529063684563874e-28
Train Acc:95.70%, Test Acc:95.31%

Epoch: 6:1659,  Loss:0.15293139219284058, MinVal:1.1833570003509521, gp: 2.7625325774858694e-22
Train Acc:96.48%, Test Acc:96.73%

Epoch: 7:1896,  Loss:0.14416582882404327, MinVal:1.04512357711792, gp: 5.263976116130585e-20
Tr

Train Acc:93.95%, Test Acc:97.13%

Epoch: 2:738,  Loss:0.15849657356739044, MinVal:0.6083126068115234, gp: 1.1074221401008444e-11
Train Acc:96.70%, Test Acc:98.07%

Epoch: 3:984,  Loss:0.19061410427093506, MinVal:0.6396447420120239, gp: 2.1994149123488738e-12
Train Acc:97.67%, Test Acc:98.56%

Epoch: 4:1230,  Loss:0.20631951093673706, MinVal:0.565158486366272, gp: 4.325428556994915e-11
Train Acc:98.16%, Test Acc:98.42%

Epoch: 5:1476,  Loss:0.10381584614515305, MinVal:0.5942614078521729, gp: 1.4537937693959169e-11
Train Acc:98.56%, Test Acc:96.78%

Epoch: 6:1722,  Loss:0.11105737835168839, MinVal:0.523272693157196, gp: 2.3101927315583026e-10
Train Acc:98.51%, Test Acc:98.56%

Epoch: 7:1968,  Loss:0.13773779571056366, MinVal:0.8967038989067078, gp: 1.2154249137208347e-16
Train Acc:98.84%, Test Acc:98.61%

Epoch: 8:2214,  Loss:0.15247158706188202, MinVal:1.018035888671875, gp: 6.036014994969753e-19
Train Acc:98.97%, Test Acc:98.22%

Epoch: 9:2460,  Loss:0.10665667802095413, MinVal:0.8725

Epoch: 3:948,  Loss:0.17350026965141296, MinVal:0.3557813763618469, gp: 6.868783941627044e-08
Train Acc:98.73%, Test Acc:98.28%

Epoch: 4:1185,  Loss:0.1184062659740448, MinVal:0.4782955050468445, gp: 4.97184959868946e-10
Train Acc:98.85%, Test Acc:98.64%

Epoch: 5:1422,  Loss:0.1402582973241806, MinVal:1.0138187408447266, gp: 2.432738698506337e-19
Train Acc:98.99%, Test Acc:99.16%

Epoch: 6:1659,  Loss:0.13905565440654755, MinVal:1.1099740266799927, gp: 4.949206589631595e-21
Train Acc:99.15%, Test Acc:98.43%

Epoch: 7:1896,  Loss:0.1568649709224701, MinVal:1.2778247594833374, gp: 6.127293671143745e-24
Train Acc:99.20%, Test Acc:99.06%

Epoch: 8:2133,  Loss:0.1388392448425293, MinVal:1.4091908931732178, gp: 3.137695483111739e-26
Train Acc:99.14%, Test Acc:98.90%

Epoch: 9:2370,  Loss:0.1756681650876999, MinVal:1.461872935295105, gp: 3.814509558990992e-27
Train Acc:99.29%, Test Acc:99.11%

Epoch: 10:2607,  Loss:0.14185567200183868, MinVal:1.5756696462631226, gp: 4.0231824916102073e-29
T

Train Acc:95.81%, Test Acc:95.29%

Epoch: 5:1428,  Loss:0.1509491354227066, MinVal:0.9487681984901428, gp: 2.7557954309074176e-18
Train Acc:96.50%, Test Acc:96.48%

Epoch: 6:1666,  Loss:0.17523212730884552, MinVal:1.1820836067199707, gp: 3.7217117075877185e-22
Train Acc:96.80%, Test Acc:96.73%

Epoch: 7:1904,  Loss:0.13361099362373352, MinVal:1.1297414302825928, gp: 1.6836848384055777e-21
Train Acc:97.48%, Test Acc:97.03%

Epoch: 8:2142,  Loss:0.15375612676143646, MinVal:1.4209798574447632, gp: 2.574249578590489e-26
Train Acc:97.56%, Test Acc:97.32%

Epoch: 9:2380,  Loss:0.13974452018737793, MinVal:0.925240159034729, gp: 6.0133999804201294e-18
Train Acc:97.69%, Test Acc:97.27%

Epoch: 10:2618,  Loss:0.1262916922569275, MinVal:0.9064548015594482, gp: 1.273751354314582e-17
Train Acc:98.14%, Test Acc:97.72%

Epoch: 11:2856,  Loss:0.1433952897787094, MinVal:1.485651969909668, gp: 1.1101941312397585e-27
Train Acc:98.12%, Test Acc:97.87%

Epoch: 12:3094,  Loss:0.19506920874118805, MinVal:1.1

In [23]:
## find the classification error from classification per class network.
acc_test = 0
count_test = 0
with torch.no_grad():
    for index in range(len(test_label) // batch_size):
        xx = test_data[index * batch_size:(index + 1) * batch_size]
        yy = test_label[index * batch_size:(index + 1) * batch_size]
        yout = []
        for net in net_list:
            yout.append(sigmoid(net(xx)).reshape(-1))
        yout = torch.stack(yout, dim=1).argmax(dim=1)
        acc = (yout == yy).type(torch.float32).sum()
        count_test += len(xx)
        acc_test += acc
        
for s in stat_per_class:
    print(s)
print(f"Total Accuracy (Argmax) is : {float(acc_test/count_test)}")

Class: 0 -> Train Acc 99.4850582475097 ; Test Acc 98.9795918367347
Class: 1 -> Train Acc 99.69593592405815 ; Test Acc 99.51541850220265
Class: 2 -> Train Acc 99.513259483048 ; Test Acc 99.07945736434108
Class: 3 -> Train Acc 99.57592562387865 ; Test Acc 99.20792079207921
Class: 4 -> Train Acc 99.62341663813763 ; Test Acc 99.69450101832994
Class: 5 -> Train Acc 99.37280944475188 ; Test Acc 98.99103139013454
Class: 6 -> Train Acc 99.62825278810409 ; Test Acc 99.4258872651357
Class: 7 -> Train Acc 99.4732641660016 ; Test Acc 99.41634241245137
Class: 8 -> Train Acc 99.17108186634763 ; Test Acc 98.61396303901438
Class: 9 -> Train Acc 99.00823667843335 ; Test Acc 98.51337958374629
Total Accuracy (Argmax) is : 0.9812999963760376


In [24]:
## only on training and testing data
for class_idx in range(10):
    correct = 0
    count = 0
    net = net_list[class_idx]
    for index in range(len(test_label) // batch_size):
        xx = test_data[index * batch_size:(index + 1) * batch_size]
        net(xx)
        net.compute_penalty_and_clipper()
        correct += (net.cond>0).type(torch.float32).sum()
        count += len(xx)
        
    for index in range(len(train_label) // batch_size):
        xx = train_data[index * batch_size:(index + 1) * batch_size]
        net(xx)
        net.compute_penalty_and_clipper()
        correct += (net.cond>0).type(torch.float32).sum()
        count += len(xx)

    print(f"Class: {class_idx} -> Correct {correct/count*100:.4f}% on {count} input points")

Class: 0 -> Correct 100.0000% on 70000 input points
Class: 1 -> Correct 100.0000% on 70000 input points
Class: 2 -> Correct 100.0000% on 70000 input points
Class: 3 -> Correct 100.0000% on 70000 input points
Class: 4 -> Correct 99.9857% on 70000 input points
Class: 5 -> Correct 100.0000% on 70000 input points
Class: 6 -> Correct 100.0000% on 70000 input points
Class: 7 -> Correct 100.0000% on 70000 input points
Class: 8 -> Correct 100.0000% on 70000 input points
Class: 9 -> Correct 100.0000% on 70000 input points


In [25]:
## Check the constraint on large number of points, including training and test data.
from tqdm import tqdm

for class_idx in range(10):
    correct = 0
    count = 0
    net = net_list[class_idx]
    for index in range(len(test_label) // batch_size):
        xx = test_data[index * batch_size:(index + 1) * batch_size]
        net(xx)
        net.compute_penalty_and_clipper()
        correct += (net.cond>0).type(torch.float32).sum()
        count += len(xx)
        
    for index in range(len(train_label) // batch_size):
        xx = train_data[index * batch_size:(index + 1) * batch_size]
        net(xx)
        net.compute_penalty_and_clipper()
        correct += (net.cond>0).type(torch.float32).sum()
        count += len(xx)

    for i in tqdm(range(20000)):
        xx = torch.rand(batch_size, 784)
        net(xx)
        net.compute_penalty_and_clipper()
        correct += (net.cond>0).type(torch.float32).sum()
        count += len(xx)

    print(f"Class: {class_idx} -> Correct {correct/count*100:.4f}% on {count} input points")

100%|██████████| 20000/20000 [02:46<00:00, 120.20it/s]


Class: 0 -> Correct 100.0000% on 1070000 input points


100%|██████████| 20000/20000 [02:57<00:00, 112.91it/s]


Class: 1 -> Correct 96.5372% on 1070000 input points


100%|██████████| 20000/20000 [02:44<00:00, 121.90it/s]


Class: 2 -> Correct 100.0000% on 1070000 input points


100%|██████████| 20000/20000 [02:40<00:00, 124.36it/s]


Class: 3 -> Correct 99.9951% on 1070000 input points


100%|██████████| 20000/20000 [02:40<00:00, 124.82it/s]


Class: 4 -> Correct 10.0613% on 1070000 input points


100%|██████████| 20000/20000 [02:39<00:00, 125.66it/s]


Class: 5 -> Correct 100.0000% on 1070000 input points


100%|██████████| 20000/20000 [02:38<00:00, 126.13it/s]


Class: 6 -> Correct 99.9995% on 1070000 input points


100%|██████████| 20000/20000 [02:39<00:00, 125.55it/s]


Class: 7 -> Correct 99.9503% on 1070000 input points


100%|██████████| 20000/20000 [02:53<00:00, 115.40it/s]


Class: 8 -> Correct 6.5421% on 1070000 input points


100%|██████████| 20000/20000 [03:01<00:00, 109.96it/s]

Class: 9 -> Correct 100.0000% on 1070000 input points



