In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, confusion_matrix
import numpy as np 
import matplotlib.pyplot as plt

import copy
import time
import sys

In [None]:
sys.path.append('data_ingestion')
from data_ingestion import data_ingestion_for_big_dataset, data_ingestion_pipeline_smaller_dataset

In [8]:
INFERENCE_PATH = './trained_pytorch_model/resnet_finetuned_smaller_dataset.pth'
INFERENCE_PATH_FOR_BIGGER_DATASET = './trained_pytorch_model/resnet_finetuned_bigger_dataset.pth'

In [3]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
if torch.cuda.is_available():
    model = model.cuda()
# Loss Function
criterion = nn.CrossEntropyLoss()
# Observe that all parameters are being optimized
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

Using cache found in /Users/pyuvraj/.cache/torch/hub/pytorch_vision_v0.10.0


In [4]:
# check if CUDA is available
use_cuda = torch.cuda.is_available()
# set device to be cuda if available, otherwise it will be set to cpu
device = torch.device("cuda" if use_cuda else "cpu")

In [33]:
# Generators - Small Nudity Dataset
training_set = data_pipeline_pytorch_smaller_dataset.training_dataset
training_generator = data_pipeline_pytorch_smaller_dataset.train_dataloader
validation_set = data_pipeline_pytorch_smaller_dataset.val_dataset
validation_generator = data_pipeline_pytorch_smaller_dataset.val_dataloader
dataloaders = {'train': training_generator, 'val': validation_generator}

In [5]:
# Generators - Large Nudity Dataset
training_set = data_ingestion_for_big_dataset.training_dataset
training_generator = data_ingestion_for_big_dataset.train_dataloader
validation_set = data_ingestion_for_big_dataset.val_dataset
validation_generator = data_ingestion_for_big_dataset.val_dataloader
dataloaders = {'train': training_generator, 'val': validation_generator}

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
            running_loss = 0.0
            running_corrects = 0
            iteration_no = 0
            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                iteration_no += 1
                inputs = inputs.to(device)
                labels = labels.to(device)
                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        if iteration_no % 5 == 0:
                            print("loss iteration -> ", iteration_no, loss.item())
                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model, INFERENCE_PATH)
        print()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    # load best model weights
    model.load_state_dict(best_model_wts)
    return model
# Train and evaluate
dataset_sizes = {'train': len(training_set), 'val': len(validation_set)}
# Dummy dataset
model_ft = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=10)



Epoch 0/9
----------
loss iteration ->  5 0.2023783177137375
loss iteration ->  10 0.13433001935482025
loss iteration ->  15 0.20007950067520142
loss iteration ->  20 0.12698262929916382
loss iteration ->  25 0.3564136326313019
loss iteration ->  30 0.18685854971408844
loss iteration ->  35 0.07465089857578278
loss iteration ->  40 0.2087007462978363
loss iteration ->  45 0.19297254085540771
loss iteration ->  50 0.303440660238266
loss iteration ->  55 0.36875954270362854
loss iteration ->  60 0.10443352907896042
loss iteration ->  65 0.32403597235679626
loss iteration ->  70 0.1242697536945343
loss iteration ->  75 0.16409790515899658
loss iteration ->  80 0.15111762285232544
loss iteration ->  85 0.3254172205924988
loss iteration ->  90 0.31990185379981995
loss iteration ->  95 0.27288588881492615
loss iteration ->  100 0.12771248817443848
loss iteration ->  105 0.29708731174468994
loss iteration ->  110 0.15137290954589844
loss iteration ->  115 0.0774708017706871
loss iteration -> 

loss iteration ->  965 0.2506457269191742
loss iteration ->  970 0.24701598286628723
loss iteration ->  975 0.1068802922964096
loss iteration ->  980 0.23213085532188416
loss iteration ->  985 0.1346988081932068
loss iteration ->  990 0.1859842985868454
loss iteration ->  995 0.17438846826553345
loss iteration ->  1000 0.1087697222828865
loss iteration ->  1005 0.0466451495885849
loss iteration ->  1010 0.19985680282115936
loss iteration ->  1015 0.15210452675819397
loss iteration ->  1020 0.1939258575439453
loss iteration ->  1025 0.10882598906755447
loss iteration ->  1030 0.2519749104976654
loss iteration ->  1035 0.36774611473083496
loss iteration ->  1040 0.6734572052955627
loss iteration ->  1045 0.06230933219194412
loss iteration ->  1050 0.062322065234184265
loss iteration ->  1055 0.21298614144325256
loss iteration ->  1060 0.08944787085056305
loss iteration ->  1065 0.16158899664878845
loss iteration ->  1070 0.18918311595916748
loss iteration ->  1075 0.10481712222099304
los

