<a href="https://colab.research.google.com/github/yussif-issah/css54FinalProject/blob/main/mangrove_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,models,datasets
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
NUM_CLASSES = 3
TRAINING_DATA_PATH = "train_files/train_files"
LABELS_FILE_PATH = "labels.txt"
TEST_DATA_PATH = "/content/drive/MyDrive/testDataFolder"

In [None]:
def buildModel(nn,fc_layer1,drop_out):
    model = models.resnet50(pretrained=True)
    head = nn.Sequential(
            nn.Linear(model.fc.in_features, fc_layer1),
            nn.ReLU(),
            nn.Dropout(drop_out),
            nn.Linear(fc_layer1, NUM_CLASSES)
      )

    model.fc = head
    return model.to(device)

In [None]:
def buildDataLoader(batchSize,image_folder, label_file):
    ttransform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=30),
    transforms.ColorJitter(brightness=0.2,contrast=0.2,saturation=0.2,hue=0.1),
    transforms.ToTensor(),

    ])
    whole_data = CustomDataset(image_folder,label_file,ttransform)

    generator1 = torch.Generator().manual_seed(42)
    generator2 = torch.Generator().manual_seed(32)



    training_validation = int(0.9*len(whole_data))
    test_size = len(whole_data) - training_validation

    training_size = int(0.8*training_validation)
    val_size = training_validation - training_size

    training_validation_dataset,test_data= random_split(whole_data,[training_validation , test_size], generator=generator1)
    train_data, val_data = random_split(training_validation_dataset,[training_size, val_size], generator=generator2)


    train_data_loader = DataLoader(train_data,batch_size=batchSize,shuffle=True)
    val_data_loader = DataLoader(val_data,batch_size=batchSize,shuffle=True)
    test_data_loader = DataLoader(test_data,batch_size=24,shuffle=True)

    image_datasets = {"train":train_data,"val":val_data,"test":test_data}
    data_loaders = {"train": train_data_loader,"val":val_data_loader,"test":test_data_loader}

    return data_loaders,image_datasets

In [None]:
def buildOptimizer(model,learning_rate):
    #return optim(model.parameters(),lr=learning_rate)
    return optim.SGD(model.parameters(),
                              lr=learning_rate, momentum=0.9)


In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
def trainModel(epochs,batch_size,dropout,fc_layer,learning_rate,train_data_path,label_path):

    data_loaders,datasets= buildDataLoader(batch_size,train_data_path,label_path)
    model = buildModel(torch.nn,fc_layer,dropout)
    optimizer = buildOptimizer(model,learning_rate)

    for epoch in range(epochs):

      for phase in ["train","val"]:

        if phase == "train":
          model.train()
        else:
          model.eval()

        running_loss = 0
        correct = 0

        for inputs,labels in data_loaders[phase]:
          inputs,labels = inputs.to(device),labels.to(device)

          optimizer.zero_grad()

          with torch.set_grad_enabled(phase=="train"):
            outputs = model(inputs)
            _,preds = torch.max(outputs,1)
            loss = criterion(outputs,labels)

            if phase == "train":
              loss.backward()
              optimizer.step()

          running_loss += loss.item()*inputs.size(0)
          correct += torch.sum(preds ==labels)

          epoch_loss = running_loss / len(datasets[phase])
          epoch_acc = correct.double() / len(datasets[phase])

          if phase == "train":
            print(f"epoch: {epoch+1} training loss {epoch_loss}, training accuracy {epoch_acc}")
          else:
            print(f"epoch: {epoch+1} validation loss {epoch_loss}, validation accuracy {epoch_acc}")
    torch.save(model.state_dict(),"resnet.pth")
    return data_loaders,model

In [None]:
data_loaders,model = trainModel(20,32,0.4,256,0.001640,TRAINING_DATA_PATH,LABELS_FILE_PATH) #20,48,0.5,256,0.001640

epoch: 1 training loss 0.028487378161100415, training accuracy 0.0
epoch: 1 training loss 0.05689232704198756, training accuracy 0.0
epoch: 1 training loss 0.0850008417644772, training accuracy 0.0
epoch: 1 training loss 0.11317041767716973, training accuracy 0.0
epoch: 1 training loss 0.14129135638051687, training accuracy 0.0
epoch: 1 training loss 0.1692509858190166, training accuracy 0.0
epoch: 1 training loss 0.19726058309112116, training accuracy 0.0001895734597156398
epoch: 1 training loss 0.22506057811574348, training accuracy 0.0001895734597156398
epoch: 1 training loss 0.25269641749666766, training accuracy 0.0001895734597156398
epoch: 1 training loss 0.28046263202106786, training accuracy 0.0001895734597156398
epoch: 1 training loss 0.3079006393938833, training accuracy 0.0009478672985781991
epoch: 1 training loss 0.3355031784563833, training accuracy 0.0013270142180094786
epoch: 1 training loss 0.36254647928391587, training accuracy 0.0022748815165876774
epoch: 1 training l

epoch: 1 training loss 2.4189534755905657, training accuracy 0.13232227488151657
epoch: 1 training loss 2.435259487640236, training accuracy 0.13516587677725117
epoch: 1 training loss 2.4556573327231748, training accuracy 0.13687203791469194
epoch: 1 training loss 2.4739997147835826, training accuracy 0.13895734597156398
epoch: 1 training loss 2.4946026828277734, training accuracy 0.1404739336492891
epoch: 1 training loss 2.5144612027570536, training accuracy 0.14255924170616113
epoch: 1 training loss 2.5348122330073495, training accuracy 0.14445497630331752
epoch: 1 training loss 2.554265366685334, training accuracy 0.1461611374407583
epoch: 1 training loss 2.574148627547856, training accuracy 0.1476777251184834
epoch: 1 training loss 2.590279810033138, training accuracy 0.150521327014218
epoch: 1 training loss 2.611499215799485, training accuracy 0.15184834123222749
epoch: 1 training loss 2.629372766865373, training accuracy 0.15393364928909953
epoch: 1 training loss 2.65054434102858

