In [3]:
from torchvision import transforms
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import os
from PIL import Image # import PIL library
import torch
device = "cuda"

In [4]:
# Define a custom dataset class that inherits from torch.utils.data.Dataset
class CustomDataset(Dataset):
    def __init__(self, root):
        # Root is the directory that contains all the class folders
        self.root = root
        self.folders_paths = [os.path.join(root, folder) for folder in os.listdir(root)]
        self.classes = [folder.split("/")[-1] for folder in self.folders_paths]
        self.files = []
        for label, folder_path in enumerate(self.folders_paths):
            for file_name in os.listdir(folder_path):
                file_path = os.path.join(folder_path, file_name)
                self.files.append((file_path, label))

    def __getitem__(self, idx):
        # Return a sample and its corresponding label
        file_path, label = self.files[idx]
        # Open the image using PIL
        image = Image.open(file_path).convert('RGB')
        # Note: This returns an image in the form of a PIL object
        return image, label

    def __len__(self):
        # Return the number of samples
        return len(self.files)

    
# Create an instance of the custom dataset with a given root directory
root_dir = "final_data"
data = CustomDataset(root_dir)

# Split the dataset into train (60%), test (20%) and validation (20%) sets
train_size = int(0.6 * len(data))
test_size = int(0.2 * len(data))
val_size = len(data) - train_size - test_size
train_data, test_data, val_data = random_split(data, [train_size, test_size, val_size])

# Define different transformations for each dataset
train_transform  =   transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.1),
        transforms.RandomAffine(degrees=40, translate=None, scale=(1, 2), shear=15),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Create a wrapper class to apply the transformations on the fly
class TransformDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

# Create transformed datasets using the wrapper class
train_data = TransformDataset(train_data, transform=train_transform)
test_data = TransformDataset(test_data, transform=transform_test)
val_data = TransformDataset(val_data, transform=transform_test)

# Create data loaders for each dataset with a given batch size
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)



### Define a Convolutional Neural Network (CNN)

In [9]:
# TODO: Build and train your network
from torchvision.models import vgg16
from torch import nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm

model = vgg16(weights=None)

classifier = nn.Sequential(
          nn.Linear(in_features=25088, out_features=512, bias=True),
          nn.ReLU(),
          nn.Dropout(p=0.5),
          nn.Linear(in_features=512, out_features=8, bias=True)
        )

for param in model.features:
    param.requires_grad = True
model.classifier = classifier
model.to(device)
net = model

criterion = nn.CrossEntropyLoss()
optimiser = optim.SGD(net.parameters(), lr=0.05)

### Train the network

In [7]:
def validation(net,dl_test):
    net.eval()
    correct=0
    total = 0
    for i, data in enumerate(dl_test, 0):
        # Get data from each batch. Data is in the format of [inputs, labels]
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()

        # Reset the parameter gradients
        optimiser.zero_grad()

        # Forward pass, backward pass, and optimise
        outputs = net(inputs)
        
        correct += (torch.argmax(outputs,-1) == labels).float().sum()
        total += len(labels)

    return correct/total

In [21]:
import copy
from tqdm import tqdm 


max_accuracy = 0
for epoch in range(200):  # Train for 200 epochs
    net.train()
    running_loss = 0.0
    train_acc = 0.0
    correct = 0
    pbar = tqdm(train_loader)

    for i, data in enumerate(pbar, 0):
        # Get data from each batch. Data is in the format of [inputs, labels]
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()

        # Reset the parameter gradients
        optimiser.zero_grad()

        # Forward pass, backward pass, and optimise
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimiser.step()

        # Print 
        pbar.set_description(f"loss: {str(loss.item())}")

        running_loss += loss.item()
        correct += (torch.argmax(outputs,-1) == labels).float().sum()
    
    train_acc = correct/len(train_data)
    validation_acc = validation(net,val_loader)

    print(f'[Epoch: {epoch + 1}] train_loss: {running_loss / 127:.3f}, train_acc: {train_acc}, val_acc: {validation_acc}')

    if  validation_acc > max_accuracy:
        max_accuracy = validation_acc
        best_model = copy.deepcopy(net)