loss iteration ->  1905 0.06298721581697464
loss iteration ->  1910 0.06345494836568832
loss iteration ->  1915 0.20485331118106842
loss iteration ->  1920 0.09961772710084915
loss iteration ->  1925 0.04484044015407562
loss iteration ->  1930 0.14112739264965057
loss iteration ->  1935 0.22244180738925934
loss iteration ->  1940 0.04445759579539299
loss iteration ->  1945 0.08971220254898071
loss iteration ->  1950 0.0620267353951931
loss iteration ->  1955 0.05902067944407463
loss iteration ->  1960 0.3548986315727234
loss iteration ->  1965 0.09406425058841705
loss iteration ->  1970 0.16036789119243622
loss iteration ->  1975 0.07922035455703735
loss iteration ->  1980 0.04992000013589859
loss iteration ->  1985 0.14464381337165833
loss iteration ->  1990 0.20290061831474304
loss iteration ->  1995 0.10519041866064072
loss iteration ->  2000 0.07647629827260971
loss iteration ->  2005 0.02750280871987343
loss iteration ->  2010 0.057599809020757675
loss iteration ->  2015 0.1308129

loss iteration ->  2840 0.13426627218723297
loss iteration ->  2845 0.12173186242580414
loss iteration ->  2850 0.44073042273521423
loss iteration ->  2855 0.07631254196166992
loss iteration ->  2860 0.11913534253835678
loss iteration ->  2865 0.23984715342521667
loss iteration ->  2870 0.09433590620756149
loss iteration ->  2875 0.07061593234539032
loss iteration ->  2880 0.06401565670967102
loss iteration ->  2885 0.07782591134309769
loss iteration ->  2890 0.02255364879965782
loss iteration ->  2895 0.2079150527715683
loss iteration ->  2900 0.02587568201124668
loss iteration ->  2905 0.23981329798698425
loss iteration ->  2910 0.07017521560192108
loss iteration ->  2915 0.08271531760692596
loss iteration ->  2920 0.1013370156288147
loss iteration ->  2925 0.15460751950740814
loss iteration ->  2930 0.0996214896440506
loss iteration ->  2935 0.029209434986114502
loss iteration ->  2940 0.1951584815979004
loss iteration ->  2945 0.10669827461242676
loss iteration ->  2950 0.263536334

loss iteration ->  3775 0.05398242920637131
loss iteration ->  3780 0.13315574824810028
loss iteration ->  3785 0.06449830532073975
loss iteration ->  3790 0.06152352690696716
loss iteration ->  3795 0.14320477843284607
loss iteration ->  3800 0.04634175822138786
loss iteration ->  3805 0.04890846833586693
loss iteration ->  3810 0.15524977445602417
loss iteration ->  3815 0.09578672051429749
loss iteration ->  3820 0.27156519889831543
loss iteration ->  3825 0.1782999485731125
loss iteration ->  3830 0.20049814879894257
loss iteration ->  3835 0.07894110679626465
loss iteration ->  3840 0.15200692415237427
loss iteration ->  3845 0.07593519985675812
loss iteration ->  3850 0.1653614342212677
loss iteration ->  3855 0.06058483198285103
loss iteration ->  3860 0.06082502380013466
loss iteration ->  3865 0.05595611035823822
loss iteration ->  3870 0.13828428089618683
loss iteration ->  3875 0.04111397638916969
loss iteration ->  3880 0.10100235044956207
loss iteration ->  3885 0.14021894

loss iteration ->  4710 0.23838697373867035
loss iteration ->  4715 0.10199267417192459
loss iteration ->  4720 0.023828059434890747
loss iteration ->  4725 0.037112992256879807
loss iteration ->  4730 0.1763581484556198
loss iteration ->  4735 0.1264694482088089
loss iteration ->  4740 0.23458276689052582
loss iteration ->  4745 0.0856214165687561
loss iteration ->  4750 0.05894636735320091
loss iteration ->  4755 0.25510600209236145
loss iteration ->  4760 0.20789143443107605
loss iteration ->  4765 0.2423495203256607
loss iteration ->  4770 0.09073301404714584
loss iteration ->  4775 0.18259383738040924
loss iteration ->  4780 0.09866020828485489
loss iteration ->  4785 0.10583851486444473
loss iteration ->  4790 0.16628190875053406
loss iteration ->  4795 0.31669023633003235
loss iteration ->  4800 0.04708659276366234
loss iteration ->  4805 0.09630530327558517
loss iteration ->  4810 0.11624038219451904
loss iteration ->  4815 0.12960422039031982
loss iteration ->  4820 0.17490103

