In [150]:
import numpy as np
import torch
import torch.nn as nn # nn class becomes the parent class later when defining the model
import torch.nn.functional as F # functional class consists of various methods of loss functions
import torch.utils.data as data # Used to load dataset
import torchvision
from torchvision import transforms # used to perform various transformations on the input image
import matplotlib.pyplot as plt 
import torch.optim as optim

In [151]:
# transforms.Compose - it composes several transforms together
transform_img = transforms.Compose([
    transforms.Resize((30, 30)), # resized to 100 * 100
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],   # input[channel] = (input[channel] - mean[channel]) / std[channel]
                        std=[0.5, 0.5, 0.5])])

In [152]:
train_data_path = "/home/sayan/Desktop/Flower Classification/train/"
test_data_path = "/home/sayan/Desktop/Flower Classification/test/"

In [153]:
train_data = torchvision.datasets.ImageFolder(root = train_data_path, transform = transform_img)
# num_workers means the number of sub-processes employed in order to load the data. Optimum - 2, 4
train_data_loader = data.DataLoader(train_data, batch_size = 32, shuffle=True,  num_workers=4)

test_data = torchvision.datasets.ImageFolder(root = test_data_path, transform = transform_img)
test_data_loader  = data.DataLoader(test_data, batch_size = 32, shuffle=True, num_workers=4) 

In [154]:
class MyNetwork(nn.Module):
    # Class MyNetwork will inherit nn.Module
    def __init__(self):
        super().__init__();
        # super().__init__() will call the constructor of nn.Module. The parameters of the methods
        # that we will use later (of nn class) will then get initialized
        self.conv1 = nn.Conv2d(3, 6, 5); # where 5 is the size of the filter, 3 - input channels, 6 - o/p channels
        self.pool = nn.MaxPool2d(2, 2);
        self.conv2 = nn.Conv2d(6, 10, 5);
        self.fc1 = nn.Linear(4 * 4 * 10, 120); # Input units, o/p units in layer
        self.fc2 = nn.Linear(120, 100);
        self.fc3 = nn.Linear(100, 5);
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)));
        x = self.pool(F.relu(self.conv2(x)));
        x = x.view(-1, 4 * 4 * 10); # in pytorch, images are of the form [number of channels, height, width]
        x = F.relu(self.fc1(x));
        x = F.relu(self.fc2(x));
        x = self.fc3(x);
        return x

mynetwork = MyNetwork(); # Creating an instance of the above defined class

In [155]:
# Here we define which type of loss function we'll use and what type of optimization we will do
criterion = nn.CrossEntropyLoss() # -ylog(y_) for softmax
optimizer = optim.Adam(mynetwork.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) # Momentum is generally taken as 0.9
# When net.parameters() is called, automatically requires_grad is set True for the parameters in order to calculate 
# their gradients

In [161]:
for epoch in range(20): 
    for i, data in enumerate(train_data_loader):
        inputs, labels = data;
        optimizer.zero_grad(); # Zero the parameter gradeints so that in the next cycle, those values don't create
        # trouble
        outputs = mynetwork.forward(inputs); # Passes inputs to the forward method inside MyNetwork clas
        loss = criterion(outputs, labels);
        loss.backward(); # Tells that we have to calculate the gradient of Loss w.r.t the parameters and the gradients
        # are stored in the .grad property of the parameters
        optimizer.step(); # This leads to parameter update
        print(loss.item()); # Prints the scalar value of loss
    # As the training data is divided into mini-batches, per 32 inputs,
    # the cost will be calculated and then updated.

0.7348100543022156
0.959075391292572
0.9521585702896118
0.8821877837181091
0.8556429743766785
0.7477111220359802
0.8863551616668701
0.8528202176094055
0.9524240493774414
0.9511318802833557
0.8443264961242676
0.853919506072998
0.9461005926132202
0.7295598983764648
1.041833758354187
0.9568425416946411
1.0941493511199951
1.0970656871795654
1.0236526727676392
0.8804437518119812
0.9120051264762878
0.7824024558067322
1.0916215181350708
0.8814027309417725
1.1785292625427246
1.154835820198059
0.7599840760231018
0.9520627856254578
1.029695987701416
0.8250048160552979
0.9713821411132812
1.0274405479431152
0.962011456489563
0.8962452411651611
0.9573663473129272
0.7289483547210693
0.7892572283744812
0.8012480139732361
1.05342698097229
0.9824444055557251
0.7906793355941772
0.9337444305419922
0.9091656804084778
0.755405843257904
0.8689055442810059
0.8882148861885071
0.8765568733215332
1.0377211570739746
1.0793445110321045
1.0742703676223755
1.0170527696609497
0.912051796913147
0.8788064122200012
0.8

