In [1]:
import torch
import numpy as np
import os.path
import utils
import time

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

### Download the CIFAR dataset 
* 50000 32 * 32 RGB training image 
* 10000 32 * 32 RGB test image

In [2]:
from utils import check_cifar_dataset_exists
data_path=check_cifar_dataset_exists()

train_data=torch.load(data_path+'cifar/train_data.pt')
train_label=torch.load(data_path+'cifar/train_label.pt')
test_data=torch.load(data_path+'cifar/test_data.pt')
test_label=torch.load(data_path+'cifar/test_label.pt')

print(train_data.size())
print(train_label.size())
print(test_data.size())

torch.Size([50000, 3, 32, 32])
torch.Size([50000])
torch.Size([10000, 3, 32, 32])


### Compute average pixel intensity over all training set and all channels

In [3]:
mean= train_data.mean()

print(mean)

tensor(0.4733)


### Compute standard deviation

In [4]:
std= train_data.std()

print(std)

tensor(0.2516)


### Make a Resnet convnet Class

In [None]:
# The BasicBlock is the repeated block in Resnet.
class BasicBlock(nn.Module):
    
    def __init__(self, in_channels):
        super().__init__()
        
        # block 1 :  channel x 32 x 32 -> channel x 32 x 32 -> channel x 32 x 32 (2 layers)
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels,in_channels, kernel_size=3, stride = 1, padding=1, bias=False),
            nn.BatchNorm2d(in_channels)
        )

        # identity shorcut
        self.shortcut = nn.Sequential()
        
    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))  

In [None]:
class ResNet(nn.Module):

    def __init__(self, num_classes=10):
        super().__init__()
        
        # income conv 
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1, stride = 1, bias = False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True))
        
        self.conv2 = BasicBlock(16)
        self.conv2_1 = BasicBlock(16)
        self.conv2_2 = BasicBlock(16)
        
        self.conv3 = nn.Conv2d(16, 32, kernel_size = 3, stride = 2, padding = 1) # subsampling using a stride of 2.
        self.conv4 = BasicBlock(32)
        self.conv4_1 = BasicBlock(32)
        self.conv4_2 = BasicBlock(32)
        
        self.conv5 = nn.Conv2d(32, 64, kernel_size = 3, stride = 2, padding = 1) # subsampling
        self.conv6 = BasicBlock(64)
        self.conv6_1 = BasicBlock(64)
        self.conv6_2 = BasicBlock(64)
        
#         self.avg_pool = nn.AdaptiveAvgPool2d((8, 8))

        # linear layers:   64 x 8 x 8 --> 4096 --> 10
        self.fc = nn.Linear(4096, 10)
        

    def forward(self, x):
        output = self.conv1(x)
        
        output = self.conv2(output)
        output = self.conv2_1(output)
        output = self.conv2_2(output)
        
        output = self.conv3(output)
        
        output = self.conv4(output)
        output = self.conv4_1(output)
        output = self.conv4_2(output)
        
        output = self.conv5(output)
        
        output = self.conv6(output)
        output = self.conv6_1(output)
        output = self.conv6_2(output)
        output = output.view(output.size(0), -1)
        # bs x 4096 ->  bs * 10
#         print(output.size())

        x = self.fc(output)
        x = F.log_softmax(x, dim =1)

        return x 

In [None]:
model = ResNet()
# print(model)

In [None]:
utils.display_num_param(model)

There are 356218 (0.36 million) parameters in this neural network


In [None]:
bs=5
x=torch.rand(bs,3,32,32)
y = model(x)
print(y.size())

torch.Size([5, 10])


### Put the network to GPU

In [None]:
gpu_id = 0
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

device = torch.device('cuda')
print(device)

cuda


### Send the weights of the networks to the GPU (as well as the mean and std)

In [None]:
model = model.to(device)

mean=mean.to(device)

std=std.to(device)

### Choose the criterion, learning rate, and batch size.

In [None]:
criterion = nn.NLLLoss()

my_lr=0.1 

bs= 128

### Divide the data to 45k train data and 5k dev data

In [None]:
dev_data = train_data[45000:50000]
dev_label = train_label[45000:50000]

