In [7]:
from torchvision import transforms
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import os
from PIL import Image # import PIL library

In [8]:
# Define a custom dataset class that inherits from torch.utils.data.Dataset
class CustomDataset(Dataset):
    def __init__(self, root):
        # Root is the directory that contains all the class folders
        self.root = root
        self.folders_paths = [os.path.join(root, folder) for folder in os.listdir(root)]
        self.classes = [folder.split("/")[-1] for folder in self.folders_paths]
        self.files = []
        for label, folder_path in enumerate(self.folders_paths):
            for file_name in os.listdir(folder_path):
                file_path = os.path.join(folder_path, file_name)
                self.files.append((file_path, label))

    def __getitem__(self, idx):
        # Return a sample and its corresponding label
        file_path, label = self.files[idx]
        # Open the image using PIL
        image = Image.open(file_path).convert('RGB')
        # Note: This returns an image in the form of a PIL object
        return image, label

    def __len__(self):
        # Return the number of samples
        return len(self.files)

    
# Create an instance of the custom dataset with a given root directory
root_dir = "final_data"
data = CustomDataset(root_dir)

# Split the dataset into train (60%), test (20%) and validation (20%) sets
train_size = int(0.6 * len(data))
test_size = int(0.2 * len(data))
val_size = len(data) - train_size - test_size
train_data, test_data, val_data = random_split(data, [train_size, test_size, val_size])

# Define different transformations for each dataset
train_transform  =   transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.1),
        transforms.RandomAffine(degrees=40, translate=None, scale=(1, 2), shear=15),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Create a wrapper class to apply the transformations on the fly
class TransformDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

# Create transformed datasets using the wrapper class
train_data = TransformDataset(train_data, transform=train_transform)
test_data = TransformDataset(test_data, transform=transform_test)
val_data = TransformDataset(val_data, transform=transform_test)

# Create data loaders for each dataset with a given batch size
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)



### Defining the architecture for the model

For this project, our team considered the use of multiple architectures, including VGG, ResNet, and DenseNet. In the end, we decided to use ResNet-18 and submit this model as it performs better in our experiments.

In [9]:
import torch
torch.cuda.is_available()

True

In [10]:
# TODO: Build and train your network
from torchvision.models import resnet18
from torch import nn
import torch.optim as optim
import torch
import numpy as np
from tqdm import tqdm
device = "cuda"

# Additional import
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Add model here
#device = torch.device("cuda")
class model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 56 * 56, 256)  # Update the input size for fc1
        self.fc2 = nn.Linear(256, 10)  # Update the output size for fc2

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

print(model)

# for param in model.features:
    # param.requires_grad = True
#model().to(device)

net = model().to(device)



criterion = nn.CrossEntropyLoss()
optimiser = optim.SGD(net.parameters(), lr=0.05)

<class '__main__.model'>


### Train the network

In [11]:
def validation(net,dl_test):
    net.eval()
    correct=0
    total = 0
    for i, data in enumerate(dl_test, 0):
        # Get data from each batch. Data is in the format of [inputs, labels]
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()

        # Reset the parameter gradients
        optimiser.zero_grad()

        # Forward pass, backward pass, and optimise
        outputs = net(inputs)
        
        correct += (torch.argmax(outputs,-1) == labels).float().sum()
        total += len(labels)

    return correct/total

In [12]:
import copy
from tqdm import tqdm 


max_accuracy = 0
for epoch in range(200):  # Train for 200 epochs
    net.train()
    running_loss = 0.0
    train_acc = 0.0
    correct = 0
    pbar = tqdm(train_loader)

    for i, data in enumerate(pbar, 0):
        # Get data from each batch. Data is in the format of [inputs, labels]
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()

        # Reset the parameter gradients
        optimiser.zero_grad()

        # Forward pass, backward pass, and optimise
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimiser.step()

        # Print 
        pbar.set_description(f"loss: {str(loss.item())}")

        running_loss += loss.item()
        correct += (torch.argmax(outputs,-1) == labels).float().sum()
    
    train_acc = correct/len(train_data)
    validation_acc = validation(net,val_loader)

    print(f'[Epoch: {epoch + 1}] train_loss: {running_loss / 127:.3f}, train_acc: {train_acc}, val_acc: {validation_acc}')

    if  validation_acc > max_accuracy:
        max_accuracy = validation_acc
        best_model = copy.deepcopy(net)
print('Finished Training')

loss: 1.9419134855270386: 100%|██████████| 127/127 [00:31<00:00,  4.03it/s]


[Epoch: 1] train_loss: 2.075, train_acc: 0.21458488702774048, val_acc: 0.2535107135772705


loss: 1.728811502456665: 100%|██████████| 127/127 [00:34<00:00,  3.63it/s] 


[Epoch: 2] train_loss: 1.890, train_acc: 0.28282830119132996, val_acc: 0.2808573544025421


loss: 1.782105565071106: 100%|██████████| 127/127 [00:20<00:00,  6.16it/s] 


[Epoch: 3] train_loss: 1.809, train_acc: 0.30795764923095703, val_acc: 0.3843311071395874


loss: 1.6948846578598022: 100%|██████████| 127/127 [00:20<00:00,  6.20it/s]


[Epoch: 4] train_loss: 1.745, train_acc: 0.34318798780441284, val_acc: 0.37694013118743896


loss: 1.7492437362670898: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s]


[Epoch: 5] train_loss: 1.690, train_acc: 0.3646218478679657, val_acc: 0.4035476744174957


loss: 1.5325556993484497: 100%|██████████| 127/127 [00:34<00:00,  3.66it/s]


[Epoch: 6] train_loss: 1.682, train_acc: 0.3843311369419098, val_acc: 0.3843311071395874


loss: 1.5412477254867554: 100%|██████████| 127/127 [00:29<00:00,  4.24it/s]


[Epoch: 7] train_loss: 1.620, train_acc: 0.39788126945495605, val_acc: 0.4153732359409332


loss: 1.514523983001709: 100%|██████████| 127/127 [00:35<00:00,  3.62it/s] 


[Epoch: 8] train_loss: 1.573, train_acc: 0.42177876830101013, val_acc: 0.397634893655777