epoch: 2 training loss 0.02094812185278436, training accuracy 0.0015165876777251184
epoch: 2 training loss 0.03863276477108635, training accuracy 0.0037914691943127963
epoch: 2 training loss 0.05628854760626481, training accuracy 0.005876777251184834
epoch: 2 training loss 0.07342095108393809, training accuracy 0.008151658767772511
epoch: 2 training loss 0.09027934720730894, training accuracy 0.010995260663507108
epoch: 2 training loss 0.11192343616937574, training accuracy 0.012511848341232227
epoch: 2 training loss 0.12761829972832123, training accuracy 0.015355450236966824
epoch: 2 training loss 0.14507243947395215, training accuracy 0.01800947867298578
epoch: 2 training loss 0.16001820930372482, training accuracy 0.021232227488151657
epoch: 2 training loss 0.1766943634177836, training accuracy 0.023317535545023697
epoch: 2 training loss 0.1915268754281139, training accuracy 0.026350710900473934
epoch: 2 training loss 0.20601618183732598, training accuracy 0.02938388625592417
epoch:

epoch: 2 training loss 1.7244785225673875, training accuracy 0.2621800947867299
epoch: 2 training loss 1.740419231975248, training accuracy 0.2652132701421801
epoch: 2 training loss 1.7558932466190573, training accuracy 0.267867298578199
epoch: 2 training loss 1.7717255925906212, training accuracy 0.270521327014218
epoch: 2 training loss 1.7869745335872704, training accuracy 0.2735545023696682
epoch: 2 training loss 1.7986063677891735, training accuracy 0.27734597156398105
epoch: 2 training loss 1.8159468751157064, training accuracy 0.27943127962085307
epoch: 2 training loss 1.8299901953240707, training accuracy 0.2824644549763033
epoch: 2 training loss 1.84500167485097, training accuracy 0.28511848341232227
epoch: 2 training loss 1.8618119313027621, training accuracy 0.2870142180094787
epoch: 2 training loss 1.8772243481676725, training accuracy 0.2900473933649289
epoch: 2 training loss 1.895172032360782, training accuracy 0.29251184834123223
epoch: 2 training loss 1.913623608051318, 

epoch: 2 validation loss 1.900886165454046, validation accuracy 0.5837755875663382
epoch: 2 validation loss 1.955879500637098, validation accuracy 0.5966641394996209
epoch: 2 validation loss 1.9649655968063795, validation accuracy 0.6004548900682335
epoch: 3 training loss 0.013322130537711049, training accuracy 0.003033175355450237
epoch: 3 training loss 0.02700426291515477, training accuracy 0.0064454976303317535
epoch: 3 training loss 0.041846439307334865, training accuracy 0.00928909952606635
epoch: 3 training loss 0.05624930286859449, training accuracy 0.012511848341232227
epoch: 3 training loss 0.07189531516124852, training accuracy 0.015355450236966824
epoch: 3 training loss 0.0872756654278362, training accuracy 0.01781990521327014
epoch: 3 training loss 0.09664569438916247, training accuracy 0.022559241706161137
epoch: 3 training loss 0.11148985803974748, training accuracy 0.025402843601895733
epoch: 3 training loss 0.1253811009122297, training accuracy 0.02862559241706161
epoch

epoch: 3 training loss 1.1331872677916035, training accuracy 0.3764928909952607
epoch: 3 training loss 1.1438172207511432, training accuracy 0.38028436018957346
epoch: 3 training loss 1.1560764587546977, training accuracy 0.38350710900473933
epoch: 3 training loss 1.1627330592006304, training accuracy 0.388436018957346
epoch: 3 training loss 1.1714235441379637, training accuracy 0.3929857819905213
epoch: 3 training loss 1.1811482097752286, training accuracy 0.3967772511848341
epoch: 3 training loss 1.1886730309798255, training accuracy 0.40208530805687204
epoch: 3 training loss 1.198526364367155, training accuracy 0.4058767772511848
epoch: 3 training loss 1.209002912259215, training accuracy 0.40928909952606635
epoch: 3 training loss 1.216892395381114, training accuracy 0.41345971563981043
epoch: 3 training loss 1.22474162928866, training accuracy 0.4181990521327014
epoch: 3 training loss 1.2331386645710298, training accuracy 0.42218009478672985
epoch: 3 training loss 1.242740238063143

epoch: 3 validation loss 0.9865730566179508, validation accuracy 0.714177407126611
epoch: 3 validation loss 1.0157389449205607, validation accuracy 0.7323730098559514
epoch: 3 validation loss 1.0418955229556046, validation accuracy 0.7505686125852918
epoch: 3 validation loss 1.0682893930916717, validation accuracy 0.7695223654283548
epoch: 3 validation loss 1.07522426372771, validation accuracy 0.7740712661106899
epoch: 4 training loss 0.010029108852006812, training accuracy 0.004170616113744076
epoch: 4 training loss 0.017192800078911805, training accuracy 0.00909952606635071
epoch: 4 training loss 0.025158220625601673, training accuracy 0.013459715639810426
epoch: 4 training loss 0.0319583520392106, training accuracy 0.01819905213270142
epoch: 4 training loss 0.03985523965121446, training accuracy 0.021990521327014217
epoch: 4 training loss 0.04762029132571831, training accuracy 0.026350710900473934
epoch: 4 training loss 0.06032967228460086, training accuracy 0.030142180094786728
ep

epoch: 4 training loss 0.6779782798730931, training accuracy 0.4390521327014218
epoch: 4 training loss 0.6846092954518106, training accuracy 0.44341232227488153
epoch: 4 training loss 0.6928088161956643, training accuracy 0.4472037914691943
epoch: 4 training loss 0.6966437733003878, training accuracy 0.45213270142180095
epoch: 4 training loss 0.7026532372931169, training accuracy 0.4566824644549763
epoch: 4 training loss 0.7086089986647475, training accuracy 0.461042654028436
epoch: 4 training loss 0.715634247115438, training accuracy 0.46559241706161136
epoch: 4 training loss 0.7255791592710956, training accuracy 0.46938388625592414
epoch: 4 training loss 0.7304051479683104, training accuracy 0.4743127962085308
epoch: 4 training loss 0.735571502034698, training accuracy 0.47962085308056873
epoch: 4 training loss 0.7400029960740799, training accuracy 0.4843601895734597
epoch: 4 training loss 0.7436757088611476, training accuracy 0.48985781990521327
epoch: 4 training loss 0.746595307300

