In [58]:
import numpy as np
import pandas as pd

pd.set_option("display.max_columns", None)

In [108]:
trainset = pd.read_csv("datasets/train.csv")
testset = pd.read_csv("datasets/test.csv")
trainset


X_train = trainset.iloc[:5000, 1:].to_numpy()
Y_train = pd.get_dummies(trainset['label']).to_numpy()[:5000, :]

X_cv = trainset.iloc[20000:22000, 1:].to_numpy()
Y_cv = pd.get_dummies(trainset['label']).to_numpy()[20000:22000, :]

print(X_train.shape)
print(Y_train.shape)
print(X_cv.shape)
print(Y_cv.shape)

(5000, 784)
(5000, 10)
(2000, 784)
(2000, 10)


In [121]:
from torch.nn import Module, Linear, BatchNorm1d, Sigmoid, BCELoss, ReLU
from torch.optim import SGD
from torchmetrics import Accuracy

class Classifier(Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.input = Linear(input_dim, 512)
        self.hidden = Linear(512, 512)
        self.output = Linear(512, 10)
        self.bn = BatchNorm1d(input_dim)
        self.sigmoid = Sigmoid()
        self.relu = ReLU()
    
    def forward(self, x):
        #x = self.bn(x)
        x = self.input(x)
        x = self.relu(x)
        
        x = self.hidden(x)
        x = self.relu(x)
        
        x = self.hidden(x)
        x = self.relu(x)
        
        
        x = self.output(x)
        x = self.sigmoid(x)
        return x
    
    def fit(self, X_train, Y_train, loss_fn, opt, X_cv=None, Y_cv=None, epochs=1):
            
        for i in range(epochs):   
            self.train()
            opt.zero_grad()
        
            Y_pred = self(X_train)
            cost = loss_fn(Y_pred, Y_train)
            
            acc = Accuracy().cuda()
            
            msg = f"Iter: {i}    loss: {cost.item(): .4f}    accuracy: {acc(torch.round(Y_pred).int(), Y_train.int()): .4f}    "
            if torch.is_tensor(X_cv) and torch.is_tensor(Y_cv):
                Y_pred_val = self(X_cv)
                cost_val = loss_fn(Y_pred_val, Y_cv)
                msg += f"cv_loss: {cost_val.item(): .4f}    cv_accuracy: {acc(torch.round(Y_pred_val).int(), Y_cv.int()): .4f}"
            
            print(msg)
            
            cost.backward()
        
            opt.step()
        

In [122]:
input_dim, output_dim = (X_train.shape[1], 10)
LEARNING_RATE = 0.01
EPOCHS = 1000

model_v1 = Classifier(input_dim, output_dim).cuda()
optimizer = SGD(model_v1.parameters(), lr=LEARNING_RATE)
criterion = BCELoss().cuda()

In [123]:
X_train_t = torch.from_numpy(X_train).float().cuda()
Y_train_t = torch.from_numpy(Y_train).float().cuda()
X_cv_t = torch.from_numpy(X_cv).float().cuda()
Y_cv_t = torch.from_numpy(Y_cv).float().cuda()

model_v1.fit(X_train_t, Y_train_t, criterion, optimizer, X_cv=X_cv_t, Y_cv=Y_cv_t, epochs=EPOCHS)

Iter: 0    loss:  1.7290    accuracy:  0.3656    cv_loss:  1.7326    cv_accuracy:  0.3663
Iter: 1    loss:  0.9735    accuracy:  0.9000    cv_loss:  0.9627    cv_accuracy:  0.9000
Iter: 2    loss:  0.6223    accuracy:  0.8970    cv_loss:  0.6153    cv_accuracy:  0.8972
Iter: 3    loss:  0.4560    accuracy:  0.8944    cv_loss:  0.4526    cv_accuracy:  0.8942
Iter: 4    loss:  0.3614    accuracy:  0.8968    cv_loss:  0.3613    cv_accuracy:  0.8973
Iter: 5    loss:  0.3048    accuracy:  0.8990    cv_loss:  0.3070    cv_accuracy:  0.8991
Iter: 6    loss:  0.2731    accuracy:  0.9034    cv_loss:  0.2768    cv_accuracy:  0.9023
Iter: 7    loss:  0.2542    accuracy:  0.9087    cv_loss:  0.2590    cv_accuracy:  0.9073
Iter: 8    loss:  0.2411    accuracy:  0.9135    cv_loss:  0.2466    cv_accuracy:  0.9119
Iter: 9    loss:  0.2305    accuracy:  0.9182    cv_loss:  0.2367    cv_accuracy:  0.9158
Iter: 10    loss:  0.2213    accuracy:  0.9215    cv_loss:  0.2280    cv_accuracy:  0.9192
Iter: 11 

Iter: 92    loss:  0.0801    accuracy:  0.9745    cv_loss:  0.0975    cv_accuracy:  0.9655
Iter: 93    loss:  0.0797    accuracy:  0.9747    cv_loss:  0.0971    cv_accuracy:  0.9656
Iter: 94    loss:  0.0792    accuracy:  0.9748    cv_loss:  0.0967    cv_accuracy:  0.9657
Iter: 95    loss:  0.0788    accuracy:  0.9750    cv_loss:  0.0963    cv_accuracy:  0.9658
Iter: 96    loss:  0.0784    accuracy:  0.9752    cv_loss:  0.0960    cv_accuracy:  0.9660
Iter: 97    loss:  0.0780    accuracy:  0.9754    cv_loss:  0.0956    cv_accuracy:  0.9663
Iter: 98    loss:  0.0775    accuracy:  0.9755    cv_loss:  0.0953    cv_accuracy:  0.9666
Iter: 99    loss:  0.0771    accuracy:  0.9756    cv_loss:  0.0950    cv_accuracy:  0.9668
Iter: 100    loss:  0.0767    accuracy:  0.9758    cv_loss:  0.0946    cv_accuracy:  0.9669
Iter: 101    loss:  0.0763    accuracy:  0.9759    cv_loss:  0.0943    cv_accuracy:  0.9670
Iter: 102    loss:  0.0759    accuracy:  0.9761    cv_loss:  0.0940    cv_accuracy:  0.9

Iter: 184    loss:  0.0547    accuracy:  0.9835    cv_loss:  0.0774    cv_accuracy:  0.9738
Iter: 185    loss:  0.0545    accuracy:  0.9835    cv_loss:  0.0772    cv_accuracy:  0.9739
Iter: 186    loss:  0.0543    accuracy:  0.9836    cv_loss:  0.0771    cv_accuracy:  0.9739
Iter: 187    loss:  0.0542    accuracy:  0.9836    cv_loss:  0.0770    cv_accuracy:  0.9739
Iter: 188    loss:  0.0540    accuracy:  0.9837    cv_loss:  0.0769    cv_accuracy:  0.9740
Iter: 189    loss:  0.0538    accuracy:  0.9837    cv_loss:  0.0767    cv_accuracy:  0.9740
Iter: 190    loss:  0.0536    accuracy:  0.9838    cv_loss:  0.0766    cv_accuracy:  0.9741
Iter: 191    loss:  0.0535    accuracy:  0.9838    cv_loss:  0.0765    cv_accuracy:  0.9742
Iter: 192    loss:  0.0533    accuracy:  0.9839    cv_loss:  0.0764    cv_accuracy:  0.9742
Iter: 193    loss:  0.0531    accuracy:  0.9839    cv_loss:  0.0762    cv_accuracy:  0.9743
Iter: 194    loss:  0.0530    accuracy:  0.9840    cv_loss:  0.0761    cv_accura

Iter: 277    loss:  0.0421    accuracy:  0.9876    cv_loss:  0.0686    cv_accuracy:  0.9772
Iter: 278    loss:  0.0420    accuracy:  0.9876    cv_loss:  0.0686    cv_accuracy:  0.9772
Iter: 279    loss:  0.0418    accuracy:  0.9877    cv_loss:  0.0685    cv_accuracy:  0.9772
Iter: 280    loss:  0.0417    accuracy:  0.9878    cv_loss:  0.0684    cv_accuracy:  0.9772
Iter: 281    loss:  0.0416    accuracy:  0.9878    cv_loss:  0.0684    cv_accuracy:  0.9772
Iter: 282    loss:  0.0415    accuracy:  0.9878    cv_loss:  0.0683    cv_accuracy:  0.9772
Iter: 283    loss:  0.0414    accuracy:  0.9879    cv_loss:  0.0682    cv_accuracy:  0.9772
Iter: 284    loss:  0.0413    accuracy:  0.9879    cv_loss:  0.0682    cv_accuracy:  0.9772
Iter: 285    loss:  0.0412    accuracy:  0.9879    cv_loss:  0.0681    cv_accuracy:  0.9772
Iter: 286    loss:  0.0411    accuracy:  0.9880    cv_loss:  0.0680    cv_accuracy:  0.9773
Iter: 287    loss:  0.0410    accuracy:  0.9880    cv_loss:  0.0680    cv_accura

Iter: 370    loss:  0.0339    accuracy:  0.9903    cv_loss:  0.0636    cv_accuracy:  0.9789
Iter: 371    loss:  0.0338    accuracy:  0.9903    cv_loss:  0.0635    cv_accuracy:  0.9789
Iter: 372    loss:  0.0338    accuracy:  0.9903    cv_loss:  0.0635    cv_accuracy:  0.9789
Iter: 373    loss:  0.0337    accuracy:  0.9904    cv_loss:  0.0634    cv_accuracy:  0.9789
Iter: 374    loss:  0.0336    accuracy:  0.9904    cv_loss:  0.0634    cv_accuracy:  0.9789
Iter: 375    loss:  0.0336    accuracy:  0.9905    cv_loss:  0.0634    cv_accuracy:  0.9790
Iter: 376    loss:  0.0335    accuracy:  0.9905    cv_loss:  0.0633    cv_accuracy:  0.9790
Iter: 377    loss:  0.0334    accuracy:  0.9905    cv_loss:  0.0633    cv_accuracy:  0.9791
Iter: 378    loss:  0.0333    accuracy:  0.9905    cv_loss:  0.0632    cv_accuracy:  0.9792
Iter: 379    loss:  0.0333    accuracy:  0.9906    cv_loss:  0.0632    cv_accuracy:  0.9792
Iter: 380    loss:  0.0332    accuracy:  0.9906    cv_loss:  0.0632    cv_accura

Iter: 463    loss:  0.0281    accuracy:  0.9927    cv_loss:  0.0603    cv_accuracy:  0.9808
Iter: 464    loss:  0.0280    accuracy:  0.9928    cv_loss:  0.0602    cv_accuracy:  0.9808
Iter: 465    loss:  0.0280    accuracy:  0.9928    cv_loss:  0.0602    cv_accuracy:  0.9808
Iter: 466    loss:  0.0279    accuracy:  0.9928    cv_loss:  0.0602    cv_accuracy:  0.9808
Iter: 467    loss:  0.0278    accuracy:  0.9929    cv_loss:  0.0601    cv_accuracy:  0.9808
Iter: 468    loss:  0.0278    accuracy:  0.9929    cv_loss:  0.0601    cv_accuracy:  0.9808
Iter: 469    loss:  0.0277    accuracy:  0.9929    cv_loss:  0.0601    cv_accuracy:  0.9808
Iter: 470    loss:  0.0277    accuracy:  0.9929    cv_loss:  0.0601    cv_accuracy:  0.9808
Iter: 471    loss:  0.0276    accuracy:  0.9930    cv_loss:  0.0600    cv_accuracy:  0.9808
Iter: 472    loss:  0.0276    accuracy:  0.9930    cv_loss:  0.0600    cv_accuracy:  0.9808
Iter: 473    loss:  0.0275    accuracy:  0.9930    cv_loss:  0.0600    cv_accura

Iter: 554    loss:  0.0237    accuracy:  0.9944    cv_loss:  0.0580    cv_accuracy:  0.9815
Iter: 555    loss:  0.0237    accuracy:  0.9944    cv_loss:  0.0579    cv_accuracy:  0.9815
Iter: 556    loss:  0.0236    accuracy:  0.9944    cv_loss:  0.0579    cv_accuracy:  0.9816
Iter: 557    loss:  0.0236    accuracy:  0.9945    cv_loss:  0.0579    cv_accuracy:  0.9816
Iter: 558    loss:  0.0235    accuracy:  0.9945    cv_loss:  0.0579    cv_accuracy:  0.9815
Iter: 559    loss:  0.0235    accuracy:  0.9945    cv_loss:  0.0579    cv_accuracy:  0.9815
Iter: 560    loss:  0.0235    accuracy:  0.9945    cv_loss:  0.0578    cv_accuracy:  0.9815
Iter: 561    loss:  0.0234    accuracy:  0.9945    cv_loss:  0.0578    cv_accuracy:  0.9815
Iter: 562    loss:  0.0234    accuracy:  0.9945    cv_loss:  0.0578    cv_accuracy:  0.9815
Iter: 563    loss:  0.0233    accuracy:  0.9945    cv_loss:  0.0578    cv_accuracy:  0.9816
Iter: 564    loss:  0.0233    accuracy:  0.9946    cv_loss:  0.0578    cv_accura

Iter: 646    loss:  0.0203    accuracy:  0.9957    cv_loss:  0.0563    cv_accuracy:  0.9819
Iter: 647    loss:  0.0202    accuracy:  0.9957    cv_loss:  0.0563    cv_accuracy:  0.9819
Iter: 648    loss:  0.0202    accuracy:  0.9957    cv_loss:  0.0563    cv_accuracy:  0.9819
Iter: 649    loss:  0.0202    accuracy:  0.9957    cv_loss:  0.0562    cv_accuracy:  0.9819
Iter: 650    loss:  0.0201    accuracy:  0.9957    cv_loss:  0.0562    cv_accuracy:  0.9820
Iter: 651    loss:  0.0201    accuracy:  0.9957    cv_loss:  0.0562    cv_accuracy:  0.9820
Iter: 652    loss:  0.0201    accuracy:  0.9957    cv_loss:  0.0562    cv_accuracy:  0.9820
Iter: 653    loss:  0.0200    accuracy:  0.9958    cv_loss:  0.0562    cv_accuracy:  0.9820
Iter: 654    loss:  0.0200    accuracy:  0.9958    cv_loss:  0.0562    cv_accuracy:  0.9820
Iter: 655    loss:  0.0200    accuracy:  0.9958    cv_loss:  0.0561    cv_accuracy:  0.9820
Iter: 656    loss:  0.0199    accuracy:  0.9958    cv_loss:  0.0561    cv_accura

Iter: 736    loss:  0.0175    accuracy:  0.9967    cv_loss:  0.0551    cv_accuracy:  0.9822
Iter: 737    loss:  0.0175    accuracy:  0.9967    cv_loss:  0.0551    cv_accuracy:  0.9822
Iter: 738    loss:  0.0175    accuracy:  0.9967    cv_loss:  0.0551    cv_accuracy:  0.9822
Iter: 739    loss:  0.0175    accuracy:  0.9967    cv_loss:  0.0551    cv_accuracy:  0.9823
Iter: 740    loss:  0.0174    accuracy:  0.9967    cv_loss:  0.0551    cv_accuracy:  0.9823
Iter: 741    loss:  0.0174    accuracy:  0.9967    cv_loss:  0.0551    cv_accuracy:  0.9823
Iter: 742    loss:  0.0174    accuracy:  0.9968    cv_loss:  0.0550    cv_accuracy:  0.9823
Iter: 743    loss:  0.0173    accuracy:  0.9968    cv_loss:  0.0550    cv_accuracy:  0.9823
Iter: 744    loss:  0.0173    accuracy:  0.9968    cv_loss:  0.0550    cv_accuracy:  0.9823
Iter: 745    loss:  0.0173    accuracy:  0.9968    cv_loss:  0.0550    cv_accuracy:  0.9823
Iter: 746    loss:  0.0173    accuracy:  0.9968    cv_loss:  0.0550    cv_accura

Iter: 826    loss:  0.0153    accuracy:  0.9974    cv_loss:  0.0542    cv_accuracy:  0.9829
Iter: 827    loss:  0.0153    accuracy:  0.9974    cv_loss:  0.0542    cv_accuracy:  0.9829
Iter: 828    loss:  0.0153    accuracy:  0.9974    cv_loss:  0.0542    cv_accuracy:  0.9829
Iter: 829    loss:  0.0152    accuracy:  0.9975    cv_loss:  0.0542    cv_accuracy:  0.9829
Iter: 830    loss:  0.0152    accuracy:  0.9975    cv_loss:  0.0542    cv_accuracy:  0.9829
Iter: 831    loss:  0.0152    accuracy:  0.9975    cv_loss:  0.0542    cv_accuracy:  0.9829
Iter: 832    loss:  0.0152    accuracy:  0.9975    cv_loss:  0.0542    cv_accuracy:  0.9829
Iter: 833    loss:  0.0152    accuracy:  0.9975    cv_loss:  0.0542    cv_accuracy:  0.9829
Iter: 834    loss:  0.0151    accuracy:  0.9975    cv_loss:  0.0542    cv_accuracy:  0.9829
Iter: 835    loss:  0.0151    accuracy:  0.9975    cv_loss:  0.0542    cv_accuracy:  0.9829
Iter: 836    loss:  0.0151    accuracy:  0.9975    cv_loss:  0.0542    cv_accura

Iter: 916    loss:  0.0135    accuracy:  0.9980    cv_loss:  0.0536    cv_accuracy:  0.9833
Iter: 917    loss:  0.0135    accuracy:  0.9980    cv_loss:  0.0536    cv_accuracy:  0.9833
Iter: 918    loss:  0.0134    accuracy:  0.9980    cv_loss:  0.0536    cv_accuracy:  0.9833
Iter: 919    loss:  0.0134    accuracy:  0.9980    cv_loss:  0.0536    cv_accuracy:  0.9833
Iter: 920    loss:  0.0134    accuracy:  0.9980    cv_loss:  0.0536    cv_accuracy:  0.9833
Iter: 921    loss:  0.0134    accuracy:  0.9980    cv_loss:  0.0536    cv_accuracy:  0.9833
Iter: 922    loss:  0.0134    accuracy:  0.9980    cv_loss:  0.0536    cv_accuracy:  0.9833
Iter: 923    loss:  0.0133    accuracy:  0.9980    cv_loss:  0.0536    cv_accuracy:  0.9833
Iter: 924    loss:  0.0133    accuracy:  0.9980    cv_loss:  0.0536    cv_accuracy:  0.9833
Iter: 925    loss:  0.0133    accuracy:  0.9980    cv_loss:  0.0535    cv_accuracy:  0.9834
Iter: 926    loss:  0.0133    accuracy:  0.9980    cv_loss:  0.0535    cv_accura