loss: 1.6112215518951416: 100%|██████████| 127/127 [00:25<00:00,  5.05it/s]


[Epoch: 9] train_loss: 1.555, train_acc: 0.4313870668411255, val_acc: 0.37989652156829834


loss: 1.563075065612793: 100%|██████████| 127/127 [00:20<00:00,  6.06it/s] 


[Epoch: 10] train_loss: 1.520, train_acc: 0.4375461935997009, val_acc: 0.41611233353614807


loss: 1.1876037120819092: 100%|██████████| 127/127 [00:21<00:00,  6.01it/s]


[Epoch: 11] train_loss: 1.492, train_acc: 0.45947280526161194, val_acc: 0.44198077917099


loss: 1.3563956022262573: 100%|██████████| 127/127 [00:35<00:00,  3.61it/s]


[Epoch: 12] train_loss: 1.446, train_acc: 0.46464648842811584, val_acc: 0.48041388392448425


loss: 1.559416651725769: 100%|██████████| 127/127 [00:35<00:00,  3.61it/s] 


[Epoch: 13] train_loss: 1.447, train_acc: 0.4678492248058319, val_acc: 0.4708056151866913


loss: 1.5460638999938965: 100%|██████████| 127/127 [00:34<00:00,  3.64it/s]


[Epoch: 14] train_loss: 1.409, train_acc: 0.4747474789619446, val_acc: 0.44641536474227905


loss: 1.4903472661972046: 100%|██████████| 127/127 [00:34<00:00,  3.65it/s]


[Epoch: 15] train_loss: 1.379, train_acc: 0.49618133902549744, val_acc: 0.44493716955184937


loss: 1.420577883720398: 100%|██████████| 127/127 [00:35<00:00,  3.63it/s] 


[Epoch: 16] train_loss: 1.380, train_acc: 0.49470314383506775, val_acc: 0.5003695487976074


loss: 1.0125410556793213: 100%|██████████| 127/127 [00:23<00:00,  5.47it/s]


[Epoch: 17] train_loss: 1.346, train_acc: 0.5028332471847534, val_acc: 0.4552845358848572


loss: 0.9993934035301208: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s]


[Epoch: 18] train_loss: 1.316, train_acc: 0.5153979063034058, val_acc: 0.5011086463928223


loss: 1.115802526473999: 100%|██████████| 127/127 [00:20<00:00,  6.16it/s] 


[Epoch: 19] train_loss: 1.307, train_acc: 0.5213106870651245, val_acc: 0.5084996223449707


loss: 1.2422975301742554: 100%|██████████| 127/127 [00:20<00:00,  6.17it/s]


[Epoch: 20] train_loss: 1.283, train_acc: 0.5400345325469971, val_acc: 0.4966740608215332


loss: 1.17451810836792: 100%|██████████| 127/127 [00:20<00:00,  6.17it/s]  


[Epoch: 21] train_loss: 1.263, train_acc: 0.535599946975708, val_acc: 0.5151515007019043


loss: 1.3278160095214844: 100%|██████████| 127/127 [00:20<00:00,  6.10it/s]


[Epoch: 22] train_loss: 1.271, train_acc: 0.5410199761390686, val_acc: 0.4745011031627655


loss: 1.3704752922058105: 100%|██████████| 127/127 [00:23<00:00,  5.43it/s]


[Epoch: 23] train_loss: 1.247, train_acc: 0.5454545617103577, val_acc: 0.521064281463623


loss: 0.9152085185050964: 100%|██████████| 127/127 [00:35<00:00,  3.62it/s]


[Epoch: 24] train_loss: 1.238, train_acc: 0.5503818988800049, val_acc: 0.4996304512023926


loss: 1.7208235263824463: 100%|██████████| 127/127 [00:34<00:00,  3.63it/s]


[Epoch: 25] train_loss: 1.225, train_acc: 0.5602365136146545, val_acc: 0.507021427154541


loss: 1.3380095958709717: 100%|██████████| 127/127 [00:35<00:00,  3.62it/s]


[Epoch: 26] train_loss: 1.196, train_acc: 0.5599901676177979, val_acc: 0.48041388392448425


loss: 1.2990137338638306: 100%|██████████| 127/127 [00:35<00:00,  3.60it/s]


[Epoch: 27] train_loss: 1.205, train_acc: 0.5624538064002991, val_acc: 0.5121951103210449


loss: 0.9232532978057861: 100%|██████████| 127/127 [00:35<00:00,  3.62it/s]


[Epoch: 28] train_loss: 1.185, train_acc: 0.5713229775428772, val_acc: 0.5062823295593262


loss: 1.2883498668670654: 100%|██████████| 127/127 [00:20<00:00,  6.23it/s]


[Epoch: 29] train_loss: 1.157, train_acc: 0.573293924331665, val_acc: 0.521064281463623


loss: 1.2488523721694946: 100%|██████████| 127/127 [00:20<00:00,  6.15it/s]


[Epoch: 30] train_loss: 1.179, train_acc: 0.5703375339508057, val_acc: 0.5099778175354004


loss: 1.4459757804870605: 100%|██████████| 127/127 [00:21<00:00,  6.02it/s]


[Epoch: 31] train_loss: 1.152, train_acc: 0.574525773525238, val_acc: 0.5232815742492676


loss: 0.9583354592323303: 100%|██████████| 127/127 [00:21<00:00,  5.99it/s]


[Epoch: 32] train_loss: 1.120, train_acc: 0.5907859206199646, val_acc: 0.5247597694396973


loss: 1.1433591842651367: 100%|██████████| 127/127 [00:20<00:00,  6.15it/s]


[Epoch: 33] train_loss: 1.125, train_acc: 0.5905395746231079, val_acc: 0.49741315841674805


loss: 1.442553162574768: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s] 


[Epoch: 34] train_loss: 1.109, train_acc: 0.5920177698135376, val_acc: 0.5099778175354004


loss: 1.1207133531570435: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s]


[Epoch: 35] train_loss: 1.100, train_acc: 0.592510461807251, val_acc: 0.486326664686203


loss: 1.1580853462219238: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s]


