In [14]:
 # cnn model trained on CIFAR-10 dataset

In [15]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset,DataLoader

import torch.nn.functional as F

In [16]:
# transformations
composed = transforms.Compose([
                         transforms.ToTensor(),
                         transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
                        ])

In [17]:
# CIFAR-10 dataset
train_data = torchvision.datasets.CIFAR10(root='./data',train=True,
                                          download=True,transform=composed)

test_data = torchvision.datasets.CIFAR10(root='./data',train=False,
                                         download=True,transform=composed)

Files already downloaded and verified
Files already downloaded and verified


In [18]:
# hyperparameters
batch_size = 4
num_epochs = 4
learning_rate = 0.01

# device
device = None
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(device)

cuda


In [19]:
# DataLoaders

train_loader = DataLoader(dataset=train_data,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=2)

test_loader = DataLoader(dataset=test_data,
                         batch_size=batch_size,
                         shuffle=False,
                         num_workers=2)

In [20]:
len(train_loader),len(test_loader)

(12500, 2500)

In [21]:
# label classes
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

num_classes = len(classes)
print(num_classes)

10


In [25]:
# model class
class ConvNet(nn.Module):
    def __init__(self,num_classes):
        super(ConvNet,self).__init__()
        
        # cnn layers
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.conv2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
        
        #fc layers
        self.linear1 = nn.Linear(in_features=16*5*5,out_features=120)
        self.linear2 = nn.Linear(in_features=120,out_features=84)
        self.linear3 = nn.Linear(in_features=84,out_features=num_classes)

    def forward(self,x):
        x1 = self.pool(F.relu(self.conv1(x)))
        x2 = self.pool(F.relu(self.conv2(x1)))
        x2_flat = x2.view(-1,16*5*5) # Flatten
        lin1 = self.linear1(x2_flat)
        x3 = F.relu(lin1)
        x4 = F.relu(self.linear2(x3))
        x5 = F.relu(self.linear3(x4))
        return lin1,x5

In [26]:
# model defining
model = ConvNet(num_classes)
model = model.to(device)

# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)

In [27]:
n_total_steps = len(train_loader)

# training loop
for epoch in range(num_epochs):
    for ii,(images,labels) in enumerate(train_loader):
        # original shape = [4,3,32,32] : 4 := batch_size, 3 := RGB, dims =32x32
        # input_layer = 3 input_channels, 6 output_channels, 5 kernel_size
        images = images.to(device)
        labels = labels.to(device)

        # forward
        _,outputs = model(images)
        loss = criterion(outputs,labels)

        # backward and updates
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if (ii+1)%2000 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{ii+1}/{n_total_steps}], Loss = {loss.item():.6f}')

    print('----------------------------------------')

Epoch [1/4], Step [2000/12500], Loss = 2.291288
Epoch [1/4], Step [4000/12500], Loss = 1.638216
Epoch [1/4], Step [6000/12500], Loss = 1.212903
Epoch [1/4], Step [8000/12500], Loss = 1.595376
Epoch [1/4], Step [10000/12500], Loss = 1.581806
Epoch [1/4], Step [12000/12500], Loss = 2.580363
----------------------------------------
Epoch [2/4], Step [2000/12500], Loss = 1.167568
Epoch [2/4], Step [4000/12500], Loss = 1.759755
Epoch [2/4], Step [6000/12500], Loss = 1.679082
Epoch [2/4], Step [8000/12500], Loss = 2.505718
Epoch [2/4], Step [10000/12500], Loss = 1.020406
Epoch [2/4], Step [12000/12500], Loss = 0.843665
----------------------------------------
Epoch [3/4], Step [2000/12500], Loss = 1.695767
Epoch [3/4], Step [4000/12500], Loss = 1.212294
Epoch [3/4], Step [6000/12500], Loss = 0.734185
Epoch [3/4], Step [8000/12500], Loss = 1.875462
Epoch [3/4], Step [10000/12500], Loss = 0.649675
Epoch [3/4], Step [12000/12500], Loss = 0.096931
----------------------------------------
Epoch [

In [28]:
# evaluating model and getting features of every image
features = []

with torch.no_grad():
    # for entire test set
    n_correct = 0
    n_samples = 0

    # for each class label
    n_class_correct = [0 for i in range(num_classes)]
    n_class_samples = [0 for i in range(num_classes)]

    for images,labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        ftrs,outputs = model(images)
        features.append(ftrs)

        _,preds = torch.max(outputs,1)
        n_samples += labels.size(0)
        n_correct += (preds == labels).sum().item()

        # iterating over the batch 
        for i in range(batch_size):
            label = labels[i]
            pred = preds[i]
            n_class_samples[label] +=1
            if(label == pred):
                n_class_correct[label] += 1
            
    accuracy = n_correct/float(n_samples)
    print(f'Accuracy of model on test set = {accuracy:.4f}')
    
    print('-------------------------------------------')
    
    # printing accuracy per class
    for i in range(num_classes):
        accuracy = n_class_correct[i]/float(n_class_samples[i])
        print(f'Accuracy of {classes[i]} : {accuracy}')

Accuracy of model on test set = 0.5883
-------------------------------------------
Accuracy of plane : 0.53
Accuracy of car : 0.775
Accuracy of bird : 0.33
Accuracy of cat : 0.371
Accuracy of deer : 0.493
Accuracy of dog : 0.683
Accuracy of frog : 0.659
Accuracy of horse : 0.604
Accuracy of ship : 0.66
Accuracy of truck : 0.778


In [51]:
model

ConvNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (linear1): Linear(in_features=400, out_features=120, bias=True)
  (linear2): Linear(in_features=120, out_features=84, bias=True)
  (linear3): Linear(in_features=84, out_features=10, bias=True)
)