epoch: 4 validation loss 0.5548111530214299, validation accuracy 0.6868840030326004
epoch: 4 validation loss 0.5693293890327644, validation accuracy 0.7081122062168309
epoch: 4 validation loss 0.5872847525977655, validation accuracy 0.7285822592873389
epoch: 4 validation loss 0.5986245605781099, validation accuracy 0.7505686125852918
epoch: 4 validation loss 0.6145583052270064, validation accuracy 0.7710386656557998
epoch: 4 validation loss 0.6277013110608383, validation accuracy 0.7915087187263078
epoch: 4 validation loss 0.6382321407616003, validation accuracy 0.8127369219105383
epoch: 4 validation loss 0.6526051507563732, validation accuracy 0.8339651250947687
epoch: 4 validation loss 0.6558101670682385, validation accuracy 0.8392721758908264
epoch: 5 training loss 0.004333188025307316, training accuracy 0.0051184834123222745
epoch: 5 training loss 0.010105095090459308, training accuracy 0.01004739336492891
epoch: 5 training loss 0.013949968600160137, training accuracy 0.01516587677

epoch: 5 training loss 0.4118827841519179, training accuracy 0.4699526066350711
epoch: 5 training loss 0.4160827427000796, training accuracy 0.47507109004739334
epoch: 5 training loss 0.4193482508365577, training accuracy 0.48075829383886254
epoch: 5 training loss 0.4231594429196905, training accuracy 0.4856872037914692
epoch: 5 training loss 0.42515560692520504, training accuracy 0.4911848341232227
epoch: 5 training loss 0.42919313783329244, training accuracy 0.49668246445497627
epoch: 5 training loss 0.4317719554449145, training accuracy 0.5019905213270142
epoch: 5 training loss 0.4351923439062037, training accuracy 0.5069194312796208
epoch: 5 training loss 0.4396030017436963, training accuracy 0.5120379146919432
epoch: 5 training loss 0.4426788749514033, training accuracy 0.517345971563981
epoch: 5 training loss 0.44592822540427834, training accuracy 0.5222748815165876
epoch: 5 training loss 0.45081026050151807, training accuracy 0.52739336492891
epoch: 5 training loss 0.45437283267

epoch: 5 validation loss 0.31306458061802467, validation accuracy 0.6580742987111448
epoch: 5 validation loss 0.32376564371486427, validation accuracy 0.6785443517816527
epoch: 5 validation loss 0.3359522826749326, validation accuracy 0.6997725549658832
epoch: 5 validation loss 0.3446195496131833, validation accuracy 0.7202426080363912
epoch: 5 validation loss 0.35778315452303827, validation accuracy 0.7399545109931767
epoch: 5 validation loss 0.3688284557276011, validation accuracy 0.7634571645185747
epoch: 5 validation loss 0.37327588661951583, validation accuracy 0.7869598180439727
epoch: 5 validation loss 0.38443297016701977, validation accuracy 0.8081880212282032
epoch: 5 validation loss 0.40544160267364987, validation accuracy 0.8286580742987112
epoch: 5 validation loss 0.4184748528128415, validation accuracy 0.8483699772554966
epoch: 5 validation loss 0.4256238037207707, validation accuracy 0.8703563305534495
epoch: 5 validation loss 0.4303882499822077, validation accuracy 0.894

epoch: 6 training loss 0.2756310477414967, training accuracy 0.4767772511848341
epoch: 6 training loss 0.2781457010603629, training accuracy 0.48227488151658765
epoch: 6 training loss 0.28211182734412604, training accuracy 0.4875829383886256
epoch: 6 training loss 0.2870212508829849, training accuracy 0.4928909952606635
epoch: 6 training loss 0.2896106885394779, training accuracy 0.4985781990521327
epoch: 6 training loss 0.2932008943060563, training accuracy 0.503696682464455
epoch: 6 training loss 0.29752785153863553, training accuracy 0.5082464454976303
epoch: 6 training loss 0.2984488612893633, training accuracy 0.5143127962085308
epoch: 6 training loss 0.3028058240989938, training accuracy 0.5196208530805687
epoch: 6 training loss 0.30461065626822376, training accuracy 0.5254976303317536
epoch: 6 training loss 0.30824697286596797, training accuracy 0.5304265402843602
epoch: 6 training loss 0.31106405104506074, training accuracy 0.5359241706161137
epoch: 6 training loss 0.3147757970

epoch: 6 validation loss 0.19366607619018786, validation accuracy 0.5799848369977255
epoch: 6 validation loss 0.2033727150961158, validation accuracy 0.6004548900682335
epoch: 6 validation loss 0.21037313068699348, validation accuracy 0.623199393479909
epoch: 6 validation loss 0.21748455498054048, validation accuracy 0.645185746777862
epoch: 6 validation loss 0.23029325277720372, validation accuracy 0.667172100075815
epoch: 6 validation loss 0.23936001959852055, validation accuracy 0.6884003032600455
epoch: 6 validation loss 0.24453801743635362, validation accuracy 0.7119029567854435
epoch: 6 validation loss 0.2531494470687415, validation accuracy 0.7338893100833965
epoch: 6 validation loss 0.26534290129348487, validation accuracy 0.7558756633813495
epoch: 6 validation loss 0.2693125543673893, validation accuracy 0.78013646702047
epoch: 6 validation loss 0.28095221664075876, validation accuracy 0.800606520090978
epoch: 6 validation loss 0.2888549361109643, validation accuracy 0.8218347

epoch: 7 training loss 0.19089856450591608, training accuracy 0.4724170616113744
epoch: 7 training loss 0.19250782781302647, training accuracy 0.4782938388625592
epoch: 7 training loss 0.1942437764931629, training accuracy 0.48417061611374407
epoch: 7 training loss 0.19634558049423434, training accuracy 0.48947867298578196
epoch: 7 training loss 0.19826066464609443, training accuracy 0.49516587677725116
epoch: 7 training loss 0.20160331861667724, training accuracy 0.5000947867298579
epoch: 7 training loss 0.20498125727142769, training accuracy 0.5050236966824645
epoch: 7 training loss 0.20604510971720186, training accuracy 0.5110900473933649
epoch: 7 training loss 0.2086056370305789, training accuracy 0.5163981042654028
epoch: 7 training loss 0.21060764981672098, training accuracy 0.5218957345971564
epoch: 7 training loss 0.21167831565531509, training accuracy 0.5275829383886256
epoch: 7 training loss 0.21248488358411743, training accuracy 0.533649289099526
epoch: 7 training loss 0.216