[Epoch: 36] train_loss: 1.095, train_acc: 0.6028578877449036, val_acc: 0.5491500496864319


loss: 1.0203778743743896: 100%|██████████| 127/127 [00:20<00:00,  6.12it/s]


[Epoch: 37] train_loss: 1.090, train_acc: 0.6031042337417603, val_acc: 0.5151515007019043


loss: 1.2052366733551025: 100%|██████████| 127/127 [00:20<00:00,  6.10it/s]


[Epoch: 38] train_loss: 1.073, train_acc: 0.6117270588874817, val_acc: 0.572062075138092


loss: 0.8511607646942139: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s]


[Epoch: 39] train_loss: 1.064, train_acc: 0.6146834492683411, val_acc: 0.5388026833534241


loss: 1.0906143188476562: 100%|██████████| 127/127 [00:20<00:00,  6.11it/s]


[Epoch: 40] train_loss: 1.043, train_acc: 0.619857132434845, val_acc: 0.5299334526062012


loss: 0.9712615013122559: 100%|██████████| 127/127 [00:20<00:00,  6.08it/s]


[Epoch: 41] train_loss: 1.032, train_acc: 0.6223207712173462, val_acc: 0.5365853905677795


loss: 1.4209725856781006: 100%|██████████| 127/127 [00:21<00:00,  6.03it/s]


[Epoch: 42] train_loss: 1.039, train_acc: 0.6250308156013489, val_acc: 0.4745011031627655


loss: 1.0607036352157593: 100%|██████████| 127/127 [00:20<00:00,  6.06it/s]


[Epoch: 43] train_loss: 1.023, train_acc: 0.6203498840332031, val_acc: 0.5461936593055725


loss: 1.0961008071899414: 100%|██████████| 127/127 [00:20<00:00,  6.16it/s]


[Epoch: 44] train_loss: 1.016, train_acc: 0.632668137550354, val_acc: 0.526237964630127


loss: 0.6907519698143005: 100%|██████████| 127/127 [00:20<00:00,  6.15it/s]


[Epoch: 45] train_loss: 0.999, train_acc: 0.6403055191040039, val_acc: 0.5351071357727051


loss: 0.9419684410095215: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s]


[Epoch: 46] train_loss: 0.985, train_acc: 0.6469573974609375, val_acc: 0.5003695487976074


loss: 0.7830182313919067: 100%|██████████| 127/127 [00:20<00:00,  6.15it/s]


[Epoch: 47] train_loss: 0.996, train_acc: 0.6334072351455688, val_acc: 0.5594974160194397


loss: 0.7334904074668884: 100%|██████████| 127/127 [00:20<00:00,  6.16it/s]


[Epoch: 48] train_loss: 0.990, train_acc: 0.6393200755119324, val_acc: 0.5476718544960022


loss: 1.0579419136047363: 100%|██████████| 127/127 [00:20<00:00,  6.07it/s]


[Epoch: 49] train_loss: 0.962, train_acc: 0.6390736699104309, val_acc: 0.5336289405822754


loss: 1.1165865659713745: 100%|██████████| 127/127 [00:21<00:00,  5.94it/s]


[Epoch: 50] train_loss: 0.955, train_acc: 0.6536092758178711, val_acc: 0.5469327569007874


loss: 1.3634247779846191: 100%|██████████| 127/127 [00:21<00:00,  6.00it/s]


[Epoch: 51] train_loss: 0.975, train_acc: 0.6513919830322266, val_acc: 0.5277161598205566


loss: 1.1170045137405396: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s]


[Epoch: 52] train_loss: 0.927, train_acc: 0.6715940237045288, val_acc: 0.5565410256385803


loss: 1.1962097883224487: 100%|██████████| 127/127 [00:20<00:00,  6.09it/s]


[Epoch: 53] train_loss: 0.951, train_acc: 0.6622321009635925, val_acc: 0.5225424766540527


loss: 0.9493206739425659: 100%|██████████| 127/127 [00:20<00:00,  6.07it/s]


[Epoch: 54] train_loss: 0.919, train_acc: 0.6706085205078125, val_acc: 0.5388026833534241


loss: 0.9657118916511536: 100%|██████████| 127/127 [00:21<00:00,  6.04it/s] 


[Epoch: 55] train_loss: 0.931, train_acc: 0.6619857549667358, val_acc: 0.5654101967811584


loss: 0.7869243025779724: 100%|██████████| 127/127 [00:20<00:00,  6.07it/s]


[Epoch: 56] train_loss: 0.938, train_acc: 0.6644493937492371, val_acc: 0.5424981713294983


loss: 0.8244277238845825: 100%|██████████| 127/127 [00:20<00:00,  6.06it/s]


[Epoch: 57] train_loss: 0.906, train_acc: 0.6671594381332397, val_acc: 0.572062075138092


loss: 0.9055414795875549: 100%|██████████| 127/127 [00:25<00:00,  4.95it/s]


[Epoch: 58] train_loss: 0.898, train_acc: 0.675289511680603, val_acc: 0.5631929039955139


loss: 0.8262976408004761: 100%|██████████| 127/127 [00:20<00:00,  6.14it/s]


[Epoch: 59] train_loss: 0.879, train_acc: 0.6846514344215393, val_acc: 0.6001478433609009


loss: 0.6818248629570007: 100%|██████████| 127/127 [00:20<00:00,  6.08it/s]


[Epoch: 60] train_loss: 0.890, train_acc: 0.6681448817253113, val_acc: 0.55801922082901


loss: 0.6084223985671997: 100%|██████████| 127/127 [00:20<00:00,  6.10it/s]


[Epoch: 61] train_loss: 0.900, train_acc: 0.6747967600822449, val_acc: 0.5764966607093811


loss: 0.7793833613395691: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s]


[Epoch: 62] train_loss: 0.878, train_acc: 0.6839123368263245, val_acc: 0.5713229775428772


loss: 1.1687654256820679: 100%|██████████| 127/127 [00:20<00:00,  6.06it/s] 


[Epoch: 63] train_loss: 0.858, train_acc: 0.6974624395370483, val_acc: 0.5314116477966309


