## Revised Linear Layer Architecture without VIB

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torch.autograd import Variable
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torchvision import transforms
from tensorboardX import SummaryWriter
#from utils import cuda
import pdb
import time
from numbers import Number
import numpy as np
import joblib
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

In [2]:
class ToyNet(nn.Module):
    '''
    Construct a MLP that is used to train the model 
    param[in]: X_train, output_features
    param[out]: output 
    
    Note: initialize the weight with a self-defined method
    '''

    def __init__(self,output_features=8):
        super(ToyNet, self).__init__()
        self.encode = nn.Sequential(
            nn.Linear(40,128),
            nn.ReLU(True),
            nn.Linear(128, 128),
            nn.ReLU(True),
            nn.Linear(128, output_features))
        #self.optim = optim.Adam(self.toynet.parameters(),lr=self.lr,betas=(0.5,0.999))
        #self.scheduler = lr_scheduler.ExponentialLR(self.optim,gamma=0.97
    def forward(self, X_train):
        output=self.encode(X_train)
        #prediction = F.softmax(output,dim=1).max(1)[1]
        
        #print(prediction)
        return output
    def weight_init(self):
        for m in self._modules:
            xavier_init(self._modules[m])
def xavier_init(ms):
    for m in ms :
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform(m.weight,gain=nn.init.calculate_gain('relu'))
            m.bias.data.zero_()

In [3]:
import numpy as np
import torch
import argparse
import joblib
from torch.utils.data import Dataset, DataLoader

In [4]:
def cuda(tensor, is_cuda):
    if is_cuda : return tensor.cuda()
    else : return tensor

In [5]:
params={
    'epoch':50,
    'batch_size':16,
    'lr':0.01
}

In [6]:
class CustomDataset(Dataset):
    """
    construct dataset from numpy and split it 
    
    """
    def __init__(self, data, target, transform=None):
        self.data = torch.from_numpy(data).float()
        self.target = torch.from_numpy(target).long()
        self.transform = transform
        
    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        
        if self.transform:
            x = self.transform(x)
        
        return x, y
    
    def __len__(self):
        return len(self.data)
    
    
X = joblib.load('./joblib_features/X.joblib')
y = joblib.load('./joblib_features/y.joblib')
full_dataset = CustomDataset(X, y)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, 
                                                            [train_size, test_size])
                                                              

train_dataloader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True, num_workers=0)
test_dataloader = DataLoader(test_dataset, batch_size=params['batch_size'], shuffle=True, num_workers=0)