loss iteration ->  5645 0.1864696741104126
loss iteration ->  5650 0.20530781149864197
loss iteration ->  5655 0.09363196790218353
loss iteration ->  5660 0.1366916298866272
loss iteration ->  5665 0.26009970903396606
loss iteration ->  5670 0.07156100124120712
loss iteration ->  5675 0.0814736932516098
loss iteration ->  5680 0.14200907945632935
loss iteration ->  5685 0.06733085215091705
loss iteration ->  5690 0.029764443635940552
loss iteration ->  5695 0.23413510620594025
loss iteration ->  5700 0.24254556000232697
loss iteration ->  5705 0.0919015109539032
loss iteration ->  5710 0.022085700184106827
loss iteration ->  5715 0.1306351125240326
loss iteration ->  5720 0.07099751383066177
loss iteration ->  5725 0.1022028774023056
loss iteration ->  5730 0.09378723055124283
loss iteration ->  5735 0.04591953009366989
loss iteration ->  5740 0.12408434599637985
loss iteration ->  5745 0.027167340740561485
loss iteration ->  5750 0.040782175958156586
loss iteration ->  5755 0.11791527

loss iteration ->  6580 0.09542369097471237
loss iteration ->  6585 0.10086841881275177
loss iteration ->  6590 0.08172443509101868
loss iteration ->  6595 0.26958343386650085
loss iteration ->  6600 0.01649501733481884
loss iteration ->  6605 0.08552893996238708
loss iteration ->  6610 0.14285975694656372
loss iteration ->  6615 0.06990920007228851
loss iteration ->  6620 0.13359545171260834
loss iteration ->  6625 0.25411897897720337
loss iteration ->  6630 0.13165993988513947
loss iteration ->  6635 0.07496748864650726
loss iteration ->  6640 0.07702013105154037
loss iteration ->  6645 0.06144173443317413
loss iteration ->  6650 0.08026310801506042
loss iteration ->  6655 0.015629354864358902
loss iteration ->  6660 0.09025871008634567
loss iteration ->  6665 0.36426207423210144
loss iteration ->  6670 0.4208802282810211
loss iteration ->  6675 0.10986215621232986
loss iteration ->  6680 0.06967071443796158
loss iteration ->  6685 0.06749576330184937
loss iteration ->  6690 0.071987

loss iteration ->  7515 0.13101783394813538
loss iteration ->  7520 0.04941582307219505
loss iteration ->  7525 0.0344529002904892
loss iteration ->  7530 0.030802415683865547
loss iteration ->  7535 0.026996759697794914
loss iteration ->  7540 0.10724771022796631
loss iteration ->  7545 0.05009084939956665
loss iteration ->  7550 0.26790282130241394
loss iteration ->  7555 0.04886823147535324
loss iteration ->  7560 0.038629576563835144
loss iteration ->  7565 0.03303965926170349
loss iteration ->  7570 0.03172940015792847
loss iteration ->  7575 0.08770565688610077
loss iteration ->  7580 0.08431969583034515
loss iteration ->  7585 0.07696785032749176
loss iteration ->  7590 0.1586330085992813
loss iteration ->  7595 0.21259541809558868




loss iteration ->  7600 0.1906307190656662
loss iteration ->  7605 0.036511559039354324
loss iteration ->  7610 0.17978452146053314
loss iteration ->  7615 0.04934833571314812
loss iteration ->  7620 0.04349154233932495
loss iteration ->  7625 0.12104641646146774
loss iteration ->  7630 0.11010488122701645
loss iteration ->  7635 0.022673096507787704
loss iteration ->  7640 0.11814815551042557
loss iteration ->  7645 0.09683819860219955
loss iteration ->  7650 0.1659368872642517
loss iteration ->  7655 0.04458582028746605
loss iteration ->  7660 0.08651543408632278
loss iteration ->  7665 0.11621250212192535
loss iteration ->  7670 0.15164493024349213
loss iteration ->  7675 0.24689197540283203
loss iteration ->  7680 0.06131646782159805
loss iteration ->  7685 0.12784570455551147
loss iteration ->  7690 0.05297398567199707
loss iteration ->  7695 0.17311929166316986
loss iteration ->  7700 0.2052941769361496
loss iteration ->  7705 0.10608253628015518
loss iteration ->  7710 0.0453771

loss iteration ->  8535 0.06450759619474411
loss iteration ->  8540 0.19247187674045563
loss iteration ->  8545 0.10055376589298248
loss iteration ->  8550 0.20636311173439026
loss iteration ->  8555 0.1475151926279068
loss iteration ->  8560 0.07780508697032928
loss iteration ->  8565 0.09683836996555328
loss iteration ->  8570 0.09604127705097198
loss iteration ->  8575 0.2991284132003784
loss iteration ->  8580 0.021009620279073715
loss iteration ->  8585 0.08532623201608658
loss iteration ->  8590 0.029037490487098694
loss iteration ->  8595 0.02302762120962143
loss iteration ->  8600 0.05373163893818855
loss iteration ->  8605 0.11311748623847961
loss iteration ->  8610 0.18583686649799347
loss iteration ->  8615 0.16462945938110352
loss iteration ->  8620 0.4151102900505066
loss iteration ->  8625 0.11025398969650269
loss iteration ->  8630 0.07747550308704376
loss iteration ->  8635 0.20163573324680328
loss iteration ->  8640 0.15819160640239716
loss iteration ->  8645 0.0772828

