# Training an image classifier

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms

In [3]:
# transform
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)


trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)


testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)


testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)


classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')





Files already downloaded and verified
Files already downloaded and verified


In [4]:
import matplotlib.pyplot as plt
import numpy as np

# show an image 
def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(trainloader)   
images , labels  = dataiter.next()
print(images)
print(labels)
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

tensor([[[[ 0.1529,  0.4667,  0.5843,  ...,  0.1294,  0.0745,  0.1922],
          [ 0.0275,  0.1451,  0.2314,  ...,  0.1608,  0.0667,  0.2392],
          [ 0.0431,  0.0902,  0.1843,  ...,  0.2549,  0.0588,  0.2706],
          ...,
          [ 0.4980,  0.1373,  0.1765,  ..., -0.0118, -0.0745, -0.1294],
          [ 0.6549,  0.3020,  0.1922,  ..., -0.0588, -0.1059, -0.1137],
          [ 0.6157,  0.3725,  0.1137,  ..., -0.0510, -0.1451, -0.1451]],

         [[ 0.1294,  0.4980,  0.5843,  ...,  0.1059,  0.0275,  0.1451],
          [-0.0039,  0.1843,  0.2314,  ...,  0.1373,  0.0196,  0.1922],
          [ 0.0118,  0.1216,  0.1843,  ...,  0.2314,  0.0118,  0.2157],
          ...,
          [ 0.4980,  0.0588,  0.0980,  ..., -0.1608, -0.2000, -0.2549],
          [ 0.6627,  0.2235,  0.1137,  ..., -0.2078, -0.2314, -0.2471],
          [ 0.6157,  0.2941,  0.0431,  ..., -0.2000, -0.2706, -0.2706]],

         [[ 0.0902,  0.4431,  0.5451,  ..., -0.0039, -0.0902,  0.0431],
          [-0.0353,  0.1294,  

<Figure size 640x480 with 1 Axes>

 deer   car  ship plane


In [5]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    
    def __init__(self):
        super(Net,self).__init__()
        '''
        torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
        '''
        self.conv1 = nn.Conv2d(3,6,5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        
net = Net()
print(net)

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [6]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [None]:
for epoch in range(2): # loop over the dataset multiple times
    print("oepoch :",epoch+1)
    running_loss = 0.0
    for i,data in enumerate(trainloader,0):
        inputs,labels = data
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        outputs = net(inputs)
        
        loss= criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
#         print(loss.item())
        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')
    

oepoch : 1
2.304302215576172
2.3013668060302734
2.3014652729034424
2.26572847366333
2.328603506088257
2.312377691268921
2.2595434188842773
2.274139404296875
2.2823903560638428
2.2629313468933105
2.3072242736816406
2.293766975402832
2.23091983795166
2.2976596355438232
2.3276047706604004
2.2932534217834473
2.2672972679138184
2.3319895267486572
2.2888858318328857
2.259219169616699
2.322108268737793
2.3462047576904297
2.2869250774383545
2.310164451599121
2.3114025592803955
2.264843463897705
2.3111791610717773
2.335618734359741
2.278434991836548
2.3215675354003906
2.344456672668457
2.27428936958313
2.303309440612793
2.2824289798736572
2.3138864040374756
2.333552598953247
2.305950164794922
2.31304931640625
2.2819972038269043
2.3110079765319824
2.2816085815429688
2.285168170928955
2.280825614929199
2.2987618446350098
2.349733829498291
2.2897303104400635
2.28242826461792
2.3170323371887207
2.2744343280792236
2.2863502502441406
2.3270630836486816
2.354931116104126
2.258035898208618
2.3267140388

2.2540175914764404
2.25264048576355
2.282881498336792
2.2775702476501465
2.301987648010254
2.3180274963378906
2.255504608154297
2.2859482765197754
2.327174186706543
2.269033432006836
2.332383632659912
2.2946267127990723
2.3230934143066406
2.291455030441284
2.326418876647949
2.3086233139038086
2.31907320022583
2.3164749145507812
2.3578057289123535
2.292768955230713
2.3495593070983887
2.2821011543273926
2.2970025539398193
2.264727830886841
2.3052051067352295
2.2308309078216553
2.329159736633301
2.28554105758667
2.295492172241211
2.2979907989501953
2.3106932640075684
2.2647557258605957
2.24389386177063
2.308950424194336
2.289458751678467
2.320801019668579
2.358611822128296
2.317099094390869
2.3027536869049072
2.320436954498291
2.314656972885132
2.318016767501831
2.3259449005126953
2.2786173820495605
2.289088249206543
2.3700320720672607
2.2559165954589844
2.3050832748413086
2.3316197395324707
2.2955784797668457
2.281986713409424
2.309156656265259
2.2301270961761475
2.2678580284118652
2.314

2.2515769004821777
2.0932812690734863
2.3036136627197266
2.3528037071228027
2.353537082672119
2.4218430519104004
2.358306646347046
2.296870231628418
2.1248714923858643
2.2560293674468994
2.2282021045684814
2.295588493347168
2.356796979904175
2.219930410385132
2.300372362136841
2.3309006690979004
2.296473264694214
2.375267744064331
2.342180013656616
2.2402164936065674
2.12402606010437
2.2529985904693604
2.2727112770080566
2.3405303955078125
2.2931249141693115
2.3190104961395264
2.312939167022705
2.3286821842193604
2.3136465549468994
2.355393648147583
2.2770445346832275
2.2555623054504395
1.9755258560180664
2.3489198684692383
2.2356247901916504
2.1612548828125
2.24257230758667
2.2162492275238037
2.2977936267852783
2.299398899078369
2.3056397438049316
2.2997653484344482
2.364225149154663
2.347296953201294
2.226200580596924
2.37087345123291
2.2542457580566406
2.2913174629211426
2.264554262161255
2.1706418991088867
2.375861167907715
2.2377374172210693
2.2986724376678467
2.197077751159668
2.

2.2160959243774414
2.2386820316314697
2.096029281616211
2.0237815380096436
2.1943774223327637
2.041764736175537
1.415391445159912
1.973719835281372
2.082632303237915
1.4458204507827759
1.9861176013946533
2.5553979873657227
2.3667330741882324
2.212695837020874
2.317192554473877
2.2784719467163086
1.8304474353790283
1.8450120687484741
2.523158073425293
2.014669179916382
1.8310611248016357
2.0985665321350098
2.10356068611145
2.319204330444336
2.1912546157836914
1.9701621532440186
2.1438441276550293
2.1097538471221924
2.3778557777404785
2.1338820457458496
2.1231207847595215
1.8889384269714355
1.8527133464813232
2.1388676166534424
1.8640332221984863
2.226893901824951
1.8538320064544678
2.271495819091797
2.201324939727783
2.124542236328125
2.1013662815093994
2.1310811042785645
2.0257649421691895
1.9618966579437256
2.1914732456207275
2.0553228855133057
2.3545329570770264
2.444798707962036
2.053658962249756
1.9956028461456299
2.3499953746795654
2.0946364402770996
1.9751604795455933
2.581478595

1.9412932395935059
2.0552444458007812
2.1294569969177246
1.8496125936508179
2.130091428756714
1.6540923118591309
2.4464473724365234
2.6690685749053955
2.538860559463501
2.1313791275024414
2.209733486175537
1.9894444942474365
2.0418243408203125
2.0491111278533936
2.5391793251037598
2.1262621879577637
2.0187926292419434
1.7852063179016113
1.74172043800354
1.629343032836914
2.5489888191223145
1.9364657402038574
1.8724515438079834
2.3447155952453613
1.765293836593628
1.9912095069885254
2.761138677597046
1.9931235313415527
1.9537911415100098
2.097593069076538
1.830399751663208
2.1630210876464844
2.7212343215942383
2.078726291656494
2.1513805389404297
1.9913181066513062
2.267728328704834
2.168948173522949
2.1709272861480713
2.0872130393981934
2.2514047622680664
1.6208394765853882
2.0912837982177734
2.1341042518615723
1.967591643333435
2.180278778076172
2.066861629486084
1.9292610883712769
2.267798900604248
2.467257261276245
2.079312562942505
2.3720927238464355
2.6602730751037598
2.0923089981

1.5040826797485352
2.118340015411377
1.857600212097168
2.603004217147827
2.389463424682617
1.5991367101669312
2.5354576110839844
2.75386905670166
1.745308756828308
2.0710904598236084
2.2763373851776123
1.8446450233459473
1.7489330768585205
2.2838077545166016
2.1411020755767822
2.180108070373535
2.3891046047210693
1.6717922687530518
2.2271347045898438
1.3914525508880615
1.868302583694458
1.6993056535720825
1.732483148574829
2.1016883850097656
2.0371508598327637
2.059046983718872
1.525623083114624
2.9514102935791016
2.098634958267212
1.3441298007965088
1.8614990711212158
2.771939277648926
2.038499116897583
2.093860626220703
2.193694591522217
2.276064872741699
2.1290347576141357
1.9107543230056763
1.8321462869644165
1.1649872064590454
2.0385794639587402
1.8526852130889893
2.4070043563842773
2.1994130611419678
1.580725908279419
1.702668309211731
2.5299201011657715
1.5784788131713867
2.1483192443847656
1.915101170539856
2.398566484451294
1.9244287014007568
2.389679193496704
1.68173623085021

2.487229585647583
1.960618019104004
2.725994825363159
2.1041438579559326
1.9123482704162598
2.1976206302642822
1.6912955045700073
2.01678466796875
2.693906307220459
2.395582437515259
2.594059944152832
2.014827013015747
2.1054623126983643
1.917040467262268
2.0741939544677734
2.5128276348114014
1.8550076484680176
1.9426554441452026
2.0697073936462402
2.4069619178771973
2.350100517272949
1.9220638275146484
2.3903462886810303
2.0988969802856445
1.8536596298217773
1.7540221214294434
1.7346229553222656
1.6784900426864624
2.004012107849121
2.331522226333618
1.506066918373108
1.9517905712127686
1.8968833684921265
2.156597852706909
2.137336254119873
1.7959775924682617
1.887012004852295
1.663769245147705
2.0438857078552246
1.667121171951294
2.251525402069092
1.6715447902679443
1.9465281963348389
1.8856940269470215
1.406828761100769
1.982750654220581
1.9134471416473389
1.6832441091537476
1.789123773574829
1.2211109399795532
2.078857421875
2.0149295330047607
1.8778963088989258
1.8060449361801147
2

1.4850444793701172
1.5203099250793457
1.557441234588623
1.5955474376678467
1.804128646850586
1.8494848012924194
1.5080430507659912
1.7228690385818481
1.484904170036316
0.984092116355896
1.546126127243042
2.4383792877197266
2.0434272289276123
1.444270372390747
1.7181479930877686
2.455122232437134
1.5176783800125122
1.4957103729248047
1.3625355958938599
1.9515197277069092
1.3605576753616333
1.8322161436080933
1.963118076324463
1.5507328510284424
1.635568380355835
1.77876877784729
2.3269741535186768
1.444800853729248
1.095674991607666
2.6750409603118896
1.7838040590286255
1.7086572647094727
1.3550596237182617
1.3468290567398071
1.9404616355895996
0.8616194128990173
2.8870303630828857
2.2418816089630127
1.8188321590423584
1.4042736291885376
1.275775671005249
1.2977290153503418
2.2792134284973145
1.6414921283721924
1.9870476722717285
1.7986392974853516
1.5326268672943115
1.2960147857666016
2.4020884037017822
1.6522860527038574
1.2025728225708008
1.2723078727722168
1.2738407850265503
1.74657

1.533280372619629
1.7571786642074585
2.0042920112609863
2.2158493995666504
1.8483352661132812
1.3673676252365112
1.596632719039917
2.5794124603271484
1.9340229034423828
1.6360249519348145
1.7530453205108643
3.2102370262145996
1.4640202522277832
1.0475733280181885
1.543553352355957
1.9405765533447266
1.7255967855453491
1.817265510559082
1.4398643970489502
1.712399959564209
2.52303147315979
1.521236777305603
1.8033485412597656
1.7338757514953613
1.8460891246795654
2.056245803833008
1.765760898590088
2.0746922492980957
1.8855774402618408
2.2715752124786377
1.3543699979782104
2.915693759918213
1.9869928359985352
1.4372494220733643
0.9888417720794678
2.4781367778778076
1.1749355792999268
1.9668517112731934
1.9827487468719482
2.4777348041534424
1.8992905616760254
1.61642324924469
1.5756027698516846
1.8989254236221313
1.1907058954238892
1.9745674133300781
1.8216052055358887
1.5892504453659058
1.971139669418335
1.3682215213775635
2.173758029937744
1.7135899066925049
1.9088404178619385
1.263183

2.08040189743042
1.5055617094039917
1.9689040184020996
1.9830915927886963
1.6832408905029297
1.4404690265655518
2.236403465270996
1.687516450881958
1.8031634092330933
2.133333683013916
2.3131892681121826
1.8499099016189575
1.4360778331756592
2.485722064971924
2.4990320205688477
1.444752812385559
1.7028913497924805
1.6308962106704712
1.6127212047576904
1.8596343994140625
1.6346169710159302
1.9024909734725952
1.6290593147277832
2.4702069759368896
1.728471279144287
1.9440526962280273
1.523266315460205
1.7694878578186035
2.253364324569702
1.877744436264038
2.888265609741211
1.1792162656784058
2.7963836193084717
1.5238096714019775
1.8401739597320557
2.4850566387176514
1.8085112571716309
1.008802056312561
2.0528998374938965
1.4954798221588135
1.1325948238372803
1.224321961402893
1.9335265159606934
2.416344165802002
2.2238359451293945
2.0950119495391846
1.506388783454895
1.5249860286712646
1.7351210117340088
2.3799173831939697
1.9368237257003784
1.3459320068359375
2.480085849761963
1.83054184