print('Finished Training')

loss: 1.9844987392425537: 100%|██████████| 127/127 [00:29<00:00,  4.26it/s]


[epoch: 1] loss: 2.067,val_acc: 0.1596452295780182


loss: 2.0244505405426025: 100%|██████████| 127/127 [00:29<00:00,  4.25it/s]


[epoch: 2] loss: 2.046,val_acc: 0.190687358379364


loss: 2.046872615814209: 100%|██████████| 127/127 [00:34<00:00,  3.64it/s] 


[epoch: 3] loss: 2.032,val_acc: 0.1936437487602234


loss: 1.982298493385315: 100%|██████████| 127/127 [00:44<00:00,  2.87it/s] 


[epoch: 4] loss: 1.999,val_acc: 0.21877309679985046


loss: 2.0973434448242188: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 5] loss: 1.934,val_acc: 0.190687358379364


loss: 2.0404157638549805: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 6] loss: 1.947,val_acc: 0.2328159660100937


loss: 1.7642492055892944: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 7] loss: 1.891,val_acc: 0.2793791592121124


loss: 1.5657365322113037: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 8] loss: 1.827,val_acc: 0.3311160206794739


loss: 1.8868738412857056: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 9] loss: 1.794,val_acc: 0.31707316637039185


loss: 1.6710972785949707: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 10] loss: 1.747,val_acc: 0.3414634168148041


loss: 1.745048999786377: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s] 


[epoch: 11] loss: 1.713,val_acc: 0.3695491552352905


loss: 1.6447104215621948: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 12] loss: 1.659,val_acc: 0.42645972967147827


loss: 1.515755295753479: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s] 


[epoch: 13] loss: 1.643,val_acc: 0.37841832637786865


loss: 1.6002787351608276: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 14] loss: 1.630,val_acc: 0.3555062711238861


loss: 1.606266975402832: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s] 


[epoch: 15] loss: 1.619,val_acc: 0.3311160206794739


loss: 1.3232972621917725: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 16] loss: 1.599,val_acc: 0.43976348638534546


loss: 1.469947338104248: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s] 


[epoch: 17] loss: 1.529,val_acc: 0.421286016702652


loss: 1.8836394548416138: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 18] loss: 1.496,val_acc: 0.37694013118743896


loss: 1.3153098821640015: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 19] loss: 1.497,val_acc: 0.4693274199962616


loss: 1.213143229484558: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s] 


[epoch: 20] loss: 1.492,val_acc: 0.4663710296154022


loss: 1.6989365816116333: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 21] loss: 1.443,val_acc: 0.4286770224571228


loss: 1.2776758670806885: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 22] loss: 1.448,val_acc: 0.47524020075798035


loss: 1.3902928829193115: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 23] loss: 1.398,val_acc: 0.5114560127258301


loss: 1.601660966873169: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s] 


[epoch: 24] loss: 1.376,val_acc: 0.4951958656311035


loss: 1.3210678100585938: 100%|██████████| 127/127 [00:44<00:00,  2.88it/s]


[epoch: 25] loss: 1.343,val_acc: 0.5269770622253418


loss: 2.1912119388580322: 100%|██████████| 127/127 [00:37<00:00,  3.36it/s]


[epoch: 26] loss: 1.317,val_acc: 0.26237988471984863


loss: 1.488350510597229: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 27] loss: 1.365,val_acc: 0.5351071357727051


loss: 1.1443825960159302: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 28] loss: 1.290,val_acc: 0.486326664686203


loss: 0.75357586145401: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]  


[epoch: 29] loss: 1.259,val_acc: 0.4589800536632538


loss: 1.608860969543457: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 30] loss: 1.257,val_acc: 0.4611973464488983