loss iteration ->  9470 0.13272321224212646
loss iteration ->  9475 0.058398738503456116
loss iteration ->  9480 0.13326427340507507
loss iteration ->  9485 0.022222526371479034
loss iteration ->  9490 0.08262727409601212
loss iteration ->  9495 0.21095606684684753
loss iteration ->  9500 0.15062817931175232
loss iteration ->  9505 0.13602733612060547
loss iteration ->  9510 0.11317255347967148
loss iteration ->  9515 0.015363767743110657
loss iteration ->  9520 0.04153067246079445
loss iteration ->  9525 0.058646656572818756
loss iteration ->  9530 0.14818626642227173
loss iteration ->  9535 0.04973645508289337
loss iteration ->  9540 0.11808159947395325
loss iteration ->  9545 0.28561022877693176
loss iteration ->  9550 0.14651888608932495
loss iteration ->  9555 0.06485751271247864
loss iteration ->  9560 0.08237138390541077
loss iteration ->  9565 0.10931790620088577
loss iteration ->  9570 0.17916393280029297
loss iteration ->  9575 0.038981933146715164
loss iteration ->  9580 0.1

loss iteration ->  10395 0.14905601739883423
loss iteration ->  10400 0.0852663516998291
loss iteration ->  10405 0.17084020376205444
loss iteration ->  10410 0.13847079873085022
loss iteration ->  10415 0.0314183235168457
loss iteration ->  10420 0.03999779000878334
loss iteration ->  10425 0.035521287471055984
loss iteration ->  10430 0.12578700482845306
loss iteration ->  10435 0.1575891226530075
loss iteration ->  10440 0.031998682767152786
loss iteration ->  10445 0.1453377902507782
loss iteration ->  10450 0.14822643995285034
loss iteration ->  10455 0.09382081776857376
loss iteration ->  10460 0.07721716910600662
loss iteration ->  10465 0.1859191656112671
loss iteration ->  10470 0.20811542868614197
loss iteration ->  10475 0.030336137861013412
loss iteration ->  10480 0.13073484599590302
loss iteration ->  10485 0.03567839786410332
loss iteration ->  10490 0.09860625118017197
loss iteration ->  10495 0.1532752811908722
loss iteration ->  10500 0.038285691291093826
loss iterati

loss iteration ->  11310 0.0949733555316925
loss iteration ->  11315 0.03202815353870392
loss iteration ->  11320 0.06125709414482117
loss iteration ->  11325 0.07108493149280548
loss iteration ->  11330 0.0418478362262249
loss iteration ->  11335 0.08841736614704132
loss iteration ->  11340 0.0266244038939476
loss iteration ->  11345 0.0638982504606247
loss iteration ->  11350 0.19188709557056427
loss iteration ->  11355 0.16428855061531067
loss iteration ->  11360 0.08224275708198547
loss iteration ->  11365 0.016864387318491936
loss iteration ->  11370 0.021765194833278656
loss iteration ->  11375 0.19938978552818298
loss iteration ->  11380 0.07113750278949738
loss iteration ->  11385 0.2062709480524063
loss iteration ->  11390 0.057162657380104065
loss iteration ->  11395 0.07451138645410538
loss iteration ->  11400 0.03648408129811287
loss iteration ->  11405 0.15443438291549683
loss iteration ->  11410 0.20756515860557556
loss iteration ->  11415 0.13377000391483307
loss iterati

### Testing and evaluation part

In [25]:
test_loader = data_pipeline_pytorch_smaller_dataset.test_dataloader
model_ft.eval()
test_loss = 0
with torch.no_grad():
    test_preds = []
    test_labels = []
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        test_preds += preds.cpu().numpy().tolist()
        test_labels += labels.cpu().numpy().tolist()

In [26]:

# 3. Calculate the test accuracy and confusion matrix
test_acc = accuracy_score(test_labels, test_preds)
conf_mat = confusion_matrix(test_labels, test_preds)
print("Test Loss: {:.4f}".format(test_loss/len(data_pipeline_pytorch_smaller_dataset.test_dataset)))
print("Test Accuracy: {:.4f}".format(test_acc))
print("Confusion Matrix:\n", conf_mat)

Test Loss: 0.0294
Test Accuracy: 0.9921
Confusion Matrix:
 [[3186    9]
 [  40 2969]]


### Download the model for inference

In [32]:
torch.save(model, INFERENCE_PATH)
inference_model = torch.load(INFERENCE_PATH)
inference_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  