In [None]:
train_data = train_data[0:45000]
train_label = train_label[0:45000]

In [None]:
print(dev_data.size())
print(dev_label.size())

torch.Size([5000, 3, 32, 32])
torch.Size([5000])


### Function to evaluate the network on the test set

In [None]:
def eval_on_dev_set():

    running_error=0
    num_batches=0

    for i in range(0,5000,bs):

        minibatch_data =  dev_data[i:i+bs]
        minibatch_label = dev_label[i:i+bs]

        minibatch_data=minibatch_data.to(device)
        minibatch_label=minibatch_label.to(device)
        
        inputs = (minibatch_data - mean)/std

        scores= model( inputs ) 

        error = utils.get_error( scores , minibatch_label)

        running_error += error.item()

        num_batches+=1

    total_error = running_error/num_batches
    print( 'error rate on dev set =', total_error*100 ,'percent')

### Do 64k passes through the training set. Divide the learning rate by 10 at epoch 32k and 48k

In [None]:
start=time.time()

for epoch in range(1,64000):
    
    # divide the learning rate by 10 at epoch 32k and 48k
    if epoch==32000 or epoch == 48000:
        my_lr = my_lr / 10
    
    # create a new optimizer at the beginning of each epoch: give the current learning rate. 
    optimizer=torch.optim.SGD( model.parameters() , lr=my_lr, momentum = 0.9, weight_decay = 0.0001 )
        
    # set the running quatities to zero at the beginning of the epoch
    running_loss=0
    running_error=0
    num_batches=0
    
    # set the order in which to visit the image from the training set
    shuffled_indices = torch.randperm(45000)
 
    for count in range(0,45000,bs):
    
        # Set the gradients to zeros
        optimizer.zero_grad()
        
        # create a minibatch       
        indices = shuffled_indices[count:count+bs]
        minibatch_data = train_data[indices]
        minibatch_label = train_label[indices]
        
        # send them to the gpu
        minibatch_data=minibatch_data.to(device)
        minibatch_label=minibatch_label.to(device)
        
        # normalize the minibatch (this is the only difference compared to before!)
        inputs = (minibatch_data - mean)/std
        
        # tell Pytorch to start tracking all operations that will be done on "inputs"
        inputs.requires_grad_()

        # forward the minibatch through the net 
        scores=model( inputs ) 

        # Compute the average of the losses of the data points in the minibatch
        loss =  criterion( scores , minibatch_label) 
        
        # backward pass to compute dL/dU, dL/dV and dL/dW   
        loss.backward()

        # do one step of stochastic gradient descent: U=U-lr(dL/dU), V=V-lr(dL/dU), ...
        optimizer.step()
        

        # START COMPUTING STATS
        
        # add the loss of this batch to the running loss
        running_loss += loss.detach().item()
        
        # compute the error made on this batch and add it to the running error       
        error = utils.get_error( scores.detach() , minibatch_label)
        running_error += error.item()
        
        num_batches+=1        
    
    
    # compute stats for the full training set
    total_loss = running_loss/num_batches
    total_error = running_error/num_batches
    elapsed = (time.time()-start)/60
    

    print('epoch=',epoch, '\t time=', elapsed,'min','\t lr=', my_lr  ,'\t loss=', total_loss , '\t error=', total_error*100 ,'percent')
    eval_on_dev_set() 
    print(' ')
    
           
    

epoch= 1 	 time= 0.10245123306910196 min 	 lr= 0.1 	 loss= 4.065308667719364 	 error= 90.03092448481105 percent
error rate on dev set = 89.609375 percent
 
epoch= 2 	 time= 0.2031666874885559 min 	 lr= 0.1 	 loss= 2.3041388473727484 	 error= 89.8037997159091 percent
error rate on dev set = 90.52734375 percent
 
epoch= 3 	 time= 0.30463571548461915 min 	 lr= 0.1 	 loss= 2.304587746208364 	 error= 89.98283616859804 percent
error rate on dev set = 89.58984375 percent
 
epoch= 4 	 time= 0.4024885376294454 min 	 lr= 0.1 	 loss= 2.3043833293698053 	 error= 90.07136798040433 percent
error rate on dev set = 90.72265625 percent
 
