# VGG16 TESTING

In [None]:
import torchvision.models as models
from time import time
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from data import input_data
import numpy as np
from collections import OrderedDict

In [None]:
vgg16 = models.vgg16(pretrained="imagenet")
classifier = nn.Sequential(nn.Linear(25088,4096),
                                nn.ReLU6(True),
                                nn.Dropout(0.5,inplace=False),
                                nn.Linear(4096,4096),
                                nn.ReLU6(True),
                                nn.Dropout(0.5,inplace=False),
                                nn.Linear(4096,10))

In [None]:
f = []
for i in [i for i in classifier.state_dict().keys()][:-2]:
    f.append(((i,vgg16.classifier.state_dict()[i]))) 
for i in [i for i in classifier.state_dict().keys()][-2:]:
    f.append(((i,classifier.state_dict()[i]))) 

In [None]:
vgg16.classifier = classifier
vgg16.classifier.load_state_dict(OrderedDict(f),strict=True)
device = torch.device("cuda:5")
vgg16.to(device)
print('')

In [None]:
# vgg16

In [None]:
optimizer = optim.Adam(vgg16.parameters(), lr=0.0001)
criterion_d = nn.CrossEntropyLoss()
input_train = input_data(root_dir = "data/train/", type = "valid")
train_dl =  DataLoader(input_train, batch_size=2,shuffle=True, num_workers=4)
input_valid = input_data(root_dir = "data/test/", type = "valid")
valid_dl =  DataLoader(input_valid, batch_size=2,shuffle=False, num_workers=4)

# Training

In [None]:
vgg16.load_state_dict(torch.load("weights/vgg16.pth"))

In [None]:
stat = 0
for j in range(50):
    print("start of epoch: ", j+1)
    #Training
    running_loss = 0
    start = time()
    vgg16.train()
    for i, data in enumerate(train_dl, 0):

        input, target, img_name, number_of_class = data
        input, target = (input.type(torch.float32)).to(device), target.to(device)

        out = vgg16(input)

        loss = criterion_d(out, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()


        optimizer.zero_grad()
        
        # print every 25 mini-batches
        if i % 25 == 24:
            print('[%d, %5d] loss: %.3f' %(j + 1, i + 1, running_loss))
            running_loss = 0
    end = time()
    print("It took : ", (end - start)/60, " mins for the last training epoch")

    running_loss, acc, num, length = 0, 0, 0, 0
    with torch.no_grad():
        start = time()
        for i, data in tqdm(enumerate(valid_dl, 0),total=len(valid_dl), unit="images",position=0,leave=True):
            vgg16.eval()

            input, target, img_name, number_of_class = data
            input, target = (input.type(torch.float32)).to(device), target.to(device)

            out = vgg16(input)

            loss = criterion_d(out, target)
            running_loss += loss.cpu().numpy()
            out , predicted = torch.max(out, 1)
            for k in range(len(target)):
                if target[k] == predicted[k].item():
                    num = num + 1
            length = length + len(target)
        acc = (num/length)*100
        end = time()
        print("accuracy and val loss is : ",acc,",",running_loss/(i+1), " --AND-- ", " It took : ", (end - start), " seconds ")
    
    if acc > stat:
        stat = acc
        torch.save(vgg16.state_dict(),"../weights/" + "vgg16" + ".pth")

# Testing

In [None]:
running_loss, acc, num, length = 0, 0, 0, 0
with torch.no_grad():
    start = time()
    for i, data in tqdm(enumerate(valid_dl, 0),total=len(valid_dl), unit="images",position=0,leave=True):
        vgg16.eval()

        input, target, img_name, number_of_class = data
        input, target = (input.type(torch.float32)).to(device), target.to(device)

        out = vgg16(input)

        loss = criterion_d(out, target)
        running_loss+=loss.cpu().numpy()

        out , predicted = torch.max(out, 1)
        for k in range(len(target)):
            if target[k] == predicted[k].item():
                num = num + 1
        length = length + len(target)
    acc = (num/length)*100
    end = time()
    print("accuracy and val loss is : ",np.round(acc,3),",",np.round(running_loss/(i+1),3), " --AND-- ", " It took : ", (end - start), " seconds ")

In [None]:
# torch.save(vgg16.state_dict(),"weights/" + "vgg16" + ".pth")