loss: 0.8235606551170349: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 31] loss: 1.242,val_acc: 0.5321507453918457


loss: 1.131270408630371: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 32] loss: 1.197,val_acc: 0.5084996223449707


loss: 1.204642415046692: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 33] loss: 1.229,val_acc: 0.5402808785438538


loss: 1.2758979797363281: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s]


[epoch: 34] loss: 1.169,val_acc: 0.5469327569007874


loss: 1.1880745887756348: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 35] loss: 1.114,val_acc: 0.5018477439880371


loss: 0.8872246742248535: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 36] loss: 1.140,val_acc: 0.6031042337417603


loss: 0.9183523654937744: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s]


[epoch: 37] loss: 1.097,val_acc: 0.5986695885658264


loss: 1.1007908582687378: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 38] loss: 1.062,val_acc: 0.6341463327407837


loss: 0.9788706302642822: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 39] loss: 1.080,val_acc: 0.5602365136146545


loss: 1.0314724445343018: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 40] loss: 1.058,val_acc: 0.5779748558998108


loss: 1.1268296241760254: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 41] loss: 1.029,val_acc: 0.6045824289321899


loss: 0.8164306282997131: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 42] loss: 1.009,val_acc: 0.5905395150184631


loss: 1.5322171449661255: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 43] loss: 0.992,val_acc: 0.4981522560119629


loss: 0.8974185585975647: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 44] loss: 0.998,val_acc: 0.6171470880508423


loss: 0.9488456845283508: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s]


[epoch: 45] loss: 0.984,val_acc: 0.6526237726211548


loss: 1.1092253923416138: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 46] loss: 0.958,val_acc: 0.6348854303359985


loss: 1.076895833015442: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 47] loss: 0.927,val_acc: 0.6407982110977173


loss: 0.6418919563293457: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 48] loss: 0.898,val_acc: 0.656319260597229


loss: 0.621671736240387: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s] 


[epoch: 49] loss: 0.904,val_acc: 0.6274944543838501


loss: 0.6904114484786987: 100%|██████████| 127/127 [00:29<00:00,  4.29it/s]


[epoch: 50] loss: 0.907,val_acc: 0.6555801630020142


loss: 0.8682537078857422: 100%|██████████| 127/127 [00:29<00:00,  4.29it/s]


[epoch: 51] loss: 0.862,val_acc: 0.6548410654067993


loss: 1.1762701272964478: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 52] loss: 0.857,val_acc: 0.6873614192008972


loss: 0.5025816559791565: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 53] loss: 0.826,val_acc: 0.6518846750259399


loss: 1.1915452480316162: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s]


[epoch: 54] loss: 0.838,val_acc: 0.656319260597229


loss: 0.7575940489768982: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 55] loss: 0.802,val_acc: 0.6629711985588074


loss: 0.8618336915969849: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s]


[epoch: 56] loss: 0.810,val_acc: 0.6888396143913269


loss: 1.0445371866226196: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 57] loss: 0.764,val_acc: 0.5498891472816467


loss: 0.5157267451286316: 100%|██████████| 127/127 [00:29<00:00,  4.29it/s]


[epoch: 58] loss: 0.770,val_acc: 0.669623076915741


loss: 0.8633652925491333: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s]


[epoch: 59] loss: 0.774,val_acc: 0.656319260597229


loss: 0.6685047745704651: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 60] loss: 0.768,val_acc: 0.6489282846450806


loss: 0.6072579622268677: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 61] loss: 0.728,val_acc: 0.6954914927482605


loss: 0.8931350111961365: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s] 


[epoch: 62] loss: 0.726,val_acc: 0.6629711985588074


loss: 0.7052773833274841: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 63] loss: 0.720,val_acc: 0.5875831246376038


loss: 0.7680718302726746: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 64] loss: 0.733,val_acc: 0.7021433711051941


loss: 0.8293414115905762: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 65] loss: 0.699,val_acc: 0.6659275889396667