loss: 0.7446322441101074: 100%|██████████| 127/127 [00:20<00:00,  6.10it/s]


[Epoch: 64] train_loss: 0.871, train_acc: 0.6814486384391785, val_acc: 0.5565410256385803


loss: 0.8456396460533142: 100%|██████████| 127/127 [00:21<00:00,  6.01it/s]


[Epoch: 65] train_loss: 0.882, train_acc: 0.6881005167961121, val_acc: 0.5476718544960022


loss: 0.5199100971221924: 100%|██████████| 127/127 [00:21<00:00,  5.99it/s] 


[Epoch: 66] train_loss: 0.871, train_acc: 0.6883469223976135, val_acc: 0.5654101967811584


loss: 1.0865064859390259: 100%|██████████| 127/127 [00:21<00:00,  5.99it/s]


[Epoch: 67] train_loss: 0.828, train_acc: 0.6927815079689026, val_acc: 0.4915003776550293


loss: 0.9928141236305237: 100%|██████████| 127/127 [00:20<00:00,  6.11it/s]


[Epoch: 68] train_loss: 0.841, train_acc: 0.7043607234954834, val_acc: 0.5624538064002991


loss: 0.8158530592918396: 100%|██████████| 127/127 [00:20<00:00,  6.08it/s] 


[Epoch: 69] train_loss: 0.806, train_acc: 0.7149544358253479, val_acc: 0.5476718544960022


loss: 0.8040568828582764: 100%|██████████| 127/127 [00:20<00:00,  6.07it/s] 


[Epoch: 70] train_loss: 0.838, train_acc: 0.6982015371322632, val_acc: 0.5395417809486389


loss: 0.7442015409469604: 100%|██████████| 127/127 [00:20<00:00,  6.14it/s]


[Epoch: 71] train_loss: 0.866, train_acc: 0.6920424103736877, val_acc: 0.5846267342567444


loss: 0.9787978529930115: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s]


[Epoch: 72] train_loss: 0.819, train_acc: 0.7127371430397034, val_acc: 0.5609756112098694


loss: 0.5628281831741333: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s]


[Epoch: 73] train_loss: 0.807, train_acc: 0.7129834890365601, val_acc: 0.5410199761390686


loss: 1.278964877128601: 100%|██████████| 127/127 [00:20<00:00,  6.16it/s]  


[Epoch: 74] train_loss: 0.828, train_acc: 0.7055925130844116, val_acc: 0.5521064400672913


loss: 1.11237633228302: 100%|██████████| 127/127 [00:20<00:00,  6.12it/s]   


[Epoch: 75] train_loss: 0.800, train_acc: 0.7149544358253479, val_acc: 0.516629695892334


loss: 0.9926663637161255: 100%|██████████| 127/127 [00:21<00:00,  5.96it/s]


[Epoch: 76] train_loss: 0.758, train_acc: 0.7344173789024353, val_acc: 0.5646710991859436


loss: 1.0652369260787964: 100%|██████████| 127/127 [00:20<00:00,  6.16it/s] 


[Epoch: 77] train_loss: 0.780, train_acc: 0.7277655005455017, val_acc: 0.5506282448768616


loss: 0.9485701322555542: 100%|██████████| 127/127 [00:20<00:00,  6.16it/s]


[Epoch: 78] train_loss: 0.768, train_acc: 0.7211136221885681, val_acc: 0.5395417809486389


loss: 0.9454540610313416: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s]


[Epoch: 79] train_loss: 0.780, train_acc: 0.7191426753997803, val_acc: 0.5232815742492676


loss: 0.9282891154289246: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s] 


[Epoch: 80] train_loss: 0.764, train_acc: 0.7253018021583557, val_acc: 0.5639320015907288


loss: 0.8139274716377258: 100%|██████████| 127/127 [00:20<00:00,  6.12it/s] 


[Epoch: 81] train_loss: 0.783, train_acc: 0.7164326310157776, val_acc: 0.5609756112098694


loss: 0.7638924717903137: 100%|██████████| 127/127 [00:20<00:00,  6.11it/s] 


[Epoch: 82] train_loss: 0.752, train_acc: 0.7349100708961487, val_acc: 0.5469327569007874


loss: 1.378777027130127: 100%|██████████| 127/127 [00:20<00:00,  6.10it/s] 


[Epoch: 83] train_loss: 0.757, train_acc: 0.7319536805152893, val_acc: 0.5225424766540527


loss: 0.8878802061080933: 100%|██████████| 127/127 [00:20<00:00,  6.17it/s] 


[Epoch: 84] train_loss: 0.719, train_acc: 0.7427937984466553, val_acc: 0.5188469886779785


loss: 0.8472790718078613: 100%|██████████| 127/127 [00:20<00:00,  6.17it/s] 


[Epoch: 85] train_loss: 0.746, train_acc: 0.7317073345184326, val_acc: 0.5432372689247131


loss: 0.5962098240852356: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s]


[Epoch: 86] train_loss: 0.721, train_acc: 0.7435328960418701, val_acc: 0.5469327569007874


loss: 1.191718578338623: 100%|██████████| 127/127 [00:20<00:00,  6.17it/s] 


[Epoch: 87] train_loss: 0.752, train_acc: 0.7366346716880798, val_acc: 0.5447154641151428


loss: 0.5911536812782288: 100%|██████████| 127/127 [00:20<00:00,  6.15it/s] 


[Epoch: 88] train_loss: 0.748, train_acc: 0.7302291393280029, val_acc: 0.5247597694396973


loss: 0.7265192866325378: 100%|██████████| 127/127 [00:20<00:00,  6.17it/s] 


[Epoch: 89] train_loss: 0.702, train_acc: 0.7501847743988037, val_acc: 0.507021427154541


loss: 0.2565404772758484: 100%|██████████| 127/127 [00:20<00:00,  6.09it/s] 


[Epoch: 90] train_loss: 0.698, train_acc: 0.7506775259971619, val_acc: 0.5602365136146545


loss: 0.7009559273719788: 100%|██████████| 127/127 [00:20<00:00,  6.15it/s] 


[Epoch: 91] train_loss: 0.728, train_acc: 0.740330159664154, val_acc: 0.5631929039955139