epoch: 7 validation loss 0.13701970221148557, validation accuracy 0.5034116755117514
epoch: 7 validation loss 0.14299906520250621, validation accuracy 0.5261561789234268
epoch: 7 validation loss 0.14991559505101132, validation accuracy 0.5481425322213799
epoch: 7 validation loss 0.15358386462705437, validation accuracy 0.5716451857467778
epoch: 7 validation loss 0.15852123141921406, validation accuracy 0.5943896891584534
epoch: 7 validation loss 0.16488768495873488, validation accuracy 0.6163760424564063
epoch: 7 validation loss 0.17316753453969774, validation accuracy 0.6383623957543594
epoch: 7 validation loss 0.1766763011825365, validation accuracy 0.6618650492797574
epoch: 7 validation loss 0.18095986675005052, validation accuracy 0.6838514025777104
epoch: 7 validation loss 0.1937046907813916, validation accuracy 0.7035633055344959
epoch: 7 validation loss 0.20562380204334504, validation accuracy 0.7247915087187263
epoch: 7 validation loss 0.21778735615612432, validation accuracy 0

epoch: 8 training loss 0.12819207618586825, training accuracy 0.4629383886255924
epoch: 8 training loss 0.12856559434773232, training accuracy 0.4690047393364929
epoch: 8 training loss 0.13040536335859254, training accuracy 0.47450236966824644
epoch: 8 training loss 0.13202248602681815, training accuracy 0.48018957345971564
epoch: 8 training loss 0.13404986058366242, training accuracy 0.4854976303317535
epoch: 8 training loss 0.13590672920100497, training accuracy 0.4911848341232227
epoch: 8 training loss 0.13735683475060487, training accuracy 0.4970616113744076
epoch: 8 training loss 0.1400288678345522, training accuracy 0.5023696682464455
epoch: 8 training loss 0.14179296626863885, training accuracy 0.507867298578199
epoch: 8 training loss 0.1433358612557723, training accuracy 0.5137440758293839
epoch: 8 training loss 0.14398254030688679, training accuracy 0.5198104265402843
epoch: 8 training loss 0.14611898157833877, training accuracy 0.5253080568720379
epoch: 8 training loss 0.1464

epoch: 8 validation loss 0.1056313312261551, validation accuracy 0.4094010614101592
epoch: 8 validation loss 0.10803745966051635, validation accuracy 0.43290371493555724
epoch: 8 validation loss 0.11020349046332625, validation accuracy 0.4564063684609553
epoch: 8 validation loss 0.11392999852218802, validation accuracy 0.4791508718726308
epoch: 8 validation loss 0.11864607491395616, validation accuracy 0.5026535253980288
epoch: 8 validation loss 0.12374787749983851, validation accuracy 0.5238817285822592
epoch: 8 validation loss 0.12847395264984895, validation accuracy 0.5458680818802123
epoch: 8 validation loss 0.1344887689716261, validation accuracy 0.5686125852918877
epoch: 8 validation loss 0.14244703798243455, validation accuracy 0.5905989385898408
epoch: 8 validation loss 0.15456719449110515, validation accuracy 0.6110689916603488
epoch: 8 validation loss 0.1614504223795349, validation accuracy 0.6330553449583017
epoch: 8 validation loss 0.1650249719438994, validation accuracy 0.

epoch: 9 training loss 0.09994819650152849, training accuracy 0.44170616113744077
epoch: 9 training loss 0.10040825825731901, training accuracy 0.4477725118483412
epoch: 9 training loss 0.10148853550590045, training accuracy 0.45364928909952607
epoch: 9 training loss 0.10307462588305721, training accuracy 0.45933649289099526
epoch: 9 training loss 0.10462336079204251, training accuracy 0.46521327014218006
epoch: 9 training loss 0.10580375251046854, training accuracy 0.47127962085308056
epoch: 9 training loss 0.10765673302926158, training accuracy 0.47658767772511845
epoch: 9 training loss 0.11018152259537393, training accuracy 0.48208530805687205
epoch: 9 training loss 0.11176477762195171, training accuracy 0.48796208530805685
epoch: 9 training loss 0.11297951969490232, training accuracy 0.4938388625592417
epoch: 9 training loss 0.11416486609038584, training accuracy 0.49971563981042655
epoch: 9 training loss 0.11539439183275846, training accuracy 0.505781990521327
epoch: 9 training lo

epoch: 9 validation loss 0.07818951454914547, validation accuracy 0.3199393479909022
epoch: 9 validation loss 0.07934706793127863, validation accuracy 0.34420015163002277
epoch: 9 validation loss 0.08290346044767437, validation accuracy 0.36770280515542075
epoch: 9 validation loss 0.085671183826166, validation accuracy 0.3912054586808188
epoch: 9 validation loss 0.09560753966570805, validation accuracy 0.4131918119787718
epoch: 9 validation loss 0.09801317607931696, validation accuracy 0.43745261561789234
epoch: 9 validation loss 0.10472148462531239, validation accuracy 0.46019711902956784
epoch: 9 validation loss 0.11271688981522567, validation accuracy 0.48294162244124333
epoch: 9 validation loss 0.11546954618790911, validation accuracy 0.5072024260803639
epoch: 9 validation loss 0.12050552572419555, validation accuracy 0.5299469294920394
epoch: 9 validation loss 0.12380380762445828, validation accuracy 0.5534495830174374
epoch: 9 validation loss 0.1272284821003472, validation accura

epoch: 10 training loss 0.07066107032423335, training accuracy 0.4153554502369668
epoch: 10 training loss 0.0709896518946824, training accuracy 0.4214218009478673
epoch: 10 training loss 0.07153554659884123, training accuracy 0.42748815165876775
epoch: 10 training loss 0.07244149597900174, training accuracy 0.4329857819905213
epoch: 10 training loss 0.07373805797495549, training accuracy 0.43886255924170614
epoch: 10 training loss 0.07506797488831796, training accuracy 0.444739336492891
epoch: 10 training loss 0.07576088553921306, training accuracy 0.45080568720379144
epoch: 10 training loss 0.07638013813732925, training accuracy 0.45687203791469194
epoch: 10 training loss 0.07770754711322875, training accuracy 0.46255924170616114
epoch: 10 training loss 0.07852955932300802, training accuracy 0.46824644549763034
epoch: 10 training loss 0.0794125737850135, training accuracy 0.47412322274881513
epoch: 10 training loss 0.07998148178037309, training accuracy 0.48018957345971564
epoch: 10 t