epoch= 5 	 time= 0.500160276889801 min 	 lr= 0.1 	 loss= 2.3043890317732636 	 error= 89.91501934149049 percent
error rate on dev set = 89.609375 percent
 
epoch= 6 	 time= 0.6023951967557272 min 	 lr= 0.1 	 loss= 2.304490840570493 	 error= 90.04547426646407 percent
error rate on dev set = 90.3515625 percent
 
epoch= 7 	 time= 0.7010857184727987 min 	 lr= 0.1 	 loss= 2.

error rate on dev set = 89.58984375 percent
 
epoch= 54 	 time= 5.408148169517517 min 	 lr= 0.1 	 loss= 2.304646049033512 	 error= 90.18234099176797 percent
error rate on dev set = 90.52734375 percent
 
epoch= 55 	 time= 5.505988490581513 min 	 lr= 0.1 	 loss= 2.3046786974776876 	 error= 90.13277304104783 percent
error rate on dev set = 89.86328125 percent
 
epoch= 56 	 time= 5.606597630182902 min 	 lr= 0.1 	 loss= 2.30450334941799 	 error= 90.06421638822015 percent
error rate on dev set = 89.58984375 percent
 
epoch= 57 	 time= 5.707600657145182 min 	 lr= 0.1 	 loss= 2.304250618273562 	 error= 90.0706281546842 percent
error rate on dev set = 90.17578125 percent
 
epoch= 58 	 time= 5.808122022946676 min 	 lr= 0.1 	 loss= 2.304236835376783 	 error= 90.06495619700713 percent
error rate on dev set = 89.70703125 percent
 
epoch= 59 	 time= 5.907059601942698 min 	 lr= 0.1 	 loss= 2.3042142885652455 	 error= 90.14411695640196 percent
error rate on dev set = 89.609375 percent
 
epoch= 60 	 ti

epoch= 106 	 time= 10.60258073012034 min 	 lr= 0.1 	 loss= 0.4364836711267179 	 error= 15.423275716602802 percent
error rate on dev set = 28.203125 percent
 
epoch= 107 	 time= 10.701287778218587 min 	 lr= 0.1 	 loss= 0.44212466838176956 	 error= 15.497504344040697 percent
error rate on dev set = 29.12109375 percent
 
epoch= 108 	 time= 10.800606667995453 min 	 lr= 0.1 	 loss= 0.4210280294699425 	 error= 14.652876420454545 percent
error rate on dev set = 29.824218749999996 percent
 
epoch= 109 	 time= 10.902257247765858 min 	 lr= 0.1 	 loss= 0.4119687943549996 	 error= 14.460276592184195 percent
error rate on dev set = 29.47265625 percent
 
epoch= 110 	 time= 11.002518419424693 min 	 lr= 0.1 	 loss= 0.4132235965860838 	 error= 14.613665962083774 percent
error rate on dev set = 30.507812499999996 percent
 
epoch= 111 	 time= 11.101859068870544 min 	 lr= 0.1 	 loss= 0.4081137058409778 	 error= 14.397391880100423 percent
error rate on dev set = 30.019531249999996 percent
 
epoch= 112 	 ti

error rate on dev set = 28.88671875 percent
 
epoch= 157 	 time= 15.690306417147319 min 	 lr= 0.1 	 loss= 0.23683846251912077 	 error= 8.069957386363637 percent
error rate on dev set = 28.73046875 percent
 
epoch= 158 	 time= 15.787585079669952 min 	 lr= 0.1 	 loss= 0.23567162494344468 	 error= 8.073163269595666 percent
error rate on dev set = 30.859375 percent
 
epoch= 159 	 time= 15.886044963200886 min 	 lr= 0.1 	 loss= 0.2422282176240432 	 error= 8.181423609229652 percent
error rate on dev set = 29.47265625 percent
 
epoch= 160 	 time= 15.987226764361063 min 	 lr= 0.1 	 loss= 0.24388803550126878 	 error= 8.41422031887553 percent
error rate on dev set = 29.042968749999996 percent
 