In [7]:
class Solver(object):
#train the model

    def __init__(self):
        self.cuda =torch.cuda.is_available ()
        self.epoch = params['epoch']
        self.batch_size = params['batch_size']
        self.lr = params['lr']
        self.toynet = cuda(ToyNet(), self.cuda)
        self.toynet.weight_init()
        self.optim = optim.Adam(self.toynet.parameters(),
                                lr=self.lr,
                                betas=(0.5,0.999))
        self.criterion = nn.CrossEntropyLoss()
    
    def train(self):
        for epc in range(self.epoch):  # loop over the dataset multiple times

            running_loss = 0.0
            for i, data in enumerate(train_dataloader):
                # get the inputs; data is a list of [inputs, labels]
                inputs, labels = data

                # zero the parameter gradients
                self.optim.zero_grad()

                # forward + backward + optimize
                outputs = self.toynet.forward(inputs)
                loss = self.criterion(outputs,labels)
                loss.backward()
                self.optim.step()
                
                prediction = F.softmax(outputs,dim=1).max(1)[1]
                accuracy = torch.eq(prediction,labels).float().mean()
                avg_accuracy = Variable(cuda(torch.zeros(accuracy.size()), self.cuda))
                

                # print statistics
                running_loss += loss.item()
                if i % self.batch_size == 0:    # print every 30 mini-batches
                    print('[%d, %5d] loss: %.3f' %
                            (epc + 1, i + 1, running_loss / self.batch_size))
                    print('acc:{:.4f} '
                            .format(accuracy.item(), end=' '))
                    print('err:{:.4f} '
                            .format(1-accuracy.item()))
                    running_loss = 0.0
        
            self.test()
        print('Finished Training',(epc+1))
    
    def test(self):
        self.cuda =torch.cuda.is_available ()
        self.epoch = params['epoch']
        self.batch_size = params['batch_size']
        self.lr = params['lr']
        self.toynet = cuda(ToyNet(), self.cuda)
        self.toynet.weight_init()
        self.optim = optim.Adam(self.toynet.parameters(),
                                lr=self.lr,
                                betas=(0.5,0.999))
        self.criterion = nn.CrossEntropyLoss()
      
        loss=0
        correct = 0
        total_num= 0
        y_real=torch.randn([0])
        y_hat=torch.randn([0])
        
        for i, data in enumerate(test_dataloader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            outputs = self.toynet.forward(inputs)
            #loss
            total_num += labels.shape[0]
            loss += self.criterion(outputs,labels)
            prediction = F.softmax(outputs,dim=1).max(1)[1]
            y_real=torch.cat([y_real,labels],dim=0)
            y_hat=torch.cat([y_hat,prediction],dim=0)
            correct += torch.eq(prediction,labels).float().sum()
            avg_correct = Variable(cuda(torch.zeros(correct.size()), self.cuda))
        accuracy = correct/total_num
        avg_accuracy = avg_correct/total_num
        
        print('[TEST RESULT]')
        print('acc:{:.4f} '
                .format(accuracy.item(),end=' '))
        print('err:{:.4f}'
                .format(1-accuracy.item()))
        print(classification_report(y_real,y_hat))
   



In [8]:
net=Solver()
net.train()



[1,     1] loss: 10.071
acc:0.0625 
err:0.9375 
[1,    17] loss: 140.323
acc:0.2500 
err:0.7500 
[1,    33] loss: 28.560
acc:0.1875 
err:0.8125 
[1,    49] loss: 15.465
acc:0.5625 
err:0.4375 
[1,    65] loss: 5.482
acc:0.3750 
err:0.6250 
[1,    81] loss: 2.983
acc:0.3750 
err:0.6250 
[1,    97] loss: 3.956
acc:0.5000 
err:0.5000 
[1,   113] loss: 2.879
acc:0.3750 
err:0.6250 
[1,   129] loss: 2.743
acc:0.6875 
err:0.3125 
[1,   145] loss: 1.421
acc:0.6250 
err:0.3750 
[1,   161] loss: 1.156
acc:0.7500 
err:0.2500 
[1,   177] loss: 2.181
acc:0.4375 
err:0.5625 
[1,   193] loss: 1.760
acc:0.2500 
err:0.7500 
[1,   209] loss: 1.297
acc:0.4375 
err:0.5625 
[1,   225] loss: 1.227
acc:0.5625 
err:0.4375 
[1,   241] loss: 1.092
acc:0.6250 
err:0.3750 
[1,   257] loss: 1.225
acc:0.8125 
err:0.1875 
[TEST RESULT]
acc:0.1589 
err:0.8411
              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00       115
         1.0       0.00      0.00      0.00       

  'precision', 'predicted', average, warn_for)


[2,    17] loss: 97.331
acc:0.5000 
err:0.5000 
[2,    33] loss: 24.863
acc:0.1250 
err:0.8750 
[2,    49] loss: 17.178
acc:0.3125 
err:0.6875 
[2,    65] loss: 10.648
acc:0.3750 
err:0.6250 
[2,    81] loss: 6.198
acc:0.6875 
err:0.3125 
[2,    97] loss: 2.893
acc:0.3750 
err:0.6250 
[2,   113] loss: 3.279
acc:0.5625 
err:0.4375 
[2,   129] loss: 2.306
acc:0.6250 
err:0.3750 
[2,   145] loss: 1.837
acc:0.1250 
err:0.8750 
[2,   161] loss: 2.013
acc:0.7500 
err:0.2500 
[2,   177] loss: 2.354
acc:0.1250 
err:0.8750 
[2,   193] loss: 1.993
acc:0.6875 
err:0.3125 
[2,   209] loss: 1.620
acc:0.8125 
err:0.1875 
[2,   225] loss: 1.384
acc:0.6250 
err:0.3750 
[2,   241] loss: 1.288
acc:0.6875 
err:0.3125 
[2,   257] loss: 1.536
acc:0.3750 
err:0.6250 
[TEST RESULT]
acc:0.1227 
err:0.8773
              precision    recall  f1-score   support

         0.0       0.14      0.48      0.22       115
         1.0       0.00      0.00      0.00        57
         2.0       0.00      0.00      0.00 

[8,   129] loss: 1.869
acc:0.6250 
err:0.3750 
[8,   145] loss: 1.914
acc:0.4375 
err:0.5625 
[8,   161] loss: 2.200
acc:0.3750 
err:0.6250 
[8,   177] loss: 1.329
acc:0.3750 
err:0.6250 
[8,   193] loss: 1.269
acc:0.4375 
err:0.5625 
[8,   209] loss: 1.364
acc:0.7500 
err:0.2500 
[8,   225] loss: 1.141
acc:0.6250 
err:0.3750 
[8,   241] loss: 1.355
acc:0.4375 
err:0.5625 
[8,   257] loss: 1.379
acc:0.7500 
err:0.2500 
[TEST RESULT]
acc:0.1532 
err:0.8468
              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00       115
         1.0       0.00      0.00      0.00        57
         2.0       0.00      0.00      0.00       167
         3.0       0.15      1.00      0.27       161
         4.0       0.00      0.00      0.00       161
         5.0       0.00      0.00      0.00       149
         6.0       0.00      0.00      0.00       124
         7.0       0.00      0.00      0.00       117

    accuracy                           0.15      105