loss: 0.715508759021759: 100%|██████████| 127/127 [00:20<00:00,  6.21it/s]  


[Epoch: 92] train_loss: 0.715, train_acc: 0.7413156032562256, val_acc: 0.5506282448768616


loss: 0.9496360421180725: 100%|██████████| 127/127 [00:20<00:00,  6.08it/s] 


[Epoch: 93] train_loss: 0.715, train_acc: 0.7526484727859497, val_acc: 0.55801922082901


loss: 1.1265227794647217: 100%|██████████| 127/127 [00:20<00:00,  6.07it/s] 


[Epoch: 94] train_loss: 0.688, train_acc: 0.7560975551605225, val_acc: 0.5351071357727051


loss: 0.7734549045562744: 100%|██████████| 127/127 [00:20<00:00,  6.08it/s] 


[Epoch: 95] train_loss: 0.670, train_acc: 0.7629958391189575, val_acc: 0.5661492943763733


loss: 0.5059919953346252: 100%|██████████| 127/127 [00:20<00:00,  6.14it/s] 


[Epoch: 96] train_loss: 0.655, train_acc: 0.7642276883125305, val_acc: 0.5809312462806702


loss: 0.42310383915901184: 100%|██████████| 127/127 [00:21<00:00,  6.01it/s]


[Epoch: 97] train_loss: 0.671, train_acc: 0.760778546333313, val_acc: 0.5779748558998108


loss: 0.34270939230918884: 100%|██████████| 127/127 [00:21<00:00,  6.00it/s]


[Epoch: 98] train_loss: 0.710, train_acc: 0.7521557211875916, val_acc: 0.5373244881629944


loss: 0.7071079015731812: 100%|██████████| 127/127 [00:21<00:00,  6.00it/s]


[Epoch: 99] train_loss: 0.685, train_acc: 0.7533875703811646, val_acc: 0.5594974160194397


loss: 0.37491071224212646: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s]


[Epoch: 100] train_loss: 0.677, train_acc: 0.7647203803062439, val_acc: 0.5691056847572327


loss: 0.22758008539676666: 100%|██████████| 127/127 [00:20<00:00,  6.17it/s]


[Epoch: 101] train_loss: 0.664, train_acc: 0.7644740343093872, val_acc: 0.5639320015907288


loss: 0.5383146405220032: 100%|██████████| 127/127 [00:21<00:00,  5.98it/s] 


[Epoch: 102] train_loss: 0.626, train_acc: 0.777285099029541, val_acc: 0.5764966607093811


loss: 0.9486035704612732: 100%|██████████| 127/127 [00:21<00:00,  6.00it/s] 


[Epoch: 103] train_loss: 0.685, train_acc: 0.7590540051460266, val_acc: 0.5572801232337952


loss: 0.6744586825370789: 100%|██████████| 127/127 [00:20<00:00,  6.15it/s] 


[Epoch: 104] train_loss: 0.671, train_acc: 0.7649667859077454, val_acc: 0.5728011727333069


loss: 0.40981823205947876: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s]


[Epoch: 105] train_loss: 0.640, train_acc: 0.7738359570503235, val_acc: 0.5587583184242249


loss: 1.2437758445739746: 100%|██████████| 127/127 [00:20<00:00,  6.14it/s]


[Epoch: 106] train_loss: 0.668, train_acc: 0.7622567415237427, val_acc: 0.5225424766540527


loss: 0.5223370790481567: 100%|██████████| 127/127 [00:20<00:00,  6.21it/s] 


[Epoch: 107] train_loss: 0.642, train_acc: 0.769154965877533, val_acc: 0.5757575631141663


loss: 0.5059029459953308: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s] 


[Epoch: 108] train_loss: 0.652, train_acc: 0.7819660305976868, val_acc: 0.5750184655189514


loss: 1.0715066194534302: 100%|██████████| 127/127 [00:20<00:00,  6.21it/s] 


[Epoch: 109] train_loss: 0.653, train_acc: 0.7654594779014587, val_acc: 0.5794530510902405


loss: 0.49728378653526306: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s]


[Epoch: 110] train_loss: 0.622, train_acc: 0.7809805870056152, val_acc: 0.5550628304481506


loss: 0.2979235351085663: 100%|██████████| 127/127 [00:20<00:00,  6.06it/s] 


[Epoch: 111] train_loss: 0.625, train_acc: 0.7814732789993286, val_acc: 0.5779748558998108


loss: 0.4357767105102539: 100%|██████████| 127/127 [00:20<00:00,  6.12it/s] 


[Epoch: 112] train_loss: 0.622, train_acc: 0.777285099029541, val_acc: 0.5691056847572327


loss: 0.5978024005889893: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s] 


[Epoch: 113] train_loss: 0.623, train_acc: 0.7851687669754028, val_acc: 0.5875831246376038


loss: 0.5550731420516968: 100%|██████████| 127/127 [00:20<00:00,  6.15it/s] 


[Epoch: 114] train_loss: 0.595, train_acc: 0.7915742993354797, val_acc: 0.553584635257721


loss: 0.9099032878875732: 100%|██████████| 127/127 [00:20<00:00,  6.22it/s] 


[Epoch: 115] train_loss: 0.605, train_acc: 0.7925597429275513, val_acc: 0.55801922082901


loss: 0.6244757771492004: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s] 


[Epoch: 116] train_loss: 0.619, train_acc: 0.7807341814041138, val_acc: 0.5587583184242249


loss: 0.6344420909881592: 100%|██████████| 127/127 [00:20<00:00,  6.09it/s] 


[Epoch: 117] train_loss: 0.568, train_acc: 0.7974870800971985, val_acc: 0.5920177102088928


loss: 0.6889297962188721: 100%|██████████| 127/127 [00:20<00:00,  6.12it/s] 


[Epoch: 118] train_loss: 0.597, train_acc: 0.788864254951477, val_acc: 0.5779748558998108


loss: 0.4293626844882965: 100%|██████████| 127/127 [00:20<00:00,  6.09it/s] 


[Epoch: 119] train_loss: 0.566, train_acc: 0.800936222076416, val_acc: 0.5646710991859436