loss: 0.514920175075531: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]  


[epoch: 66] loss: 0.716,val_acc: 0.7280118465423584


loss: 0.4870631694793701: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 67] loss: 0.667,val_acc: 0.6866223216056824


loss: 0.4743445813655853: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 68] loss: 0.671,val_acc: 0.693274199962616


loss: 1.1278188228607178: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 69] loss: 0.650,val_acc: 0.7065779566764832


loss: 0.7856694459915161: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 70] loss: 0.661,val_acc: 0.707317054271698


loss: 0.7832642793655396: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 71] loss: 0.632,val_acc: 0.7028824687004089


loss: 0.372341513633728: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]  


[epoch: 72] loss: 0.630,val_acc: 0.7198817133903503


loss: 0.5487894415855408: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 73] loss: 0.628,val_acc: 0.6984478831291199


loss: 0.4328795075416565: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s] 


[epoch: 74] loss: 0.602,val_acc: 0.7324464321136475


loss: 0.6365184187889099: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 75] loss: 0.596,val_acc: 0.7154471278190613


loss: 0.37710970640182495: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 76] loss: 0.571,val_acc: 0.7243162989616394


loss: 0.3772864043712616: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 77] loss: 0.581,val_acc: 0.7420547008514404


loss: 0.6282452940940857: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 78] loss: 0.575,val_acc: 0.7206208109855652


loss: 0.3360954225063324: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 79] loss: 0.578,val_acc: 0.7124907374382019


loss: 0.6102911829948425: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 80] loss: 0.532,val_acc: 0.6917960047721863


loss: 0.45665568113327026: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s]


[epoch: 81] loss: 0.559,val_acc: 0.7280118465423584


loss: 0.4871194064617157: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 82] loss: 0.554,val_acc: 0.7368810176849365


loss: 0.5877224206924438: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 83] loss: 0.513,val_acc: 0.72135990858078


loss: 0.4167799651622772: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 84] loss: 0.513,val_acc: 0.6888396143913269


loss: 0.4040306806564331: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 85] loss: 0.504,val_acc: 0.7198817133903503


loss: 0.3289484679698944: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s] 


[epoch: 86] loss: 0.495,val_acc: 0.7147080302238464


loss: 0.7148594856262207: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 87] loss: 0.526,val_acc: 0.7102734446525574


loss: 0.1546916961669922: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 88] loss: 0.484,val_acc: 0.7383592128753662


loss: 0.5774920582771301: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 89] loss: 0.479,val_acc: 0.7257945537567139


loss: 0.43075770139694214: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 90] loss: 0.486,val_acc: 0.7413156032562256


loss: 0.45061179995536804: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s]


[epoch: 91] loss: 0.455,val_acc: 0.6674057841300964


loss: 0.5980983376502991: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 92] loss: 0.458,val_acc: 0.7110125422477722


loss: 0.309550017118454: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s]  


[epoch: 93] loss: 0.438,val_acc: 0.6903178095817566


loss: 0.3776136636734009: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s] 


[epoch: 94] loss: 0.450,val_acc: 0.7309682369232178


loss: 0.5130547285079956: 100%|██████████| 127/127 [00:29<00:00,  4.29it/s] 


[epoch: 95] loss: 0.427,val_acc: 0.7191426157951355


loss: 0.5710451006889343: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 96] loss: 0.436,val_acc: 0.7250553965568542


loss: 0.13467194139957428: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s]


[epoch: 97] loss: 0.426,val_acc: 0.7235772013664246


loss: 0.3021523058414459: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 98] loss: 0.412,val_acc: 0.7361419200897217


loss: 0.6714264750480652: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 99] loss: 0.414,val_acc: 0.7095343470573425


loss: 0.5220749974250793: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 100] loss: 0.442,val_acc: 0.7243162989616394


loss: 0.8304249048233032: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 101] loss: 0.380,val_acc: 0.7235772013664246