[14,    65] loss: 8.013
acc:0.3750 
err:0.6250 
[14,    81] loss: 5.699
acc:0.3750 
err:0.6250 
[14,    97] loss: 4.859
acc:0.3750 
err:0.6250 
[14,   113] loss: 2.553
acc:0.3125 
err:0.6875 
[14,   129] loss: 3.007
acc:0.3125 
err:0.6875 
[14,   145] loss: 2.303
acc:0.7500 
err:0.2500 
[14,   161] loss: 2.139
acc:0.3125 
err:0.6875 
[14,   177] loss: 1.236
acc:0.5625 
err:0.4375 
[14,   193] loss: 2.320
acc:0.3125 
err:0.6875 
[14,   209] loss: 1.186
acc:0.5000 
err:0.5000 
[14,   225] loss: 1.741
acc:0.5625 
err:0.4375 
[14,   241] loss: 1.084
acc:0.6875 
err:0.3125 
[14,   257] loss: 0.959
acc:0.5625 
err:0.4375 
[TEST RESULT]
acc:0.1094 
err:0.8906
              precision    recall  f1-score   support

         0.0       0.11      1.00      0.20       115
         1.0       0.00      0.00      0.00        57
         2.0       0.00      0.00      0.00       167
         3.0       0.00      0.00      0.00       161
         4.0       0.00      0.00      0.00       161
         5.0  

[20,   129] loss: 2.174
acc:0.5625 
err:0.4375 
[20,   145] loss: 2.427
acc:0.5000 
err:0.5000 
[20,   161] loss: 1.883
acc:0.5625 
err:0.4375 
[20,   177] loss: 1.411
acc:0.3750 
err:0.6250 
[20,   193] loss: 1.269
acc:0.6875 
err:0.3125 
[20,   209] loss: 1.214
acc:0.7500 
err:0.2500 
[20,   225] loss: 1.155
acc:0.7500 
err:0.2500 
[20,   241] loss: 1.219
acc:0.5625 
err:0.4375 
[20,   257] loss: 1.109
acc:0.5000 
err:0.5000 
[TEST RESULT]
acc:0.1294 
err:0.8706
              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00       115
         1.0       0.00      0.00      0.00        57
         2.0       0.14      0.78      0.24       167
         3.0       0.00      0.00      0.00       161
         4.0       0.00      0.00      0.00       161
         5.0       0.00      0.00      0.00       149
         6.0       0.00      0.00      0.00       124
         7.0       0.03      0.04      0.04       117

    accuracy                           0.13

[26,   113] loss: 2.192
acc:0.5625 
err:0.4375 
[26,   129] loss: 2.579
acc:0.3750 
err:0.6250 
[26,   145] loss: 1.488
acc:0.3750 
err:0.6250 
[26,   161] loss: 1.396
acc:0.5625 
err:0.4375 
[26,   177] loss: 1.293
acc:0.5625 
err:0.4375 
[26,   193] loss: 1.385
acc:0.5000 
err:0.5000 
[26,   209] loss: 3.556
acc:0.5625 
err:0.4375 
[26,   225] loss: 4.747
acc:0.3750 
err:0.6250 
[26,   241] loss: 1.417
acc:0.5625 
err:0.4375 
[26,   257] loss: 1.309
acc:0.5000 
err:0.5000 
[TEST RESULT]
acc:0.1532 
err:0.8468
              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00       115
         1.0       0.00      0.00      0.00        57
         2.0       0.00      0.00      0.00       167
         3.0       0.00      0.00      0.00       161
         4.0       0.15      1.00      0.27       161
         5.0       0.00      0.00      0.00       149
         6.0       0.00      0.00      0.00       124
         7.0       0.00      0.00      0.00       

[32,   129] loss: 4.661
acc:0.6875 
err:0.3125 
[32,   145] loss: 3.838
acc:0.5000 
err:0.5000 
[32,   161] loss: 1.274
acc:0.5625 
err:0.4375 
[32,   177] loss: 1.342
acc:0.6875 
err:0.3125 
[32,   193] loss: 0.988
acc:0.7500 
err:0.2500 
[32,   209] loss: 1.294
acc:0.3750 
err:0.6250 
[32,   225] loss: 1.080
acc:0.5000 
err:0.5000 
[32,   241] loss: 1.149
acc:0.6250 
err:0.3750 
[32,   257] loss: 1.186
acc:0.8125 
err:0.1875 
[TEST RESULT]
acc:0.1589 
err:0.8411
              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00       115
         1.0       0.00      0.00      0.00        57
         2.0       0.16      1.00      0.27       167
         3.0       0.00      0.00      0.00       161
         4.0       0.00      0.00      0.00       161
         5.0       0.00      0.00      0.00       149
         6.0       0.00      0.00      0.00       124
         7.0       0.00      0.00      0.00       117

    accuracy                           0.16