loss: 0.9536247849464417: 100%|██████████| 127/127 [00:20<00:00,  6.11it/s] 


[Epoch: 120] train_loss: 0.570, train_acc: 0.7999507784843445, val_acc: 0.5794530510902405


loss: 0.5935372114181519: 100%|██████████| 127/127 [00:20<00:00,  6.19it/s] 


[Epoch: 121] train_loss: 0.596, train_acc: 0.7873860597610474, val_acc: 0.5558019280433655


loss: 0.49994295835494995: 100%|██████████| 127/127 [00:20<00:00,  6.16it/s]


[Epoch: 122] train_loss: 0.563, train_acc: 0.8070953488349915, val_acc: 0.577235758304596


loss: 0.5252664089202881: 100%|██████████| 127/127 [00:20<00:00,  6.11it/s] 


[Epoch: 123] train_loss: 0.587, train_acc: 0.7915742993354797, val_acc: 0.5779748558998108


loss: 0.3560101091861725: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s] 


[Epoch: 124] train_loss: 0.601, train_acc: 0.7859078645706177, val_acc: 0.55801922082901


loss: 0.8770936131477356: 100%|██████████| 127/127 [00:20<00:00,  6.12it/s] 


[Epoch: 125] train_loss: 0.657, train_acc: 0.7696477174758911, val_acc: 0.5683665871620178


loss: 0.742755651473999: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s]  


[Epoch: 126] train_loss: 0.595, train_acc: 0.7930524945259094, val_acc: 0.5801921486854553


loss: 0.5289735794067383: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s] 


[Epoch: 127] train_loss: 0.566, train_acc: 0.8014289736747742, val_acc: 0.5934959053993225


loss: 0.34091123938560486: 100%|██████████| 127/127 [00:20<00:00,  6.11it/s]


[Epoch: 128] train_loss: 0.561, train_acc: 0.7960088849067688, val_acc: 0.5587583184242249


loss: 0.42425838112831116: 100%|██████████| 127/127 [00:20<00:00,  6.14it/s]


[Epoch: 129] train_loss: 0.608, train_acc: 0.7883715629577637, val_acc: 0.5691056847572327


loss: 0.45016780495643616: 100%|██████████| 127/127 [00:20<00:00,  6.15it/s]


[Epoch: 130] train_loss: 0.611, train_acc: 0.784429669380188, val_acc: 0.5609756112098694


loss: 0.43519327044487: 100%|██████████| 127/127 [00:20<00:00,  6.15it/s]   


[Epoch: 131] train_loss: 0.566, train_acc: 0.8075881004333496, val_acc: 0.5661492943763733


loss: 0.36478906869888306: 100%|██████████| 127/127 [00:20<00:00,  6.14it/s]


[Epoch: 132] train_loss: 0.564, train_acc: 0.8024144172668457, val_acc: 0.5654101967811584


loss: 0.28601592779159546: 100%|██████████| 127/127 [00:20<00:00,  6.16it/s]


[Epoch: 133] train_loss: 0.562, train_acc: 0.7960088849067688, val_acc: 0.5646710991859436


loss: 0.5418020486831665: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s] 


[Epoch: 134] train_loss: 0.562, train_acc: 0.8070953488349915, val_acc: 0.5624538064002991


loss: 1.0079087018966675: 100%|██████████| 127/127 [00:20<00:00,  6.11it/s] 


[Epoch: 135] train_loss: 0.553, train_acc: 0.8125154376029968, val_acc: 0.5668883919715881


loss: 0.37387943267822266: 100%|██████████| 127/127 [00:20<00:00,  6.19it/s]


[Epoch: 136] train_loss: 0.577, train_acc: 0.7989652752876282, val_acc: 0.5691056847572327


loss: 0.8509790897369385: 100%|██████████| 127/127 [00:20<00:00,  6.19it/s]


[Epoch: 137] train_loss: 0.823, train_acc: 0.7642276883125305, val_acc: 0.5299334526062012


loss: 0.995140552520752: 100%|██████████| 127/127 [00:20<00:00,  6.22it/s]  


[Epoch: 138] train_loss: 0.675, train_acc: 0.7578221559524536, val_acc: 0.5558019280433655


loss: 0.4465866684913635: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s] 


[Epoch: 139] train_loss: 0.592, train_acc: 0.7876324653625488, val_acc: 0.5750184655189514


loss: 0.5926189422607422: 100%|██████████| 127/127 [00:20<00:00,  6.11it/s] 


[Epoch: 140] train_loss: 0.565, train_acc: 0.8024144172668457, val_acc: 0.5809312462806702


loss: 0.5447360277175903: 100%|██████████| 127/127 [00:20<00:00,  6.16it/s] 


[Epoch: 141] train_loss: 0.587, train_acc: 0.7967479825019836, val_acc: 0.5691056847572327


loss: 0.34603193402290344: 100%|██████████| 127/127 [00:20<00:00,  6.14it/s]


[Epoch: 142] train_loss: 0.568, train_acc: 0.8014289736747742, val_acc: 0.5631929039955139


loss: 0.5973421931266785: 100%|██████████| 127/127 [00:20<00:00,  6.16it/s] 


[Epoch: 143] train_loss: 0.561, train_acc: 0.799704372882843, val_acc: 0.577235758304596


loss: 0.3521367907524109: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s] 


[Epoch: 144] train_loss: 0.565, train_acc: 0.8006898760795593, val_acc: 0.5631929039955139


loss: 0.6087773442268372: 100%|██████████| 127/127 [00:20<00:00,  6.10it/s] 


[Epoch: 145] train_loss: 0.535, train_acc: 0.814979076385498, val_acc: 0.5661492943763733


loss: 0.39329132437705994: 100%|██████████| 127/127 [00:21<00:00,  5.99it/s]


[Epoch: 146] train_loss: 0.531, train_acc: 0.8112835884094238, val_acc: 0.5639320015907288


loss: 0.5086613893508911: 100%|██████████| 127/127 [00:21<00:00,  6.03it/s] 


[Epoch: 147] train_loss: 0.529, train_acc: 0.8132545351982117, val_acc: 0.5779748558998108