loss: 0.12575307488441467: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s]


[epoch: 102] loss: 0.393,val_acc: 0.758314847946167


loss: 0.1734626442193985: 100%|██████████| 127/127 [00:29<00:00,  4.29it/s] 


[epoch: 103] loss: 0.377,val_acc: 0.772357702255249


loss: 0.2945196330547333: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s] 


[epoch: 104] loss: 0.396,val_acc: 0.7280118465423584


loss: 0.1339768022298813: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s] 


[epoch: 105] loss: 0.404,val_acc: 0.7450110912322998


loss: 0.45750635862350464: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 106] loss: 0.357,val_acc: 0.762749433517456


loss: 0.3816896080970764: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s] 


[epoch: 107] loss: 0.360,val_acc: 0.7110125422477722


loss: 0.25044429302215576: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 108] loss: 0.380,val_acc: 0.7620103359222412


loss: 0.5893555879592896: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s] 


[epoch: 109] loss: 0.349,val_acc: 0.7265336513519287


loss: 0.3178167939186096: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s] 


[epoch: 110] loss: 0.342,val_acc: 0.7494456768035889


loss: 0.24291256070137024: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 111] loss: 0.364,val_acc: 0.6755358576774597


loss: 0.42987731099128723: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 112] loss: 0.343,val_acc: 0.7634885311126709


loss: 0.3357078731060028: 100%|██████████| 127/127 [00:29<00:00,  4.29it/s] 


[epoch: 113] loss: 0.330,val_acc: 0.748706579208374


loss: 0.36476248502731323: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 114] loss: 0.336,val_acc: 0.7649667263031006


loss: 0.24593617022037506: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s]


[epoch: 115] loss: 0.356,val_acc: 0.758314847946167


loss: 0.1708938032388687: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 116] loss: 0.325,val_acc: 0.7730967998504639


loss: 0.19841943681240082: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 117] loss: 0.337,val_acc: 0.7250553965568542


loss: 0.615487813949585: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]  


[epoch: 118] loss: 0.311,val_acc: 0.683665931224823


loss: 0.8035545349121094: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 119] loss: 0.330,val_acc: 0.6481891870498657


loss: 0.550852358341217: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]  


[epoch: 120] loss: 0.315,val_acc: 0.7287509441375732


loss: 0.1844104528427124: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 121] loss: 0.291,val_acc: 0.7753140926361084


loss: 0.3114470839500427: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s] 


[epoch: 122] loss: 0.315,val_acc: 0.7413156032562256


loss: 0.2671585977077484: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 123] loss: 0.302,val_acc: 0.7701404094696045


loss: 0.5575887560844421: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 124] loss: 0.299,val_acc: 0.7095343470573425


loss: 0.1821317970752716: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s] 


[epoch: 125] loss: 0.284,val_acc: 0.7176644206047058


loss: 0.1944882720708847: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s] 


[epoch: 126] loss: 0.289,val_acc: 0.7708795070648193


loss: 0.3900262713432312: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 127] loss: 0.290,val_acc: 0.7745749950408936


loss: 0.27623236179351807: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 128] loss: 0.287,val_acc: 0.762749433517456


loss: 0.24120302498340607: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s]


[epoch: 129] loss: 0.286,val_acc: 0.7701404094696045


loss: 0.6718526482582092: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s] 


[epoch: 130] loss: 0.277,val_acc: 0.7324464321136475


loss: 0.4625362455844879: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 131] loss: 0.264,val_acc: 0.7250553965568542


loss: 0.2785458266735077: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 132] loss: 0.305,val_acc: 0.7657058238983154


loss: 0.2354404479265213: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 133] loss: 0.246,val_acc: 0.7834441661834717


loss: 0.1812865287065506: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 134] loss: 0.247,val_acc: 0.7538802623748779


loss: 0.26753297448158264: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 135] loss: 0.272,val_acc: 0.7590539455413818