epoch: 10 validation loss 0.032493580824683524, validation accuracy 0.1379833206974981
epoch: 10 validation loss 0.039742190337163016, validation accuracy 0.1599696739954511
epoch: 10 validation loss 0.04327132251788668, validation accuracy 0.18271417740712662
epoch: 10 validation loss 0.05016617308610131, validation accuracy 0.2047005307050796
epoch: 10 validation loss 0.056375080206974425, validation accuracy 0.2259287338893101
epoch: 10 validation loss 0.057218007918467026, validation accuracy 0.25018953752843065
epoch: 10 validation loss 0.06277159545528246, validation accuracy 0.27293404094010615
epoch: 10 validation loss 0.06703734054449745, validation accuracy 0.29492039423805916
epoch: 10 validation loss 0.06837464192313078, validation accuracy 0.31842304776345715
epoch: 10 validation loss 0.07438320332715871, validation accuracy 0.34040940106141016
epoch: 10 validation loss 0.07754275489703012, validation accuracy 0.3639120545868082
epoch: 10 validation loss 0.0804226337010612

epoch: 11 training loss 0.05237879359891629, training accuracy 0.36947867298578196
epoch: 11 training loss 0.05284364031389426, training accuracy 0.3753554502369668
epoch: 11 training loss 0.053681706433047614, training accuracy 0.38123222748815166
epoch: 11 training loss 0.05414196163557152, training accuracy 0.3872985781990521
epoch: 11 training loss 0.05483615649254966, training accuracy 0.3933649289099526
epoch: 11 training loss 0.05654069665483954, training accuracy 0.3986729857819905
epoch: 11 training loss 0.05755552680571498, training accuracy 0.4043601895734597
epoch: 11 training loss 0.05820159744877386, training accuracy 0.41042654028436015
epoch: 11 training loss 0.05913408799194047, training accuracy 0.416303317535545
epoch: 11 training loss 0.05941572641309404, training accuracy 0.4223696682464455
epoch: 11 training loss 0.060080524598252714, training accuracy 0.4282464454976303
epoch: 11 training loss 0.06065126486864135, training accuracy 0.43412322274881515
epoch: 11 t

epoch: 11 training loss 0.13845134448666144, training accuracy 0.9681516587677725
epoch: 11 validation loss 0.004713735949188764, validation accuracy 0.02350265352539803
epoch: 11 validation loss 0.009241153653414526, validation accuracy 0.04700530705079606
epoch: 11 validation loss 0.011980512812790136, validation accuracy 0.0712661106899166
epoch: 11 validation loss 0.021113786668466562, validation accuracy 0.0932524639878696
epoch: 11 validation loss 0.025868728361498505, validation accuracy 0.1152388172858226
epoch: 11 validation loss 0.028736729437514992, validation accuracy 0.13874147081122062
epoch: 11 validation loss 0.03022392352843845, validation accuracy 0.16224412433661864
epoch: 11 validation loss 0.034332112076609675, validation accuracy 0.18423047763457165
epoch: 11 validation loss 0.037418491737331014, validation accuracy 0.20697498104624715
epoch: 11 validation loss 0.040649498024160347, validation accuracy 0.22896133434420016
epoch: 11 validation loss 0.04362090208749

epoch: 12 training loss 0.041249417637196764, training accuracy 0.3241706161137441
epoch: 12 training loss 0.04185398902938264, training accuracy 0.3300473933649289
epoch: 12 training loss 0.04206524217298246, training accuracy 0.3361137440758294
epoch: 12 training loss 0.042984981864549535, training accuracy 0.3419905213270142
epoch: 12 training loss 0.04310864540072979, training accuracy 0.3480568720379147
epoch: 12 training loss 0.044108016626529786, training accuracy 0.3537440758293839
epoch: 12 training loss 0.04491700697849147, training accuracy 0.35962085308056874
epoch: 12 training loss 0.045484697118189664, training accuracy 0.3656872037914692
epoch: 12 training loss 0.04627254726762455, training accuracy 0.37156398104265403
epoch: 12 training loss 0.04672015914419816, training accuracy 0.3776303317535545
epoch: 12 training loss 0.04748561756305785, training accuracy 0.38350710900473933
epoch: 12 training loss 0.048248831999810386, training accuracy 0.3893838862559242
epoch: 1

epoch: 12 training loss 0.11297870598133142, training accuracy 0.9270142180094787
epoch: 12 training loss 0.11419731599459716, training accuracy 0.9328909952606634
epoch: 12 training loss 0.1152949964887158, training accuracy 0.9387677725118483
epoch: 12 training loss 0.11566662384435464, training accuracy 0.9448341232227488
epoch: 12 training loss 0.11604397862443427, training accuracy 0.9509004739336493
epoch: 12 training loss 0.11739577300740645, training accuracy 0.9565876777251184
epoch: 12 training loss 0.11830921926204627, training accuracy 0.9622748815165877
epoch: 12 training loss 0.1189639656747122, training accuracy 0.9681516587677725
epoch: 12 training loss 0.11930185574773364, training accuracy 0.9732701421800948
epoch: 12 validation loss 0.004623448152809634, validation accuracy 0.02350265352539803
epoch: 12 validation loss 0.007009193877364037, validation accuracy 0.04700530705079606
epoch: 12 validation loss 0.015173157206803583, validation accuracy 0.06974981046247157


epoch: 13 training loss 0.03146007449706019, training accuracy 0.2896682464454976
epoch: 13 training loss 0.03177961134797589, training accuracy 0.2957345971563981
epoch: 13 training loss 0.0320889681775423, training accuracy 0.30180094786729855
epoch: 13 training loss 0.032342565206554826, training accuracy 0.30786729857819906
epoch: 13 training loss 0.032516483164511584, training accuracy 0.3139336492890995
epoch: 13 training loss 0.03312616299678929, training accuracy 0.32
epoch: 13 training loss 0.033897144862260864, training accuracy 0.32606635071090045
epoch: 13 training loss 0.03406035792771109, training accuracy 0.33213270142180096
epoch: 13 training loss 0.03493289996097438, training accuracy 0.33800947867298575
epoch: 13 training loss 0.035788806207937084, training accuracy 0.34369668246445495
epoch: 13 training loss 0.0362377206409147, training accuracy 0.34976303317535545
epoch: 13 training loss 0.03651531602534073, training accuracy 0.3558293838862559
epoch: 13 training lo