In [37]:
len(features), features[0].shape

(2500, torch.Size([4, 120]))

In [40]:
# convert tensor to numpy
for i in range(len(features)):
    features[i] = features[i].cpu().numpy()

type(features[0])

numpy.ndarray

In [42]:
features = np.array(features)

In [45]:
# reshaping (2500,4,120) into (2500x4,120) as we had taken batch_size=4 during loading
dims = features.shape
features = features.reshape(dims[0]*dims[1],dims[2])

features.shape

(10000, 120)

In [None]:
# convert to dataframe
import pandas as pd

ftrs_df = pd.DataFrame(features)

In [48]:
# convert df to csv
ftrs_df.to_csv('./hopefully_this_is_desired_output.csv',index=False)

In [49]:
# reloading the saved csv into a df

ftrs_csv_df = pd.read_csv('./hopefully_this_is_desired_output.csv')
ftrs_csv_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119
0,-4.041168,-4.620143,-1.781208,-0.723168,-3.685660,-2.465509,-1.312625,-3.474266,-0.666465,-6.764115,-1.050546,-1.225386,-5.585804,-6.634646,-1.438204,-2.241868,7.081588,2.125396,-1.947522,-0.275793,3.568234,-1.743907,0.924543,-3.827226,-5.649690,-0.910609,-3.933114,0.920461,-1.389764,-5.669378,-4.391663,-5.216137,-0.258783,-0.894288,-0.410219,-1.349462,1.067147,-1.355117,-1.064318,-3.026964,...,-2.290328,-3.284536,-0.969044,-1.620327,-0.119449,-0.457821,-8.265476,-2.297562,-4.014337,1.185151,1.269268,-1.117636,-2.991186,2.151464,-1.947825,-3.641346,-2.994073,-0.979270,-2.316109,2.873752,0.108960,-1.364055,1.481271,0.531562,0.323296,-5.919993,-5.442319,1.533954,3.005888,-1.049830,-3.512131,-0.805302,-1.204902,-5.971713,-0.976589,1.009387,1.615299,3.179398,-4.912114,-2.614215
1,-9.957033,-7.765330,-1.790383,-0.513650,-7.011168,-6.452775,-4.808393,1.504367,5.637000,-5.802796,-0.197615,-4.988310,-0.849823,-14.711400,-3.160456,-3.262569,-2.373764,2.796687,-4.199276,9.975645,-7.214971,-12.988875,-0.164072,1.181543,-3.065121,-2.355663,-0.256631,2.123674,-14.889133,-13.027492,-10.614709,-8.351180,-1.433402,-3.776729,-0.669025,-2.112906,-1.546241,-1.212393,-2.231576,2.277348,...,-8.920451,-3.659581,-5.063305,-1.569890,2.184623,-4.130387,-5.537583,-4.430237,-3.639200,-6.295833,-3.998492,-3.174933,-2.193366,-2.362017,-12.756830,-3.013829,1.958126,0.675121,-2.301796,-6.873603,1.879667,0.138329,-3.954981,-4.719856,-7.409155,4.867367,4.480630,-3.505374,-0.838354,-2.209542,-2.923599,1.004017,-3.588692,1.592608,-2.782061,-10.455246,-1.201422,-3.773453,8.664276,2.582627
2,-3.259100,-4.725335,-1.732987,-0.350862,-2.510638,-1.016953,-2.638103,0.490408,1.988256,-1.694846,-3.096874,-1.890489,3.303818,-4.203105,-1.144866,-1.718373,-1.212967,-1.540200,-2.000094,6.209831,-2.407688,-2.482808,-3.420059,-2.518446,-1.848594,-0.182664,0.208687,-1.152272,-6.492965,-5.474305,-6.623555,-1.836562,-0.957338,-3.144039,1.365988,-0.301060,-0.020419,0.130050,0.920070,1.425205,...,-4.182159,0.190383,-1.756984,-1.119064,3.226987,-0.699961,-3.431693,-2.847187,-0.579006,-5.512690,-0.847085,-0.234986,-2.013248,1.597455,-3.099414,-0.507322,0.702549,-1.291992,-2.835172,-2.163636,-2.189157,-0.457645,-0.951948,-3.104322,-1.511662,2.800468,2.530758,-1.824971,1.888536,-2.243031,1.017182,-0.313621,-1.285848,0.224540,0.594916,-6.662509,-4.193942,-3.578846,2.050025,4.811572
3,-6.947447,-2.192886,-2.688921,-0.302683,-3.348802,-7.758157,-1.579127,-0.364822,2.110183,-1.836274,4.686403,-2.857338,1.145200,-5.265627,-1.213157,-2.253142,0.193063,1.295935,-1.287893,-1.857092,-1.810665,-6.985744,-0.930199,1.945227,-4.829821,-0.669064,1.183280,2.773684,-9.733597,-6.384218,-5.258634,-4.056843,-2.712468,-0.336427,2.897691,-1.555371,-1.226870,-1.126374,-0.880810,2.099252,...,-4.290386,1.363143,-4.668521,-1.648762,-1.301664,0.492377,-3.948330,-2.044114,-1.675054,-4.437402,-2.133241,-4.287119,-1.410782,0.469959,-6.563445,2.793151,-2.305392,0.656179,-3.242972,-2.408987,-3.219388,-1.459225,-3.233110,-0.403408,-6.117755,1.224759,-2.827114,-2.127799,0.522055,-0.697712,-4.167525,0.337876,-5.808713,0.223810,-1.281335,-7.923990,3.113291,-3.527139,5.438856,-4.269650
4,2.468882,-6.016874,-1.166926,-0.915221,-1.880585,-3.809862,-0.337698,0.184816,-2.649029,0.869480,-2.532718,0.269207,-0.775140,-3.027358,-1.015006,-1.981597,1.547824,-6.011969,-2.795195,-2.756511,-0.840845,2.145902,-3.571432,-3.109736,-2.124403,2.103023,-0.911387,-0.886674,2.154503,-1.383355,3.116836,-3.671776,3.616449,3.479704,1.338850,1.070671,2.403682,-0.434449,-3.212668,-1.211945,...,-2.965212,-1.148267,-0.040053,-1.612430,-0.574172,-1.335602,1.404068,-2.787293,-0.125314,-2.123843,1.292274,-1.485380,-2.191170,0.010183,0.420879,-2.250148,-6.829504,-2.711986,-1.752280,-2.156443,-6.324668,-1.121308,-3.100153,-3.586131,0.400551,-4.432971,-5.778201,-0.116548,3.131870,-1.261272,-1.431039,-1.979913,-2.553624,-2.630749,1.568845,2.589263,-4.305298,0.748601,-5.016461,-1.386496
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,1.152494,-2.459723,3.114920,0.600530,-8.155719,-3.277661,-0.877588,1.536864,-1.993895,-3.974234,-1.573144,5.194475,-6.378981,0.059167,-1.808313,-0.170790,2.307154,-1.142684,-1.195115,-1.413461,-1.657794,-4.234680,-1.838939,-1.607897,-2.163154,0.569179,0.039772,-0.152748,0.960334,-4.885774,-0.353099,-2.546308,1.877114,5.208282,2.417560,-0.638982,3.468194,1.089829,-3.240198,1.866426,...,-2.954657,-3.252283,-5.180446,-1.256014,-0.920557,1.241014,-4.670323,-3.958753,-2.666360,-9.483654,2.186880,-0.178619,-1.545565,4.135842,-1.393055,0.421255,-2.766330,-3.053810,-8.371284,6.512273,-6.389387,0.670623,-2.400495,-2.463803,0.117077,-2.970522,-0.773931,-0.400698,0.625781,-0.561251,-0.954932,-2.985550,-7.688868,-2.418305,-1.875875,0.769634,-0.753602,-1.067290,-3.328650,1.893747
9996,1.308531,-5.397984,-1.504081,-1.233353,-2.431067,-3.145873,0.168304,-4.095574,-1.757717,-2.434891,-2.448132,1.871225,-3.047512,-4.225504,-0.508400,-2.204457,0.785472,-5.116079,0.167971,-3.940649,2.498221,1.795647,-0.732169,-3.331090,-1.050541,1.075578,-2.619204,-0.325528,4.615318,-1.361049,5.594965,-5.647921,3.682876,2.860109,-0.737237,-0.756991,0.555497,-1.668870,-2.010966,-5.122221,...,-1.559589,-3.431186,1.211397,-0.973246,-0.507723,0.054868,-1.610414,-0.382156,-1.082791,-1.273694,1.542790,1.244744,-1.662127,1.710666,1.470491,-3.057882,-4.615407,-1.221290,-3.214601,-0.242143,-7.016522,-1.487170,-2.010388,-1.315545,-0.590313,-8.868636,-9.310840,-0.635598,-1.605992,-1.420662,-4.474795,-1.030718,-1.785142,-4.963767,-1.836323,4.300394,-1.403072,0.542084,-7.364046,-2.414173
9997,-2.659198,-5.368254,-1.268456,-1.687649,-3.829951,0.145288,-0.600923,-1.831481,-6.685695,-1.971934,-4.772157,2.955398,-3.499938,5.001145,-1.603357,-1.210679,-1.599736,-2.211291,2.978359,-5.142384,1.429296,5.235621,-4.506841,0.382831,-0.941524,-0.397104,-6.640833,0.380323,6.913204,2.622923,3.451985,-3.262257,2.703968,3.045755,-0.944608,-3.021467,-2.722683,-1.138738,1.975566,-0.744914,...,-2.065692,-6.239154,-4.224986,-1.629877,-0.675801,-0.765367,-1.568130,2.782643,-1.428552,1.360239,1.469061,6.872408,-1.341854,4.042156,0.810258,-0.644728,-4.126339,-2.093933,-1.841186,3.129622,-8.076990,1.080889,-2.630823,-2.046040,2.839095,-8.060205,-6.073785,-0.454481,-4.391948,0.186558,-0.013707,-0.270892,-1.133284,-8.663404,-3.461194,1.527000,-5.498050,3.337017,-8.181948,-1.483493
9998,3.095540,-7.626342,-1.073295,-0.683592,-3.944719,-5.481817,-3.824585,-0.622524,-0.870088,-3.310616,-6.004220,0.070770,-2.492699,0.654236,-1.293016,-3.320658,-3.930301,-3.411257,2.535506,-2.932298,-5.451705,-3.736032,-2.180752,4.214794,-6.647468,0.574335,-0.296832,-2.282416,-3.304297,-5.441446,-0.654744,-3.530196,1.110014,-4.321532,-0.777180,2.736990,0.152652,0.307377,-2.256501,2.453462,...,-9.550564,-1.364258,-5.714246,0.020975,1.635584,0.686366,-1.336396,-4.086152,-0.881214,-3.977390,0.360865,-2.367365,-3.629470,0.061055,-4.619489,-4.750968,4.389218,-3.101394,-3.155628,1.749637,-7.045060,-0.942176,-0.971745,-6.633436,-3.732319,-0.165383,-2.820301,-0.516245,-4.640238,-3.423304,0.801744,-0.512751,1.603422,-4.749073,1.363733,0.115849,-7.147361,3.005476,1.008182,0.992002


In [50]:
# saving model
MODEL_PATH = 'saved_model.pth'
torch.save(model.state_dict(),MODEL_PATH)

In [None]:
# # printing parameters of the model

# for param in model.parameters():
#     print(param.shape)
#     print(str(param))