epoch= 161 	 time= 16.0860081354777 min 	 lr= 0.1 	 loss= 0.23248239532536405 	 error= 7.8635475852272725 percent
error rate on dev set = 29.51171875 percent
 
epoch= 162 	 time= 16.18644646803538 min 	 lr= 0.1 	 loss= 0.22145261318126525 	 error= 7.674153640188954 percent
error rate on dev set = 27.929687

error rate on dev set = 28.066406249999996 percent
 
epoch= 208 	 time= 20.75919998884201 min 	 lr= 0.1 	 loss= 0.18371098610276188 	 error= 6.202404912222515 percent
error rate on dev set = 26.308593749999996 percent
 
epoch= 209 	 time= 20.859638174374897 min 	 lr= 0.1 	 loss= 0.16713127741505476 	 error= 5.6122750890525905 percent
error rate on dev set = 27.109375000000004 percent
 
epoch= 210 	 time= 20.963168156147002 min 	 lr= 0.1 	 loss= 0.16800379982768474 	 error= 5.678119087083773 percent
error rate on dev set = 26.816406250000004 percent
 
epoch= 211 	 time= 21.057884081204733 min 	 lr= 0.1 	 loss= 0.15461781255858528 	 error= 5.350132184949788 percent
error rate on dev set = 26.503906249999996 percent
 
epoch= 212 	 time= 21.153980191548666 min 	 lr= 0.1 	 loss= 0.1535467722305012 	 error= 5.230527929961681 percent
error rate on dev set = 28.53515625 percent
 
epoch= 213 	 time= 21.24855632384618 min 	 lr= 0.1 	 loss= 0.1793055742746219 	 error= 6.120778095315804 percent
er

error rate on dev set = 26.621093750000004 percent
 
epoch= 259 	 time= 25.83360050916672 min 	 lr= 0.1 	 loss= 0.14542874123964628 	 error= 4.975783215327696 percent
error rate on dev set = 27.207031250000004 percent
 
epoch= 260 	 time= 25.93466115395228 min 	 lr= 0.1 	 loss= 0.16494230494241824 	 error= 5.598958327688954 percent
error rate on dev set = 26.8359375 percent
 
epoch= 261 	 time= 26.02905919154485 min 	 lr= 0.1 	 loss= 0.15125747994435104 	 error= 5.198962275277485 percent
error rate on dev set = 28.964843750000004 percent
 
epoch= 262 	 time= 26.126772805054983 min 	 lr= 0.1 	 loss= 0.1593916683923453 	 error= 5.323252048004757 percent
error rate on dev set = 27.75390625 percent
 
epoch= 263 	 time= 26.230632527669272 min 	 lr= 0.1 	 loss= 0.15471203918357126 	 error= 5.285767838358879 percent
error rate on dev set = 27.12890625 percent
 
epoch= 264 	 time= 26.33112492163976 min 	 lr= 0.1 	 loss= 0.13753402752759444 	 error= 4.6157374470071355 percent
error rate on dev 

epoch= 309 	 time= 31.47070728937785 min 	 lr= 0.1 	 loss= 0.13529895801647482 	 error= 4.644590429961681 percent
error rate on dev set = 26.8359375 percent
 
epoch= 310 	 time= 31.584053138891857 min 	 lr= 0.1 	 loss= 0.13437528620389375 	 error= 4.487255368043075 percent
error rate on dev set = 26.2890625 percent
 
epoch= 311 	 time= 31.690066119035084 min 	 lr= 0.1 	 loss= 0.13994649111885915 	 error= 4.6860203654928645 percent
error rate on dev set = 25.917968749999996 percent
 
epoch= 312 	 time= 31.795689260959627 min 	 lr= 0.1 	 loss= 0.14478504204783926 	 error= 4.954081824557348 percent
error rate on dev set = 26.425781250000004 percent
 
epoch= 313 	 time= 31.917807006835936 min 	 lr= 0.1 	 loss= 0.15346070435638962 	 error= 5.271464637057348 percent
error rate on dev set = 25.5859375 percent
 
epoch= 314 	 time= 32.042115437984464 min 	 lr= 0.1 	 loss= 0.14314421966925941 	 error= 4.776031794873151 percent
error rate on dev set = 27.65625 percent
 