epoch: 13 training loss 0.0887196400267253, training accuracy 0.8828436018957345
epoch: 13 training loss 0.0903009007440359, training accuracy 0.8883412322274882
epoch: 13 training loss 0.09075583322353273, training accuracy 0.8942180094786729
epoch: 13 training loss 0.09136845661000619, training accuracy 0.9000947867298578
epoch: 13 training loss 0.09165659077359602, training accuracy 0.9061611374407583
epoch: 13 training loss 0.09208699054627621, training accuracy 0.9120379146919431
epoch: 13 training loss 0.09261472191290833, training accuracy 0.9179146919431279
epoch: 13 training loss 0.09304655653605529, training accuracy 0.9239810426540284
epoch: 13 training loss 0.09379910206907734, training accuracy 0.9296682464454976
epoch: 13 training loss 0.0939274883722242, training accuracy 0.9357345971563981
epoch: 13 training loss 0.09454036572533196, training accuracy 0.9418009478672985
epoch: 13 training loss 0.09483869527753495, training accuracy 0.9478672985781991
epoch: 13 training 

epoch: 14 training loss 0.019536329264889397, training accuracy 0.23260663507109003
epoch: 14 training loss 0.019756740574588143, training accuracy 0.2386729857819905
epoch: 14 training loss 0.020045915214936316, training accuracy 0.24473933649289098
epoch: 14 training loss 0.020493204062583887, training accuracy 0.25061611374407583
epoch: 14 training loss 0.021096903136556183, training accuracy 0.2564928909952607
epoch: 14 training loss 0.02143737756810482, training accuracy 0.26255924170616113
epoch: 14 training loss 0.02194824512535927, training accuracy 0.26862559241706163
epoch: 14 training loss 0.022058625548936745, training accuracy 0.2746919431279621
epoch: 14 training loss 0.02306490164797453, training accuracy 0.28056872037914693
epoch: 14 training loss 0.023588733661796243, training accuracy 0.2864454976303317
epoch: 14 training loss 0.02407738526285542, training accuracy 0.29251184834123223
epoch: 14 training loss 0.02473641109692542, training accuracy 0.2983886255924171
ep

epoch: 14 training loss 0.06515031525309052, training accuracy 0.8303317535545024
epoch: 14 training loss 0.06580608033455944, training accuracy 0.8362085308056871
epoch: 14 training loss 0.06593492739573474, training accuracy 0.8422748815165877
epoch: 14 training loss 0.06632658774253881, training accuracy 0.8481516587677725
epoch: 14 training loss 0.0664108693430209, training accuracy 0.854218009478673
epoch: 14 training loss 0.06650654639113006, training accuracy 0.8602843601895734
epoch: 14 training loss 0.06733221745604022, training accuracy 0.8661611374407583
epoch: 14 training loss 0.06751690171906169, training accuracy 0.8722274881516587
epoch: 14 training loss 0.06792723799203809, training accuracy 0.8781042654028436
epoch: 14 training loss 0.06806643945910919, training accuracy 0.884170616113744
epoch: 14 training loss 0.06830783628174479, training accuracy 0.8902369668246445
epoch: 14 training loss 0.06873899641760153, training accuracy 0.896303317535545
epoch: 14 training l

epoch: 15 training loss 0.012776044861400297, training accuracy 0.1728909952606635
epoch: 15 training loss 0.013265212169755692, training accuracy 0.17876777251184833
epoch: 15 training loss 0.013594848386484299, training accuracy 0.1848341232227488
epoch: 15 training loss 0.013825106112312932, training accuracy 0.19090047393364928
epoch: 15 training loss 0.014170139287885332, training accuracy 0.19696682464454976
epoch: 15 training loss 0.015133705512042295, training accuracy 0.20284360189573458
epoch: 15 training loss 0.015336708355853908, training accuracy 0.20890995260663506
epoch: 15 training loss 0.01651731503518272, training accuracy 0.2147867298578199
epoch: 15 training loss 0.01706171138591676, training accuracy 0.22066350710900473
epoch: 15 training loss 0.01785070540215732, training accuracy 0.22654028436018958
epoch: 15 training loss 0.01861411310485189, training accuracy 0.23260663507109003
epoch: 15 training loss 0.01908799152238674, training accuracy 0.2386729857819905
e

epoch: 15 training loss 0.056390687328944276, training accuracy 0.7717535545023696
epoch: 15 training loss 0.0566861188863691, training accuracy 0.7778199052132702
epoch: 15 training loss 0.05699829429811776, training accuracy 0.7838862559241706
epoch: 15 training loss 0.057130593773313046, training accuracy 0.789952606635071
epoch: 15 training loss 0.0578046300795406, training accuracy 0.7958293838862559
epoch: 15 training loss 0.05791528414211002, training accuracy 0.8018957345971564
epoch: 15 training loss 0.05832791958375, training accuracy 0.8077725118483412
epoch: 15 training loss 0.05886771027510765, training accuracy 0.813649289099526
epoch: 15 training loss 0.05896437256822089, training accuracy 0.8197156398104265
epoch: 15 training loss 0.05979010139596406, training accuracy 0.8255924170616113
epoch: 15 training loss 0.06011568824826823, training accuracy 0.8316587677725118
epoch: 15 training loss 0.06070774097013248, training accuracy 0.8375355450236966
epoch: 15 training lo