loss: 1.009097695350647: 100%|██████████| 127/127 [00:20<00:00,  6.16it/s]  


[Epoch: 148] train_loss: 0.546, train_acc: 0.8090662956237793, val_acc: 0.5513673424720764


loss: 0.7189611792564392: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s] 


[Epoch: 149] train_loss: 0.554, train_acc: 0.8095590472221375, val_acc: 0.5498891472816467


loss: 0.5837669968605042: 100%|██████████| 127/127 [00:20<00:00,  6.21it/s] 


[Epoch: 150] train_loss: 0.569, train_acc: 0.8033998608589172, val_acc: 0.5683665871620178


loss: 0.44269511103630066: 100%|██████████| 127/127 [00:20<00:00,  6.15it/s]


[Epoch: 151] train_loss: 0.551, train_acc: 0.8048780560493469, val_acc: 0.5787139534950256


loss: 0.34989991784095764: 100%|██████████| 127/127 [00:20<00:00,  6.16it/s]


[Epoch: 152] train_loss: 0.535, train_acc: 0.8132545351982117, val_acc: 0.5920177102088928


loss: 0.9331488609313965: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s] 


[Epoch: 153] train_loss: 0.528, train_acc: 0.8154718279838562, val_acc: 0.5639320015907288


loss: 0.3391898572444916: 100%|██████████| 127/127 [00:20<00:00,  6.05it/s] 


[Epoch: 154] train_loss: 0.511, train_acc: 0.820152759552002, val_acc: 0.5668883919715881


loss: 0.4994470179080963: 100%|██████████| 127/127 [00:20<00:00,  6.12it/s] 


[Epoch: 155] train_loss: 0.519, train_acc: 0.8142399787902832, val_acc: 0.5617147088050842


loss: 0.7315052151679993: 100%|██████████| 127/127 [00:20<00:00,  6.17it/s] 


[Epoch: 156] train_loss: 0.512, train_acc: 0.8265582919120789, val_acc: 0.5506282448768616


loss: 0.36956024169921875: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s]


[Epoch: 157] train_loss: 0.502, train_acc: 0.8206455111503601, val_acc: 0.5838876366615295


loss: 0.46921974420547485: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s]


[Epoch: 158] train_loss: 0.530, train_acc: 0.8115299344062805, val_acc: 0.5831485390663147


loss: 0.5200432538986206: 100%|██████████| 127/127 [00:20<00:00,  6.17it/s] 


[Epoch: 159] train_loss: 0.485, train_acc: 0.8275437355041504, val_acc: 0.5787139534950256


loss: 1.289853811264038: 100%|██████████| 127/127 [00:20<00:00,  6.17it/s]  


[Epoch: 160] train_loss: 0.539, train_acc: 0.8112835884094238, val_acc: 0.5728011727333069


loss: 0.3535034954547882: 100%|██████████| 127/127 [00:20<00:00,  6.19it/s] 


[Epoch: 161] train_loss: 0.517, train_acc: 0.8176891207695007, val_acc: 0.5927568078041077


loss: 0.3540011942386627: 100%|██████████| 127/127 [00:20<00:00,  6.20it/s] 


[Epoch: 162] train_loss: 0.517, train_acc: 0.8184282183647156, val_acc: 0.5735402703285217


loss: 0.4946899116039276: 100%|██████████| 127/127 [00:20<00:00,  6.15it/s] 


[Epoch: 163] train_loss: 0.508, train_acc: 0.8265582919120789, val_acc: 0.5365853905677795


loss: 0.283061683177948: 100%|██████████| 127/127 [00:20<00:00,  6.12it/s]  


[Epoch: 164] train_loss: 0.469, train_acc: 0.8373984098434448, val_acc: 0.591278612613678


loss: 0.9434909820556641: 100%|██████████| 127/127 [00:20<00:00,  6.15it/s] 


[Epoch: 165] train_loss: 0.571, train_acc: 0.8083271980285645, val_acc: 0.4759792983531952


loss: 0.5714794993400574: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s] 


[Epoch: 166] train_loss: 0.577, train_acc: 0.7893570065498352, val_acc: 0.5521064400672913


loss: 0.39350563287734985: 100%|██████████| 127/127 [00:20<00:00,  6.15it/s]


[Epoch: 167] train_loss: 0.517, train_acc: 0.8199064135551453, val_acc: 0.5631929039955139


loss: 0.6562323570251465: 100%|██████████| 127/127 [00:20<00:00,  6.18it/s] 


[Epoch: 168] train_loss: 0.499, train_acc: 0.8295146822929382, val_acc: 0.5831485390663147


loss: 0.7012578248977661: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s] 


[Epoch: 169] train_loss: 0.483, train_acc: 0.8272973895072937, val_acc: 0.5705838799476624


loss: 0.409848153591156: 100%|██████████| 127/127 [00:20<00:00,  6.16it/s]  


[Epoch: 170] train_loss: 0.484, train_acc: 0.8339492678642273, val_acc: 0.5550628304481506


loss: 0.2556022107601166: 100%|██████████| 127/127 [00:20<00:00,  6.19it/s] 


[Epoch: 171] train_loss: 0.476, train_acc: 0.8322247266769409, val_acc: 0.5691056847572327


loss: 0.23653703927993774: 100%|██████████| 127/127 [00:20<00:00,  6.17it/s]


[Epoch: 172] train_loss: 0.489, train_acc: 0.8309928774833679, val_acc: 0.591278612613678


loss: 0.6551159024238586: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s] 


[Epoch: 173] train_loss: 0.499, train_acc: 0.8317319750785828, val_acc: 0.5639320015907288


loss: 0.4188595414161682: 100%|██████████| 127/127 [00:20<00:00,  6.17it/s] 


[Epoch: 174] train_loss: 0.472, train_acc: 0.8428184390068054, val_acc: 0.5728011727333069


loss: 0.2389538437128067: 100%|██████████| 127/127 [00:20<00:00,  6.21it/s] 


[Epoch: 175] train_loss: 0.473, train_acc: 0.8337029218673706, val_acc: 0.581670343875885


loss: 0.5553034543991089: 100%|██████████| 127/127 [00:20<00:00,  6.10it/s] 