0.7643615007400513
0.7157769799232483
0.7535268664360046
0.8335162997245789
0.8626967668533325
0.7333217859268188
0.9245153665542603
0.7373880743980408
0.7830588221549988
0.9849446415901184
0.6154900193214417
0.7663710117340088
0.5970536470413208
0.9001122713088989
0.704396665096283
0.8114616870880127
0.9176390171051025
0.7161515355110168
0.8332082629203796
0.7779104709625244
0.7450958490371704
0.7674359083175659
0.49307942390441895
0.9405238032341003
0.6484438180923462
0.7179285883903503
1.2183551788330078
0.7637178897857666
0.8216544389724731
0.8484858274459839
0.7250684499740601
0.817493200302124
0.8358917832374573
0.6524543166160583
0.8802249431610107
0.6873423457145691
0.6073309779167175
0.627151370048523
0.6266449093818665
0.787632942199707
0.6024858355522156
0.7991888523101807
0.8539181351661682
0.5325865745544434
0.7986096739768982
0.7316000461578369
0.702864408493042
0.815059244632721
0.7199887633323669
0.8822311162948608
0.7518616318702698
0.8572859168052673
0.805166125297546

0.5672928690910339
0.6528237462043762
0.7542437314987183
0.5570303797721863
0.7862213850021362
0.5605388879776001
0.8546202778816223
0.47847968339920044
0.5987043976783752
0.6724840998649597
0.7677370309829712
0.6718410849571228
0.6125321984291077
0.5174435973167419
0.44380778074264526
0.5866560935974121
0.676950991153717
0.5687949657440186
0.5132310390472412
0.6094343066215515
0.5618751049041748
0.8883362412452698
0.6780236959457397
0.7096459865570068
0.8031861782073975
1.022883653640747
0.7970297336578369
0.5248896479606628
0.6926891803741455
0.5863466262817383
0.5829977989196777
0.9391622543334961
0.6583234071731567
0.44145599007606506
0.5779972672462463
0.7220797538757324
0.6635840535163879
0.7475559115409851
0.5876967906951904
0.5684104561805725
0.4904206395149231
0.6633884906768799
1.1082431077957153
0.9061547517776489
0.8501537442207336
0.566076397895813
0.5975976586341858
0.6028360724449158
0.6857886910438538
0.7373375296592712
0.8937684893608093
0.91100013256073
0.766557633876

0.6337586045265198
0.5043509602546692
0.8229944705963135
0.6301050782203674
0.33040568232536316
0.5258469581604004
0.737861692905426
0.7015892863273621
0.542921245098114
0.5257441997528076
0.6675320267677307
0.6080476641654968
0.5106927156448364
0.5586146116256714
0.6068840622901917
0.7061436176300049
0.6661854386329651
0.3008987009525299
0.43685513734817505
0.297700434923172
0.4572077989578247
0.5878034234046936
0.6464812755584717
0.3486301302909851
0.6300308704376221
0.5705858469009399
0.5699657201766968
0.5193184614181519
0.6079363822937012
0.6641585230827332
0.8727771639823914
0.6876307725906372
0.7785288691520691
0.5228519439697266
0.561315655708313
0.39274728298187256
0.6516457200050354
0.3908279836177826
0.6229020357131958
0.8194496035575867
0.49807533621788025
0.4149492681026459
0.9130322337150574
0.48817044496536255
0.4628896713256836
0.6159939765930176
0.38407161831855774
0.6359452605247498
0.44505515694618225
0.5559682250022888
0.5414877533912659
0.55812007188797
0.789419651

0.5210281610488892
0.5545571446418762
0.49206840991973877
0.522364616394043
0.5412177443504333
0.4185607135295868
0.3573162257671356
0.6115187406539917
0.5304718017578125
0.4346909523010254
0.3868936598300934
0.371734082698822
0.6125712394714355
0.4158487617969513
0.41674649715423584
0.41766536235809326
0.42606431245803833
0.394996702671051
0.5556995868682861
0.5818443298339844
0.381769597530365
0.3749253451824188
0.427442729473114
0.5260156989097595
0.5734264850616455
0.6378060579299927
0.5080869793891907
0.36352184414863586
0.411685049533844
0.3310929834842682
0.44830870628356934
0.5042505860328674
0.4100363552570343
0.3751349151134491
0.3864873945713043
0.3730553090572357
0.4000939130783081
0.2931656539440155
0.39569759368896484
0.36730918288230896
0.5784733891487122
0.40990883111953735
0.46537283062934875
0.3733619749546051
0.5478335618972778
0.4376915693283081
0.4899565577507019
0.738584041595459
0.4614816904067993
0.6648389101028442
0.3136909604072571
0.6452866792678833
0.3209633

In [169]:
# Saving the model
PATH = '/home/sayan/Desktop/flower.pth';
torch.save(mynetwork.state_dict(), PATH); # A state_dict is simply a Python dictionary 
# object that maps each layer to its parameter tensor. 

# Loading the data
mynetwork = MyNetwork()
mynetwork.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [184]:
def corr(x): # To get proper correspondence between the outputs and the labels
    x = x.detach().numpy(); # Detach() was used as one can't convert a pytorch tensor to a numpy array if
    # required_grad is set True for that variable
    x = x.argmax(axis = 1);
    return x;

In [185]:
correct = 0;
total = 0;
for i, data in enumerate(test_data_loader):
    inputs, labels = data;
    outputs = mynetwork.forward(inputs);
    outputs = corr(outputs);
    labels = labels.detach().numpy();
    for i in range(len(outputs)):
        total = total + 1;
        if outputs[i] == labels[i]:
            correct = correct + 1;
print("Hence, the test set accuracy is ", (correct/total) * 100);

Hence, the test set accuracy is  58.077709611451944