epoch: 16 training loss 0.0073406112533045045, training accuracy 0.12625592417061612
epoch: 16 training loss 0.0077287888357424625, training accuracy 0.13232227488151657
epoch: 16 training loss 0.008472647118907405, training accuracy 0.13819905213270142
epoch: 16 training loss 0.009198260527651457, training accuracy 0.14388625592417062
epoch: 16 training loss 0.009469442768684496, training accuracy 0.1499526066350711
epoch: 16 training loss 0.009845630698859409, training accuracy 0.15601895734597157
epoch: 16 training loss 0.010124685995951649, training accuracy 0.16208530805687205
epoch: 16 training loss 0.010240138362369266, training accuracy 0.1681516587677725
epoch: 16 training loss 0.010956940950375598, training accuracy 0.17402843601895734
epoch: 16 training loss 0.011321043973850413, training accuracy 0.17990521327014217
epoch: 16 training loss 0.011850614892362983, training accuracy 0.18597156398104264
epoch: 16 training loss 0.012145059012688732, training accuracy 0.1920379146

epoch: 16 training loss 0.0485625737368778, training accuracy 0.724739336492891
epoch: 16 training loss 0.04947768743004279, training accuracy 0.7304265402843602
epoch: 16 training loss 0.04970874156432129, training accuracy 0.736303317535545
epoch: 16 training loss 0.05010263043557298, training accuracy 0.7421800947867299
epoch: 16 training loss 0.05111724291367553, training accuracy 0.747867298578199
epoch: 16 training loss 0.051247465378865244, training accuracy 0.7539336492890996
epoch: 16 training loss 0.05142837361136884, training accuracy 0.76
epoch: 16 training loss 0.05156813479712789, training accuracy 0.7660663507109005
epoch: 16 training loss 0.05215301353784534, training accuracy 0.7719431279620853
epoch: 16 training loss 0.05266999568419434, training accuracy 0.7780094786729858
epoch: 16 training loss 0.053031245503945376, training accuracy 0.7838862559241706
epoch: 16 training loss 0.05322040599104352, training accuracy 0.789952606635071
epoch: 16 training loss 0.0534399

epoch: 17 training loss 0.004273876838774478, training accuracy 0.07791469194312796
epoch: 17 training loss 0.004655404011785136, training accuracy 0.08398104265402843
epoch: 17 training loss 0.004956448507534949, training accuracy 0.09004739336492891
epoch: 17 training loss 0.005023575100288572, training accuracy 0.09611374407582939
epoch: 17 training loss 0.005329208035039675, training accuracy 0.10218009478672986
epoch: 17 training loss 0.00573954453400526, training accuracy 0.10824644549763032
epoch: 17 training loss 0.0059474384615206604, training accuracy 0.1143127962085308
epoch: 17 training loss 0.0060133756732488695, training accuracy 0.12037914691943127
epoch: 17 training loss 0.006737646706296369, training accuracy 0.12625592417061612
epoch: 17 training loss 0.006911511229112815, training accuracy 0.13232227488151657
epoch: 17 training loss 0.007122035964405367, training accuracy 0.13838862559241705
epoch: 17 training loss 0.0073001845414039646, training accuracy 0.144454976

epoch: 17 training loss 0.04163535150306485, training accuracy 0.6777251184834123
epoch: 17 training loss 0.04194541149794773, training accuracy 0.6836018957345972
epoch: 17 training loss 0.042087232834919934, training accuracy 0.6896682464454976
epoch: 17 training loss 0.042245744557177285, training accuracy 0.6957345971563981
epoch: 17 training loss 0.04249661197594557, training accuracy 0.7018009478672985
epoch: 17 training loss 0.04267477990891696, training accuracy 0.7078672985781991
epoch: 17 training loss 0.04290757052141343, training accuracy 0.7139336492890995
epoch: 17 training loss 0.043007564030552364, training accuracy 0.72
epoch: 17 training loss 0.04412550400218693, training accuracy 0.7256872037914692
epoch: 17 training loss 0.044344477320169384, training accuracy 0.7317535545023697
epoch: 17 training loss 0.0446627630498172, training accuracy 0.7378199052132701
epoch: 17 training loss 0.04473521821871753, training accuracy 0.7438862559241706
epoch: 17 training loss 0.0

epoch: 18 training loss 0.00196851594753175, training accuracy 0.030142180094786728
epoch: 18 training loss 0.0022031220440615973, training accuracy 0.0362085308056872
epoch: 18 training loss 0.0024466839333846108, training accuracy 0.042274881516587676
epoch: 18 training loss 0.0025723446036967058, training accuracy 0.04834123222748815
epoch: 18 training loss 0.002944291236841283, training accuracy 0.05440758293838863
epoch: 18 training loss 0.00317052136100299, training accuracy 0.060473933649289095
epoch: 18 training loss 0.003390867608418397, training accuracy 0.06654028436018958
epoch: 18 training loss 0.00357826191102159, training accuracy 0.07260663507109004
epoch: 18 training loss 0.003776305648387891, training accuracy 0.07848341232227488
epoch: 18 training loss 0.0038691107243723215, training accuracy 0.08454976303317535
epoch: 18 training loss 0.0040931234653527136, training accuracy 0.09061611374407583
epoch: 18 training loss 0.004534759329393577, training accuracy 0.096492

epoch: 18 training loss 0.03423286086292628, training accuracy 0.6309004739336492
epoch: 18 training loss 0.034624971967737825, training accuracy 0.6369668246445498
epoch: 18 training loss 0.03470507432216716, training accuracy 0.6430331753554502
epoch: 18 training loss 0.03481981277748307, training accuracy 0.6490995260663507
epoch: 18 training loss 0.03494661148408013, training accuracy 0.6551658767772511
epoch: 18 training loss 0.03518611987337682, training accuracy 0.6612322274881517
epoch: 18 training loss 0.03577636662535193, training accuracy 0.6671090047393364
epoch: 18 training loss 0.036140149648155644, training accuracy 0.673175355450237
epoch: 18 training loss 0.03633670104340919, training accuracy 0.6792417061611374
epoch: 18 training loss 0.03693480372146408, training accuracy 0.6851184834123223
epoch: 18 training loss 0.03712674193190173, training accuracy 0.6911848341232227
epoch: 18 training loss 0.03806670322519908, training accuracy 0.6970616113744076
epoch: 18 train