loss: 0.11356991529464722: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 136] loss: 0.329,val_acc: 0.7812268733978271


loss: 0.2873603403568268: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s] 


[epoch: 137] loss: 0.255,val_acc: 0.7597930431365967


loss: 0.17373937368392944: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 138] loss: 0.238,val_acc: 0.7043606638908386


loss: 0.19773760437965393: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s] 


[epoch: 139] loss: 0.242,val_acc: 0.7538802623748779


loss: 0.1611431986093521: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]  


[epoch: 140] loss: 0.245,val_acc: 0.758314847946167


loss: 0.0999177023768425: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]  


[epoch: 141] loss: 0.262,val_acc: 0.7716186046600342


loss: 0.30428239703178406: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 142] loss: 0.222,val_acc: 0.7686622142791748


loss: 0.3076491057872772: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 143] loss: 0.226,val_acc: 0.7405765056610107


loss: 0.21917517483234406: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 144] loss: 0.261,val_acc: 0.7597930431365967


loss: 0.17988517880439758: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 145] loss: 0.255,val_acc: 0.7708795070648193


loss: 0.2809741795063019: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s] 


[epoch: 146] loss: 0.225,val_acc: 0.758314847946167


loss: 0.24407655000686646: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s]


[epoch: 147] loss: 0.219,val_acc: 0.762749433517456


loss: 0.4789600074291229: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 148] loss: 0.221,val_acc: 0.693274199962616


loss: 0.2625228464603424: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]  


[epoch: 149] loss: 0.226,val_acc: 0.7664449214935303


loss: 0.1672375649213791: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 150] loss: 0.217,val_acc: 0.7184035181999207


loss: 0.24926511943340302: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 151] loss: 0.218,val_acc: 0.7790095806121826


loss: 0.0516202375292778: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 152] loss: 0.232,val_acc: 0.7782704830169678


loss: 0.2962280213832855: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s] 


[epoch: 153] loss: 0.194,val_acc: 0.748706579208374


loss: 0.3733045160770416: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]  


[epoch: 154] loss: 0.205,val_acc: 0.7568366527557373


loss: 0.6728817224502563: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 155] loss: 0.220,val_acc: 0.7250553965568542


loss: 0.24355605244636536: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s]


[epoch: 156] loss: 0.231,val_acc: 0.7546193599700928


loss: 0.3389817476272583: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s]  


[epoch: 157] loss: 0.228,val_acc: 0.7649667263031006


loss: 0.12448137998580933: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 158] loss: 0.188,val_acc: 0.7575757503509521


loss: 0.30755603313446045: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s]


[epoch: 159] loss: 0.192,val_acc: 0.7472283840179443


loss: 0.0511028952896595: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s] 


[epoch: 160] loss: 0.199,val_acc: 0.7494456768035889


loss: 0.32510510087013245: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s]


[epoch: 161] loss: 0.191,val_acc: 0.7730967998504639


loss: 0.23809105157852173: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 162] loss: 0.188,val_acc: 0.7368810176849365


loss: 0.07908106595277786: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 163] loss: 0.228,val_acc: 0.7694013118743896


loss: 0.18095384538173676: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 164] loss: 0.177,val_acc: 0.7568366527557373


loss: 0.32454055547714233: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 165] loss: 0.218,val_acc: 0.7568366527557373


loss: 0.12415608018636703: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s] 


[epoch: 166] loss: 0.192,val_acc: 0.7753140926361084


loss: 0.09278431534767151: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 167] loss: 0.181,val_acc: 0.7886179089546204


loss: 0.27493366599082947: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s]


[epoch: 168] loss: 0.204,val_acc: 0.7753140926361084


loss: 0.179804265499115: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s]   


[epoch: 169] loss: 0.179,val_acc: 0.7642276287078857


loss: 0.05818232521414757: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 170] loss: 0.173,val_acc: 0.7708795070648193


loss: 0.039275627583265305: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 171] loss: 0.171,val_acc: 0.7804877758026123