epoch= 315 	 time= 32.1664

error rate on dev set = 27.675781249999996 percent
 
epoch= 360 	 time= 37.04141092300415 min 	 lr= 0.1 	 loss= 0.12555393829560754 	 error= 4.341017590327696 percent
error rate on dev set = 26.210937499999996 percent
 
epoch= 361 	 time= 37.146773278713226 min 	 lr= 0.1 	 loss= 0.13140583478591658 	 error= 4.378748414191333 percent
error rate on dev set = 25.87890625 percent
 
epoch= 362 	 time= 37.25188084046046 min 	 lr= 0.1 	 loss= 0.1397622749209404 	 error= 4.685033925554969 percent
error rate on dev set = 25.624999999999996 percent
 
epoch= 363 	 time= 37.35673709313075 min 	 lr= 0.1 	 loss= 0.12777937699998307 	 error= 4.3503886427391665 percent
error rate on dev set = 26.796874999999996 percent
 
epoch= 364 	 time= 37.462707277139025 min 	 lr= 0.1 	 loss= 0.14352125350639902 	 error= 4.80636442926797 percent
error rate on dev set = 26.328125000000004 percent
 
epoch= 365 	 time= 37.57619525591532 min 	 lr= 0.1 	 loss= 0.12847424390598794 	 error= 4.270488074557348 percent
erro

error rate on dev set = 25.703125 percent
 


In [None]:
# import matplotlib.pyplot as plt
# def show(X):
#     if X.dim() == 3 and X.size(0) == 3:
#         plt.imshow( np.transpose(  X.numpy() , (1, 2, 0))  )
#         plt.show()
#     elif X.dim() == 2:
#         plt.imshow(   X.numpy() , cmap='gray'  )
#         plt.show()
#     else:
#         print('WRONG TENSOR SIZE')

In [None]:
# def show_prob_cifar(p):


#     p=p.data.squeeze().numpy()

#     ft=15
#     label = ('airplane', 'automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship','Truck' )
#     #p=p.data.squeeze().numpy()
#     y_pos = np.arange(len(p))*1.2
#     target=2
#     width=0.9
#     col= 'blue'
#     #col='darkgreen'

#     plt.rcdefaults()
#     fig, ax = plt.subplots()

#     # the plot
#     ax.barh(y_pos, p, width , align='center', color=col)

#     ax.set_xlim([0, 1.3])
#     #ax.set_ylim([-0.8, len(p)*1.2-1+0.8])

#     # y label
#     ax.set_yticks(y_pos)
#     ax.set_yticklabels(label, fontsize=ft)
#     ax.invert_yaxis()  
#     #ax.set_xlabel('Performance')
#     #ax.set_title('How fast do you want to go today?')

#     # x label
#     ax.set_xticklabels([])
#     ax.set_xticks([])
#     #x_pos=np.array([0, 0.25 , 0.5 , 0.75 , 1])
#     #ax.set_xticks(x_pos)
#     #ax.set_xticklabels( [0, 0.25 , 0.5 , 0.75 , 1] , fontsize=15)

#     ax.spines['right'].set_visible(False)
#     ax.spines['top'].set_visible(False)
#     ax.spines['bottom'].set_visible(False)
#     ax.spines['left'].set_linewidth(4)


#     for i in range(len(p)):
#         str_nb="{0:.2f}".format(p[i])
#         ax.text( p[i] + 0.05 , y_pos[i] ,str_nb ,
#                  horizontalalignment='left', verticalalignment='center',
#                  transform=ax.transData, color= col,fontsize=ft)
#     plt.show()

In [None]:
# # choose a picture at random
# from random import randint
# idx=randint(0, 10000-1)
# im=test_data[idx]

# # diplay the picture
# show(im)

# # send to device, rescale, and view as a batch of 1 
# im = im.to(device)
# im= (im-mean) / std
# im=im.view(1,3,32,32)

# # feed it to the net and display the confidence scores
# scores =  model(im) 
# probs= F.softmax(scores, dim=1)
# show_prob_cifar(probs.cpu())