epoch: 18 validation loss 0.14803242199885475, validation accuracy 0.9249431387414708
epoch: 18 validation loss 0.1495419908329065, validation accuracy 0.9484457922668689
epoch: 18 validation loss 0.15126062058607134, validation accuracy 0.9529946929492039
epoch: 19 training loss 0.00011241307190809205, training accuracy 0.006066350710900474
epoch: 19 training loss 0.00018332122061489882, training accuracy 0.012132701421800948
epoch: 19 training loss 0.0007447434149647211, training accuracy 0.01800947867298578
epoch: 19 training loss 0.0007887559225208951, training accuracy 0.024075829383886256
epoch: 19 training loss 0.0010312524762763795, training accuracy 0.030142180094786728
epoch: 19 training loss 0.0013532659425554683, training accuracy 0.03601895734597156
epoch: 19 training loss 0.0014703304027494095, training accuracy 0.04208530805687204
epoch: 19 training loss 0.0016101662101338824, training accuracy 0.04815165876777251
epoch: 19 training loss 0.0018300682758268022, training a

epoch: 19 training loss 0.02743814039851817, training accuracy 0.5836966824644549
epoch: 19 training loss 0.02767549942737507, training accuracy 0.5897630331753554
epoch: 19 training loss 0.027873612453022275, training accuracy 0.595829383886256
epoch: 19 training loss 0.0280426808414866, training accuracy 0.6018957345971564
epoch: 19 training loss 0.028088191174217875, training accuracy 0.6079620853080568
epoch: 19 training loss 0.028158102072245703, training accuracy 0.6140284360189573
epoch: 19 training loss 0.02843632118679336, training accuracy 0.6200947867298578
epoch: 19 training loss 0.028534753641811027, training accuracy 0.6261611374407583
epoch: 19 training loss 0.02898075814778206, training accuracy 0.6320379146919431
epoch: 19 training loss 0.029187021902387176, training accuracy 0.6381042654028436
epoch: 19 training loss 0.02929894939135601, training accuracy 0.644170616113744
epoch: 19 training loss 0.02939730626429427, training accuracy 0.6502369668246445
epoch: 19 trai

epoch: 19 validation loss 0.1334477327735791, validation accuracy 0.7361637604245641
epoch: 19 validation loss 0.13447568812453448, validation accuracy 0.7604245640636846
epoch: 19 validation loss 0.13482995763974628, validation accuracy 0.7846853677028052
epoch: 19 validation loss 0.1360303015533589, validation accuracy 0.8081880212282032
epoch: 19 validation loss 0.14345189703062142, validation accuracy 0.8309325246398787
epoch: 19 validation loss 0.14740965881973075, validation accuracy 0.8536770280515542
epoch: 19 validation loss 0.147457077251952, validation accuracy 0.8779378316906747
epoch: 19 validation loss 0.15037438235018993, validation accuracy 0.9006823351023503
epoch: 19 validation loss 0.1568690996625552, validation accuracy 0.9234268385140257
epoch: 19 validation loss 0.1581333478894353, validation accuracy 0.9469294920394238
epoch: 19 validation loss 0.1583198846894697, validation accuracy 0.9522365428354814
epoch: 20 training loss 0.00017372958468034935, training accu

epoch: 20 training loss 0.016602692872427088, training accuracy 0.5380094786729858
epoch: 20 training loss 0.01699092225158384, training accuracy 0.5440758293838862
epoch: 20 training loss 0.01730462806202224, training accuracy 0.5501421800947868
epoch: 20 training loss 0.017485247050981385, training accuracy 0.5562085308056872
epoch: 20 training loss 0.0177562291786004, training accuracy 0.5622748815165877
epoch: 20 training loss 0.01783791614652245, training accuracy 0.5683412322274881
epoch: 20 training loss 0.017906388917240487, training accuracy 0.5744075829383886
epoch: 20 training loss 0.018086497961627365, training accuracy 0.5804739336492891
epoch: 20 training loss 0.01827245649850764, training accuracy 0.5865402843601896
epoch: 20 training loss 0.018563648408623103, training accuracy 0.59260663507109
epoch: 20 training loss 0.018727618978486807, training accuracy 0.5986729857819905
epoch: 20 training loss 0.019553433049346598, training accuracy 0.6045497630331753
epoch: 20 tr

epoch: 20 validation loss 0.080101802267298, validation accuracy 0.5617892342683851
epoch: 20 validation loss 0.08104187813873451, validation accuracy 0.5860500379075056
epoch: 20 validation loss 0.0820795197189532, validation accuracy 0.6103108415466262
epoch: 20 validation loss 0.09141078619590033, validation accuracy 0.6322971948445792
epoch: 20 validation loss 0.09469594440964517, validation accuracy 0.6557998483699773
epoch: 20 validation loss 0.09881051006734687, validation accuracy 0.6785443517816527
epoch: 20 validation loss 0.11115881815338062, validation accuracy 0.6990144048521607
epoch: 20 validation loss 0.11789199236534688, validation accuracy 0.7217589082638363
epoch: 20 validation loss 0.12162728919460522, validation accuracy 0.7445034116755117
epoch: 20 validation loss 0.12531258861751066, validation accuracy 0.7672479150871873
epoch: 20 validation loss 0.12643630789074597, validation accuracy 0.7907505686125853
epoch: 20 validation loss 0.12777978167152476, validation

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
model = buildModel(torch.nn,256,0.2)

In [None]:
model.load_state_dict(torch.load("/content/drive/MyDrive/resnet.pth",map_location=torch.device('cpu')))
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
model = models.mobilenet_v3_large(pretrained=True)
head = nn.Sequential(
        nn.Linear(model.classifier[-1].in_features,256),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(256, NUM_CLASSES)
  )
model.classifier[-1] = head #nn.Linear(model.classifier[-1].in_features,NUM_CLASSES)
model = model.to(device)
model.load_state_dict(torch.load("/content/drive/MyDrive/mobileNetLarge.pth",map_location=torch.device('cpu')))
model.eval()

Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth
100%|██████████| 21.1M/21.1M [00:00<00:00, 103MB/s]


MobileNetV3(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (2): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bi

In [None]:
ttransform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    ])
test_data = datasets.ImageFolder(TEST_DATA_PATH,transform = ttransform)
test_loader = DataLoader(test_data,batch_size=1350,shuffle=False)

In [None]:

with torch.no_grad():
    for inputs,labels in test_loader:
        inputs, labels = inputs.to(device),labels.to(device)
        preds = model(inputs)


In [None]:
preds.shape

torch.Size([1350, 101])

In [None]:
torch.save(preds,"predictions.pt")