[Epoch: 176] train_loss: 0.451, train_acc: 0.835427463054657, val_acc: 0.5757575631141663


loss: 0.4872553050518036: 100%|██████████| 127/127 [00:20<00:00,  6.14it/s] 


[Epoch: 177] train_loss: 0.459, train_acc: 0.8425720930099487, val_acc: 0.553584635257721


loss: 0.48775193095207214: 100%|██████████| 127/127 [00:20<00:00,  6.06it/s]


[Epoch: 178] train_loss: 0.455, train_acc: 0.8388766050338745, val_acc: 0.5750184655189514


loss: 0.5220817923545837: 100%|██████████| 127/127 [00:20<00:00,  6.12it/s] 


[Epoch: 179] train_loss: 0.523, train_acc: 0.8243409991264343, val_acc: 0.5728011727333069


loss: 0.2465997040271759: 100%|██████████| 127/127 [00:21<00:00,  6.04it/s] 


[Epoch: 180] train_loss: 0.462, train_acc: 0.8349347114562988, val_acc: 0.5838876366615295


loss: 0.3674168288707733: 100%|██████████| 127/127 [00:20<00:00,  6.11it/s] 


[Epoch: 181] train_loss: 0.429, train_acc: 0.849470317363739, val_acc: 0.5898004174232483


loss: 0.1806570440530777: 100%|██████████| 127/127 [00:21<00:00,  5.83it/s] 


[Epoch: 182] train_loss: 0.458, train_acc: 0.8435575366020203, val_acc: 0.5824094414710999


loss: 0.12127211689949036: 100%|██████████| 127/127 [00:20<00:00,  6.10it/s]


[Epoch: 183] train_loss: 0.427, train_acc: 0.8524267077445984, val_acc: 0.5609756112098694


loss: 0.5052822828292847: 100%|██████████| 127/127 [00:21<00:00,  6.04it/s] 


[Epoch: 184] train_loss: 0.440, train_acc: 0.8484848737716675, val_acc: 0.572062075138092


loss: 0.2575681209564209: 100%|██████████| 127/127 [00:20<00:00,  6.08it/s] 


[Epoch: 185] train_loss: 0.499, train_acc: 0.8272973895072937, val_acc: 0.5942350029945374


loss: 0.597866952419281: 100%|██████████| 127/127 [00:20<00:00,  6.14it/s]  


[Epoch: 186] train_loss: 0.467, train_acc: 0.8364129066467285, val_acc: 0.5757575631141663


loss: 0.911329984664917: 100%|██████████| 127/127 [00:20<00:00,  6.14it/s]  


[Epoch: 187] train_loss: 0.465, train_acc: 0.8408475518226624, val_acc: 0.5424981713294983


loss: 0.3397802710533142: 100%|██████████| 127/127 [00:20<00:00,  6.10it/s] 


[Epoch: 188] train_loss: 0.465, train_acc: 0.8376447558403015, val_acc: 0.5572801232337952


loss: 0.1296302080154419: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s] 


[Epoch: 189] train_loss: 0.466, train_acc: 0.8388766050338745, val_acc: 0.5757575631141663


loss: 0.2964620888233185: 100%|██████████| 127/127 [00:20<00:00,  6.09it/s] 


[Epoch: 190] train_loss: 0.432, train_acc: 0.8514412641525269, val_acc: 0.5757575631141663


loss: 0.6690590381622314: 100%|██████████| 127/127 [00:20<00:00,  6.11it/s] 


[Epoch: 191] train_loss: 0.469, train_acc: 0.8467603325843811, val_acc: 0.5698447823524475


loss: 1.7978882789611816: 100%|██████████| 127/127 [00:20<00:00,  6.13it/s] 


[Epoch: 192] train_loss: 0.443, train_acc: 0.8543976545333862, val_acc: 0.553584635257721


loss: 0.4998726546764374: 100%|██████████| 127/127 [00:20<00:00,  6.12it/s] 


[Epoch: 193] train_loss: 0.483, train_acc: 0.834195613861084, val_acc: 0.5883222222328186


loss: 0.4287574887275696: 100%|██████████| 127/127 [00:20<00:00,  6.19it/s] 


[Epoch: 194] train_loss: 0.446, train_acc: 0.8452821373939514, val_acc: 0.5654101967811584


loss: 0.8617207407951355: 100%|██████████| 127/127 [00:20<00:00,  6.08it/s] 


[Epoch: 195] train_loss: 0.479, train_acc: 0.8371520042419434, val_acc: 0.5314116477966309


loss: 0.2845201790332794: 100%|██████████| 127/127 [00:28<00:00,  4.48it/s] 


[Epoch: 196] train_loss: 0.441, train_acc: 0.8492239713668823, val_acc: 0.5883222222328186


loss: 0.301771342754364: 100%|██████████| 127/127 [00:25<00:00,  4.90it/s]  


[Epoch: 197] train_loss: 0.446, train_acc: 0.8457748293876648, val_acc: 0.5801921486854553


loss: 0.3018282651901245: 100%|██████████| 127/127 [00:33<00:00,  3.75it/s] 


[Epoch: 198] train_loss: 0.439, train_acc: 0.8479921221733093, val_acc: 0.5794530510902405


loss: 0.3959945738315582: 100%|██████████| 127/127 [00:30<00:00,  4.13it/s] 


[Epoch: 199] train_loss: 0.442, train_acc: 0.8531658053398132, val_acc: 0.5728011727333069


loss: 0.2041681855916977: 100%|██████████| 127/127 [00:27<00:00,  4.64it/s] 


[Epoch: 200] train_loss: 0.421, train_acc: 0.8541513085365295, val_acc: 0.5971913933753967
Finished Training


## Save/Load model.

In [13]:
model_dir = os.path.join(os.getcwd(), 'baseline_model', 'CNN')
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
torch.save(best_model, os.path.join(model_dir, 'model.pth'))
best_model = torch.load(os.path.join(model_dir, 'model.pth'))

## Validate the models.

In [14]:
print(f"Accuracy on validation data: {validation(best_model, test_loader)}")

Accuracy on validation data: 0.5750184655189514