loss: 0.05047263950109482: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 172] loss: 0.211,val_acc: 0.7812268733978271


loss: 0.058536138385534286: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 173] loss: 0.191,val_acc: 0.7775313854217529


loss: 0.23558984696865082: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 174] loss: 0.173,val_acc: 0.7560975551605225


loss: 0.09724408388137817: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 175] loss: 0.167,val_acc: 0.7560975551605225


loss: 0.07499467581510544: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 176] loss: 0.164,val_acc: 0.7708795070648193


loss: 0.231010302901268: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]   


[epoch: 177] loss: 0.174,val_acc: 0.7095343470573425


loss: 0.04102596640586853: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s] 


[epoch: 178] loss: 0.184,val_acc: 0.7893570065498352


loss: 0.10556390881538391: 100%|██████████| 127/127 [00:29<00:00,  4.32it/s]


[epoch: 179] loss: 0.172,val_acc: 0.7649667263031006


loss: 0.028458217158913612: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 180] loss: 0.162,val_acc: 0.7886179089546204


loss: 0.03326983377337456: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 181] loss: 0.171,val_acc: 0.786400556564331


loss: 0.17352598905563354: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 182] loss: 0.152,val_acc: 0.7982261776924133


loss: 0.2966863512992859: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]  


[epoch: 183] loss: 0.152,val_acc: 0.7797486782073975


loss: 0.08470361679792404: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 184] loss: 0.145,val_acc: 0.7708795070648193


loss: 0.4851442277431488: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 185] loss: 0.232,val_acc: 0.7597930431365967


loss: 0.11475766450166702: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s] 


[epoch: 186] loss: 0.158,val_acc: 0.7597930431365967


loss: 0.25241199135780334: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 187] loss: 0.180,val_acc: 0.762749433517456


loss: 0.3528895676136017: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 188] loss: 0.150,val_acc: 0.7804877758026123


loss: 0.06152652949094772: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 189] loss: 0.142,val_acc: 0.7856614589691162


loss: 0.2688770890235901: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s]  


[epoch: 190] loss: 0.152,val_acc: 0.7760531902313232


loss: 0.19273331761360168: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 191] loss: 0.158,val_acc: 0.7701404094696045


loss: 0.08519959449768066: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 192] loss: 0.134,val_acc: 0.7642276287078857


loss: 0.04195187985897064: 100%|██████████| 127/127 [00:29<00:00,  4.30it/s] 


[epoch: 193] loss: 0.157,val_acc: 0.7790095806121826


loss: 0.03276229277253151: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 194] loss: 0.151,val_acc: 0.7538802623748779


loss: 0.20792549848556519: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 195] loss: 0.153,val_acc: 0.7509238719940186


loss: 0.1881362348794937: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]  


[epoch: 196] loss: 0.141,val_acc: 0.7701404094696045


loss: 0.061849165707826614: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s]


[epoch: 197] loss: 0.118,val_acc: 0.7804877758026123


loss: 0.30577489733695984: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 198] loss: 0.156,val_acc: 0.7708795070648193


loss: 0.19176651537418365: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 199] loss: 0.133,val_acc: 0.7753140926361084


loss: 0.07956047356128693: 100%|██████████| 127/127 [00:29<00:00,  4.31it/s] 


[epoch: 200] loss: 0.127,val_acc: 0.7716186046600342
Finished Training


## Save/Load model.

In [5]:
model_dir = os.path.join(os.getcwd(), 'model-100', 'cnn-new-200e')
# if not os.path.exists(model_dir):
    # os.makedirs(model_dir)
# torch.save(best_model, os.path.join(model_dir, 'model.pth'))
best_model = torch.load(os.path.join(model_dir, 'model.pth'))

## Validate the models.

In [10]:
print(f"Accuracy on validation data: {validation(best_model, test_loader)}")



Accuracy on validation data: 0.8854397535324097