[38,   113] loss: 1.507
acc:0.5000 
err:0.5000 
[38,   129] loss: 1.370
acc:0.1875 
err:0.8125 
[38,   145] loss: 1.368
acc:0.5625 
err:0.4375 
[38,   161] loss: 2.814
acc:0.4375 
err:0.5625 
[38,   177] loss: 1.342
acc:0.3750 
err:0.6250 
[38,   193] loss: 1.364
acc:0.1250 
err:0.8750 
[38,   209] loss: 1.246
acc:0.5000 
err:0.5000 
[38,   225] loss: 1.438
acc:0.5000 
err:0.5000 
[38,   241] loss: 3.282
acc:0.2500 
err:0.7500 
[38,   257] loss: 2.477
acc:0.3125 
err:0.6875 
[TEST RESULT]
acc:0.0542 
err:0.9458
              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00       115
         1.0       0.05      1.00      0.10        57
         2.0       0.00      0.00      0.00       167
         3.0       0.00      0.00      0.00       161
         4.0       0.00      0.00      0.00       161
         5.0       0.00      0.00      0.00       149
         6.0       0.00      0.00      0.00       124
         7.0       0.00      0.00      0.00       

[44,    49] loss: 7.529
acc:0.4375 
err:0.5625 
[44,    65] loss: 8.030
acc:0.2500 
err:0.7500 
[44,    81] loss: 2.931
acc:0.6250 
err:0.3750 
[44,    97] loss: 2.136
acc:0.4375 
err:0.5625 
[44,   113] loss: 2.793
acc:0.6250 
err:0.3750 
[44,   129] loss: 3.276
acc:0.3750 
err:0.6250 
[44,   145] loss: 5.368
acc:0.5000 
err:0.5000 
[44,   161] loss: 1.742
acc:0.7500 
err:0.2500 
[44,   177] loss: 2.708
acc:0.3125 
err:0.6875 
[44,   193] loss: 1.529
acc:0.5000 
err:0.5000 
[44,   209] loss: 2.197
acc:0.4375 
err:0.5625 
[44,   225] loss: 1.684
acc:0.6875 
err:0.3125 
[44,   241] loss: 1.158
acc:0.5625 
err:0.4375 
[44,   257] loss: 1.424
acc:0.6250 
err:0.3750 
[TEST RESULT]
acc:0.1532 
err:0.8468
              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00       115
         1.0       0.00      0.00      0.00        57
         2.0       0.00      0.00      0.00       167
         3.0       0.00      0.00      0.00       161
         4.0       0

[TEST RESULT]
acc:0.1180 
err:0.8820
              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00       115
         1.0       0.00      0.00      0.00        57
         2.0       0.00      0.00      0.00       167
         3.0       0.00      0.00      0.00       161
         4.0       0.00      0.00      0.00       161
         5.0       0.00      0.00      0.00       149
         6.0       0.12      1.00      0.21       124
         7.0       0.00      0.00      0.00       117

    accuracy                           0.12      1051
   macro avg       0.01      0.12      0.03      1051
weighted avg       0.01      0.12      0.02      1051

[50,     1] loss: 13.980
acc:0.0625 
err:0.9375 
[50,    17] loss: 133.850
acc:0.1875 
err:0.8125 
[50,    33] loss: 26.240
acc:0.5625 
err:0.4375 
[50,    49] loss: 16.525
acc:0.3750 
err:0.6250 
[50,    65] loss: 12.698
acc:0.4375 
err:0.5625 
[50,    81] loss: 7.132
acc:0.6875 
err:0.3125 
[50,    97] loss: 

In [9]:
 net.test()

[TEST RESULT]
acc:0.1418 
err:0.8582
              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00       115
         1.0       0.00      0.00      0.00        57
         2.0       0.00      0.00      0.00       167
         3.0       0.00      0.00      0.00       161
         4.0       0.00      0.00      0.00       161
         5.0       0.14      1.00      0.25       149
         6.0       0.00      0.00      0.00       124
         7.0       0.00      0.00      0.00       117

    accuracy                           0.14      1051
   macro avg       0.02      0.12      0.03      1051
weighted avg       0.02      0.14      0.04      1051



