# Model Training

In [1]:
from torchvision import transforms
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import os
from PIL import Image # import PIL library

In [2]:
# Define a custom dataset class that inherits from torch.utils.data.Dataset
class CustomDataset(Dataset):
    def __init__(self, root):
        # Root is the directory that contains all the class folders
        self.root = root
        self.folders_paths = [os.path.join(root, folder) for folder in os.listdir(root)]
        self.classes = [folder.split("/")[-1] for folder in self.folders_paths]
        self.files = []
        for label, folder_path in enumerate(self.folders_paths):
            for file_name in os.listdir(folder_path):
                file_path = os.path.join(folder_path, file_name)
                self.files.append((file_path, label))

    def __getitem__(self, idx):
        # Return a sample and its corresponding label
        file_path, label = self.files[idx]
        # Open the image using PIL
        image = Image.open(file_path).convert('RGB')
        # Note: This returns an image in the form of a PIL object
        return image, label

    def __len__(self):
        # Return the number of samples
        return len(self.files)

    
# Create an instance of the custom dataset with a given root directory
root_dir = "final_data"
data = CustomDataset(root_dir)

# Split the dataset into train (60%), test (20%) and validation (20%) sets
train_size = int(0.6 * len(data))
test_size = int(0.2 * len(data))
val_size = len(data) - train_size - test_size
train_data, test_data, val_data = random_split(data, [train_size, test_size, val_size])

# Define different transformations for each dataset
train_transform  =   transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.1),
        transforms.RandomAffine(degrees=40, translate=None, scale=(1, 2), shear=15),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Create a wrapper class to apply the transformations on the fly
class TransformDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

# Create transformed datasets using the wrapper class
train_data = TransformDataset(train_data, transform=train_transform)
test_data = TransformDataset(test_data, transform=transform_test)
val_data = TransformDataset(val_data, transform=transform_test)

# Create data loaders for each dataset with a given batch size
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)



In [3]:
# TODO: Build and train your network
from torch import nn
import torch.optim as optim
import torch
import numpy as np
from tqdm import tqdm
device = "cuda"

class LogisticRegression(nn.Module):
    def __init__(self, input_size, num_classes):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_size, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input tensor
        out = self.linear(x)
        return out

# Example usage
# instantiate the model
input_size = 224 * 224 * 3
num_classes = 8
net = LogisticRegression(input_size, num_classes).to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimiser = optim.SGD(net.parameters(), lr=0.05)

In [4]:
def validation(net,dl_test):
    net.eval()
    correct=0
    total = 0
    for i, data in enumerate(dl_test, 0):
        # Get data from each batch. Data is in the format of [inputs, labels]
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()

        # Reset the parameter gradients
        optimiser.zero_grad()

        # Forward pass, backward pass, and optimise
        outputs = net(inputs)
        
        correct += (torch.argmax(outputs,-1) == labels).float().sum()
        total += len(labels)

    return correct/total

In [5]:
import copy
from tqdm import tqdm 


max_accuracy = 0
for epoch in range(200):  # Train for 200 epochs
    net.train()
    running_loss = 0.0
    train_acc = 0.0
    correct = 0
    pbar = tqdm(train_loader)

    for i, data in enumerate(pbar, 0):
        # Get data from each batch. Data is in the format of [inputs, labels]
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()

        # Reset the parameter gradients
        optimiser.zero_grad()

        # Forward pass, backward pass, and optimise
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimiser.step()

        # Print 
        pbar.set_description(f"loss: {str(loss.item())}")

        running_loss += loss.item()
        correct += (torch.argmax(outputs,-1) == labels).float().sum()
    
    train_acc = correct/len(train_data)
    validation_acc = validation(net,val_loader)

    print(f'[Epoch: {epoch + 1}] train_loss: {running_loss / 127:.3f}, train_acc: {train_acc}, val_acc: {validation_acc}')

    if  validation_acc > max_accuracy:
        max_accuracy = validation_acc
        best_model = copy.deepcopy(net)
print('Finished Training')

loss: 525.00244140625: 100%|██████████| 127/127 [00:19<00:00,  6.36it/s]   


[Epoch: 1] train_loss: 630.218, train_acc: 0.15225425362586975, val_acc: 0.17073170840740204


loss: 439.0318603515625: 100%|██████████| 127/127 [00:19<00:00,  6.50it/s] 


[Epoch: 2] train_loss: 631.825, train_acc: 0.1697462499141693, val_acc: 0.1648189127445221


loss: 564.4920654296875: 100%|██████████| 127/127 [00:19<00:00,  6.43it/s] 


[Epoch: 3] train_loss: 565.850, train_acc: 0.17319537699222565, val_acc: 0.17812268435955048


loss: 769.2908935546875: 100%|██████████| 127/127 [00:19<00:00,  6.44it/s] 


[Epoch: 4] train_loss: 617.431, train_acc: 0.17294900119304657, val_acc: 0.14486326277256012


loss: 464.1841735839844: 100%|██████████| 127/127 [00:19<00:00,  6.42it/s] 


[Epoch: 5] train_loss: 632.493, train_acc: 0.17294900119304657, val_acc: 0.1596452295780182


loss: 562.1242065429688: 100%|██████████| 127/127 [00:33<00:00,  3.74it/s]


[Epoch: 6] train_loss: 620.530, train_acc: 0.16555802524089813, val_acc: 0.18773096799850464


loss: 955.9105224609375: 100%|██████████| 127/127 [00:33<00:00,  3.75it/s] 


[Epoch: 7] train_loss: 618.968, train_acc: 0.1697462499141693, val_acc: 0.13525499403476715


loss: 591.9640502929688: 100%|██████████| 127/127 [00:33<00:00,  3.74it/s] 


[Epoch: 8] train_loss: 637.076, train_acc: 0.17590540647506714, val_acc: 0.1648189127445221


loss: 401.4735107421875: 100%|██████████| 127/127 [00:34<00:00,  3.72it/s] 


[Epoch: 9] train_loss: 587.467, train_acc: 0.17196354269981384, val_acc: 0.1699926108121872


loss: 888.2235717773438: 100%|██████████| 127/127 [00:34<00:00,  3.72it/s] 


[Epoch: 10] train_loss: 600.092, train_acc: 0.1724562793970108, val_acc: 0.16260161995887756


loss: 514.1671142578125: 100%|██████████| 127/127 [00:28<00:00,  4.40it/s] 


[Epoch: 11] train_loss: 629.332, train_acc: 0.17565903067588806, val_acc: 0.1633407175540924


loss: 530.9091186523438: 100%|██████████| 127/127 [00:33<00:00,  3.79it/s] 


[Epoch: 12] train_loss: 612.967, train_acc: 0.17590540647506714, val_acc: 0.15890613198280334


loss: 761.3062133789062: 100%|██████████| 127/127 [00:21<00:00,  5.84it/s] 


[Epoch: 13] train_loss: 612.408, train_acc: 0.171224445104599, val_acc: 0.17960087954998016


loss: 325.61895751953125: 100%|██████████| 127/127 [00:33<00:00,  3.75it/s]


[Epoch: 14] train_loss: 611.512, train_acc: 0.17442721128463745, val_acc: 0.16260161995887756


loss: 349.79119873046875: 100%|██████████| 127/127 [00:19<00:00,  6.38it/s]


[Epoch: 15] train_loss: 610.169, train_acc: 0.17738360166549683, val_acc: 0.17590539157390594


loss: 699.1937255859375: 100%|██████████| 127/127 [00:21<00:00,  5.90it/s] 


[Epoch: 16] train_loss: 610.949, train_acc: 0.1724562793970108, val_acc: 0.18107907474040985


loss: 256.6134033203125: 100%|██████████| 127/127 [00:33<00:00,  3.76it/s] 


[Epoch: 17] train_loss: 570.651, train_acc: 0.16826805472373962, val_acc: 0.1648189127445221


loss: 513.6168823242188: 100%|██████████| 127/127 [00:25<00:00,  4.91it/s] 


[Epoch: 18] train_loss: 628.704, train_acc: 0.16900715231895447, val_acc: 0.1426459699869156


loss: 583.4391479492188: 100%|██████████| 127/127 [00:33<00:00,  3.74it/s]


[Epoch: 19] train_loss: 587.180, train_acc: 0.16654348373413086, val_acc: 0.14412416517734528


loss: 484.9656066894531: 100%|██████████| 127/127 [00:33<00:00,  3.74it/s] 


[Epoch: 20] train_loss: 608.444, train_acc: 0.17467357218265533, val_acc: 0.16260161995887756


loss: 807.2803955078125: 100%|██████████| 127/127 [00:24<00:00,  5.24it/s] 


[Epoch: 21] train_loss: 631.802, train_acc: 0.17491993308067322, val_acc: 0.14486326277256012


loss: 468.96771240234375: 100%|██████████| 127/127 [00:28<00:00,  4.47it/s]


[Epoch: 22] train_loss: 597.493, train_acc: 0.1778763234615326, val_acc: 0.15077605843544006


loss: 1028.8826904296875: 100%|██████████| 127/127 [00:29<00:00,  4.25it/s]


[Epoch: 23] train_loss: 610.094, train_acc: 0.17491993308067322, val_acc: 0.15077605843544006


loss: 525.6846923828125: 100%|██████████| 127/127 [00:33<00:00,  3.74it/s] 


[Epoch: 24] train_loss: 569.237, train_acc: 0.17442721128463745, val_acc: 0.1529933512210846


loss: 762.779052734375: 100%|██████████| 127/127 [00:33<00:00,  3.74it/s]  


[Epoch: 25] train_loss: 638.026, train_acc: 0.171224445104599, val_acc: 0.1818181872367859


loss: 418.8675537109375: 100%|██████████| 127/127 [00:34<00:00,  3.73it/s] 


[Epoch: 26] train_loss: 615.132, train_acc: 0.17541266977787018, val_acc: 0.17738358676433563


loss: 614.3466186523438: 100%|██████████| 127/127 [00:34<00:00,  3.71it/s] 


[Epoch: 27] train_loss: 595.438, train_acc: 0.1685144156217575, val_acc: 0.17590539157390594


loss: 472.6477355957031: 100%|██████████| 127/127 [00:34<00:00,  3.72it/s] 


[Epoch: 28] train_loss: 608.896, train_acc: 0.15373244881629944, val_acc: 0.1685144156217575


loss: 360.24920654296875: 100%|██████████| 127/127 [00:34<00:00,  3.72it/s]


[Epoch: 29] train_loss: 606.937, train_acc: 0.17713722586631775, val_acc: 0.15373244881629944


loss: 647.0698852539062: 100%|██████████| 127/127 [00:26<00:00,  4.77it/s] 


[Epoch: 30] train_loss: 605.375, train_acc: 0.16949988901615143, val_acc: 0.17590539157390594


loss: 636.6751098632812: 100%|██████████| 127/127 [00:28<00:00,  4.53it/s] 


[Epoch: 31] train_loss: 604.648, train_acc: 0.16826805472373962, val_acc: 0.1818181872367859


loss: 719.577392578125: 100%|██████████| 127/127 [00:34<00:00,  3.73it/s]  


[Epoch: 32] train_loss: 580.350, train_acc: 0.17171718180179596, val_acc: 0.1766444891691208


loss: 569.102294921875: 100%|██████████| 127/127 [00:33<00:00,  3.74it/s]  


[Epoch: 33] train_loss: 582.186, train_acc: 0.17836906015872955, val_acc: 0.17886178195476532


loss: 709.4117431640625: 100%|██████████| 127/127 [00:20<00:00,  6.12it/s] 


[Epoch: 34] train_loss: 627.760, train_acc: 0.17319537699222565, val_acc: 0.1515151560306549


loss: 330.3431701660156: 100%|██████████| 127/127 [00:19<00:00,  6.47it/s] 


[Epoch: 35] train_loss: 602.383, train_acc: 0.17836906015872955, val_acc: 0.15594974160194397


loss: 472.9670715332031: 100%|██████████| 127/127 [00:19<00:00,  6.51it/s] 


[Epoch: 36] train_loss: 605.498, train_acc: 0.1778763234615326, val_acc: 0.18403548002243042


loss: 524.26611328125: 100%|██████████| 127/127 [00:20<00:00,  6.30it/s]   


[Epoch: 37] train_loss: 614.606, train_acc: 0.18452821671962738, val_acc: 0.1869918704032898


loss: 706.6121826171875: 100%|██████████| 127/127 [00:20<00:00,  6.31it/s] 


[Epoch: 38] train_loss: 583.481, train_acc: 0.17491993308067322, val_acc: 0.1699926108121872


loss: 493.126220703125: 100%|██████████| 127/127 [00:19<00:00,  6.43it/s]  


[Epoch: 39] train_loss: 588.783, train_acc: 0.17738360166549683, val_acc: 0.14412416517734528


loss: 584.5145263671875: 100%|██████████| 127/127 [00:19<00:00,  6.50it/s] 


[Epoch: 40] train_loss: 589.708, train_acc: 0.17294900119304657, val_acc: 0.16703622043132782


loss: 510.93023681640625: 100%|██████████| 127/127 [00:19<00:00,  6.49it/s]


[Epoch: 41] train_loss: 633.886, train_acc: 0.17319537699222565, val_acc: 0.19660013914108276


loss: 585.3780517578125: 100%|██████████| 127/127 [00:19<00:00,  6.43it/s] 


[Epoch: 42] train_loss: 605.013, train_acc: 0.17491993308067322, val_acc: 0.1566888391971588


loss: 460.39373779296875: 100%|██████████| 127/127 [00:19<00:00,  6.46it/s]


[Epoch: 43] train_loss: 607.310, train_acc: 0.18231092393398285, val_acc: 0.16186252236366272


loss: 731.5404663085938: 100%|██████████| 127/127 [00:19<00:00,  6.49it/s] 


[Epoch: 44] train_loss: 636.800, train_acc: 0.17442721128463745, val_acc: 0.17442719638347626


loss: 615.0128173828125: 100%|██████████| 127/127 [00:19<00:00,  6.48it/s] 


[Epoch: 45] train_loss: 638.146, train_acc: 0.16925351321697235, val_acc: 0.16777531802654266


loss: 595.888916015625: 100%|██████████| 127/127 [00:19<00:00,  6.48it/s]  


[Epoch: 46] train_loss: 613.021, train_acc: 0.17294900119304657, val_acc: 0.16186252236366272


loss: 684.992431640625: 100%|██████████| 127/127 [00:19<00:00,  6.47it/s]  


[Epoch: 47] train_loss: 596.926, train_acc: 0.17713722586631775, val_acc: 0.18625277280807495


loss: 791.025146484375: 100%|██████████| 127/127 [00:19<00:00,  6.48it/s]  


[Epoch: 48] train_loss: 624.735, train_acc: 0.17664450407028198, val_acc: 0.180339977145195


loss: 746.4118041992188: 100%|██████████| 127/127 [00:19<00:00,  6.52it/s] 


[Epoch: 49] train_loss: 588.217, train_acc: 0.16949988901615143, val_acc: 0.15742793679237366


loss: 689.593017578125: 100%|██████████| 127/127 [00:19<00:00,  6.45it/s]  


[Epoch: 50] train_loss: 624.816, train_acc: 0.17491993308067322, val_acc: 0.1566888391971588


loss: 587.3394775390625: 100%|██████████| 127/127 [00:19<00:00,  6.48it/s] 


[Epoch: 51] train_loss: 602.552, train_acc: 0.17960089445114136, val_acc: 0.18329638242721558


loss: 747.4696044921875: 100%|██████████| 127/127 [00:19<00:00,  6.49it/s] 


[Epoch: 52] train_loss: 619.085, train_acc: 0.17344173789024353, val_acc: 0.1855136752128601


loss: 468.6373291015625: 100%|██████████| 127/127 [00:20<00:00,  6.32it/s] 


[Epoch: 53] train_loss: 635.084, train_acc: 0.17467357218265533, val_acc: 0.1869918704032898


loss: 734.7359008789062: 100%|██████████| 127/127 [00:19<00:00,  6.44it/s] 


[Epoch: 54] train_loss: 627.482, train_acc: 0.17565903067588806, val_acc: 0.15890613198280334


loss: 932.4485473632812: 100%|██████████| 127/127 [00:19<00:00,  6.46it/s] 


[Epoch: 55] train_loss: 626.410, train_acc: 0.16087706387043, val_acc: 0.13377678394317627


loss: 658.461181640625: 100%|██████████| 127/127 [00:20<00:00,  6.27it/s]  


[Epoch: 56] train_loss: 595.361, train_acc: 0.1776299625635147, val_acc: 0.1988174468278885


loss: 791.6452026367188: 100%|██████████| 127/127 [00:19<00:00,  6.41it/s] 


[Epoch: 57] train_loss: 612.070, train_acc: 0.16949988901615143, val_acc: 0.18994826078414917


loss: 523.5369873046875: 100%|██████████| 127/127 [00:19<00:00,  6.45it/s] 


[Epoch: 58] train_loss: 624.737, train_acc: 0.16580438613891602, val_acc: 0.16925351321697235


loss: 543.563720703125: 100%|██████████| 127/127 [00:19<00:00,  6.53it/s]  


[Epoch: 59] train_loss: 616.910, train_acc: 0.1618625372648239, val_acc: 0.14560236036777496


loss: 1083.7818603515625: 100%|██████████| 127/127 [00:19<00:00,  6.49it/s]


[Epoch: 60] train_loss: 639.900, train_acc: 0.16407983005046844, val_acc: 0.1581670343875885


loss: 708.7901611328125: 100%|██████████| 127/127 [00:19<00:00,  6.47it/s] 


[Epoch: 61] train_loss: 626.308, train_acc: 0.1727026402950287, val_acc: 0.1751662939786911


loss: 745.0624389648438: 100%|██████████| 127/127 [00:19<00:00,  6.48it/s] 


[Epoch: 62] train_loss: 621.671, train_acc: 0.17196354269981384, val_acc: 0.1699926108121872


loss: 561.6641845703125: 100%|██████████| 127/127 [00:20<00:00,  6.35it/s] 


[Epoch: 63] train_loss: 589.418, train_acc: 0.17491993308067322, val_acc: 0.18847006559371948


loss: 697.4299926757812: 100%|██████████| 127/127 [00:19<00:00,  6.41it/s] 


[Epoch: 64] train_loss: 621.179, train_acc: 0.17565903067588806, val_acc: 0.18403548002243042


loss: 501.7381591796875: 100%|██████████| 127/127 [00:19<00:00,  6.49it/s] 


[Epoch: 65] train_loss: 607.686, train_acc: 0.18058635294437408, val_acc: 0.16112342476844788


loss: 367.2119140625: 100%|██████████| 127/127 [00:19<00:00,  6.47it/s]    


[Epoch: 66] train_loss: 613.957, train_acc: 0.17467357218265533, val_acc: 0.15890613198280334


loss: 495.1138916015625: 100%|██████████| 127/127 [00:19<00:00,  6.39it/s] 


[Epoch: 67] train_loss: 625.370, train_acc: 0.17491993308067322, val_acc: 0.18329638242721558


loss: 511.3877258300781: 100%|██████████| 127/127 [00:19<00:00,  6.51it/s] 


[Epoch: 68] train_loss: 655.360, train_acc: 0.16629712283611298, val_acc: 0.1685144156217575


loss: 1251.4429931640625: 100%|██████████| 127/127 [00:19<00:00,  6.50it/s]


[Epoch: 69] train_loss: 624.122, train_acc: 0.1645725667476654, val_acc: 0.1869918704032898


loss: 808.5316162109375: 100%|██████████| 127/127 [00:19<00:00,  6.49it/s] 


[Epoch: 70] train_loss: 633.528, train_acc: 0.1685144156217575, val_acc: 0.1766444891691208


loss: 531.0460815429688: 100%|██████████| 127/127 [00:19<00:00,  6.43it/s] 


[Epoch: 71] train_loss: 621.188, train_acc: 0.17048534750938416, val_acc: 0.18477457761764526


loss: 341.3890686035156: 100%|██████████| 127/127 [00:19<00:00,  6.49it/s]


[Epoch: 72] train_loss: 607.407, train_acc: 0.1791081577539444, val_acc: 0.15594974160194397


loss: 456.4604797363281: 100%|██████████| 127/127 [00:19<00:00,  6.48it/s]


[Epoch: 73] train_loss: 606.969, train_acc: 0.17615176737308502, val_acc: 0.17590539157390594


loss: 687.5516967773438: 100%|██████████| 127/127 [00:20<00:00,  6.35it/s] 


[Epoch: 74] train_loss: 623.089, train_acc: 0.17491993308067322, val_acc: 0.18255728483200073


loss: 997.7908935546875: 100%|██████████| 127/127 [00:20<00:00,  6.22it/s] 


[Epoch: 75] train_loss: 581.411, train_acc: 0.1697462499141693, val_acc: 0.1581670343875885


loss: 497.8304748535156: 100%|██████████| 127/127 [00:19<00:00,  6.37it/s] 


[Epoch: 76] train_loss: 610.982, train_acc: 0.16949988901615143, val_acc: 0.15521064400672913


loss: 707.9385375976562: 100%|██████████| 127/127 [00:20<00:00,  6.30it/s] 


[Epoch: 77] train_loss: 636.492, train_acc: 0.1727026402950287, val_acc: 0.15742793679237366


loss: 875.0949096679688: 100%|██████████| 127/127 [00:20<00:00,  6.33it/s] 


[Epoch: 78] train_loss: 612.752, train_acc: 0.18083272874355316, val_acc: 0.17073170840740204


loss: 649.0802001953125: 100%|██████████| 127/127 [00:19<00:00,  6.35it/s] 


[Epoch: 79] train_loss: 632.762, train_acc: 0.1724562793970108, val_acc: 0.18329638242721558


loss: 669.0945434570312: 100%|██████████| 127/127 [00:19<00:00,  6.38it/s] 


[Epoch: 80] train_loss: 633.035, train_acc: 0.16555802524089813, val_acc: 0.1566888391971588


loss: 642.967041015625: 100%|██████████| 127/127 [00:19<00:00,  6.37it/s] 


[Epoch: 81] train_loss: 627.354, train_acc: 0.1687607765197754, val_acc: 0.1463414579629898


loss: 496.04241943359375: 100%|██████████| 127/127 [00:19<00:00,  6.38it/s]


[Epoch: 82] train_loss: 638.527, train_acc: 0.18231092393398285, val_acc: 0.17220990359783173


loss: 518.0361938476562: 100%|██████████| 127/127 [00:19<00:00,  6.37it/s] 


[Epoch: 83] train_loss: 611.449, train_acc: 0.16580438613891602, val_acc: 0.15742793679237366


loss: 907.8805541992188: 100%|██████████| 127/127 [00:19<00:00,  6.48it/s] 


[Epoch: 84] train_loss: 594.389, train_acc: 0.1803399920463562, val_acc: 0.15225425362586975


loss: 418.5775146484375: 100%|██████████| 127/127 [00:19<00:00,  6.45it/s] 


[Epoch: 85] train_loss: 615.069, train_acc: 0.17220990359783173, val_acc: 0.1633407175540924


loss: 658.8968505859375: 100%|██████████| 127/127 [00:20<00:00,  6.33it/s] 


[Epoch: 86] train_loss: 603.023, train_acc: 0.17590540647506714, val_acc: 0.15447154641151428


loss: 701.7274169921875: 100%|██████████| 127/127 [00:19<00:00,  6.48it/s] 


[Epoch: 87] train_loss: 604.476, train_acc: 0.17073170840740204, val_acc: 0.1633407175540924


loss: 497.1729736328125: 100%|██████████| 127/127 [00:19<00:00,  6.46it/s] 


[Epoch: 88] train_loss: 590.762, train_acc: 0.18526731431484222, val_acc: 0.15447154641151428


loss: 630.9636840820312: 100%|██████████| 127/127 [00:19<00:00,  6.44it/s] 


[Epoch: 89] train_loss: 618.887, train_acc: 0.16580438613891602, val_acc: 0.1648189127445221


loss: 977.4827880859375: 100%|██████████| 127/127 [00:19<00:00,  6.46it/s] 


[Epoch: 90] train_loss: 599.360, train_acc: 0.1803399920463562, val_acc: 0.16703622043132782


loss: 726.1366577148438: 100%|██████████| 127/127 [00:19<00:00,  6.44it/s] 


[Epoch: 91] train_loss: 600.989, train_acc: 0.18255728483200073, val_acc: 0.1766444891691208


loss: 1005.5104370117188: 100%|██████████| 127/127 [00:20<00:00,  6.35it/s]


[Epoch: 92] train_loss: 627.850, train_acc: 0.1818181872367859, val_acc: 0.1648189127445221


loss: 529.6542358398438: 100%|██████████| 127/127 [00:19<00:00,  6.37it/s] 


[Epoch: 93] train_loss: 589.661, train_acc: 0.16629712283611298, val_acc: 0.1426459699869156


loss: 599.66845703125: 100%|██████████| 127/127 [00:19<00:00,  6.46it/s]   


[Epoch: 94] train_loss: 659.145, train_acc: 0.17073170840740204, val_acc: 0.1699926108121872


loss: 480.4923095703125: 100%|██████████| 127/127 [00:19<00:00,  6.44it/s] 


[Epoch: 95] train_loss: 615.233, train_acc: 0.16432619094848633, val_acc: 0.20029564201831818


loss: 522.3002319335938: 100%|██████████| 127/127 [00:19<00:00,  6.39it/s] 


[Epoch: 96] train_loss: 589.755, train_acc: 0.171224445104599, val_acc: 0.16260161995887756


loss: 511.4265441894531: 100%|██████████| 127/127 [00:19<00:00,  6.43it/s] 


[Epoch: 97] train_loss: 609.893, train_acc: 0.17442721128463745, val_acc: 0.1736880987882614


loss: 375.0612487792969: 100%|██████████| 127/127 [00:19<00:00,  6.41it/s] 


[Epoch: 98] train_loss: 609.191, train_acc: 0.17418083548545837, val_acc: 0.15225425362586975


loss: 584.48779296875: 100%|██████████| 127/127 [00:19<00:00,  6.46it/s]   


[Epoch: 99] train_loss: 639.129, train_acc: 0.1648189276456833, val_acc: 0.14929784834384918


loss: 578.4373779296875: 100%|██████████| 127/127 [00:19<00:00,  6.46it/s] 


[Epoch: 100] train_loss: 599.950, train_acc: 0.1830500215291977, val_acc: 0.1648189127445221


loss: 645.54931640625: 100%|██████████| 127/127 [00:19<00:00,  6.46it/s]   


[Epoch: 101] train_loss: 618.810, train_acc: 0.16925351321697235, val_acc: 0.16407981514930725


loss: 413.10992431640625: 100%|██████████| 127/127 [00:19<00:00,  6.40it/s]


[Epoch: 102] train_loss: 606.123, train_acc: 0.17467357218265533, val_acc: 0.16703622043132782


loss: 535.2179565429688: 100%|██████████| 127/127 [00:19<00:00,  6.41it/s] 


[Epoch: 103] train_loss: 608.559, train_acc: 0.17615176737308502, val_acc: 0.1699926108121872


loss: 612.2559814453125: 100%|██████████| 127/127 [00:19<00:00,  6.39it/s] 


[Epoch: 104] train_loss: 642.952, train_acc: 0.1685144156217575, val_acc: 0.15447154641151428


loss: 1170.2635498046875: 100%|██████████| 127/127 [00:19<00:00,  6.48it/s]


[Epoch: 105] train_loss: 646.594, train_acc: 0.1645725667476654, val_acc: 0.16555801033973694


loss: 597.742431640625: 100%|██████████| 127/127 [00:19<00:00,  6.43it/s]  


[Epoch: 106] train_loss: 619.589, train_acc: 0.1727026402950287, val_acc: 0.15373244881629944


loss: 627.3414306640625: 100%|██████████| 127/127 [00:19<00:00,  6.47it/s] 


[Epoch: 107] train_loss: 586.181, train_acc: 0.18132545053958893, val_acc: 0.17442719638347626


loss: 736.943603515625: 100%|██████████| 127/127 [00:19<00:00,  6.44it/s]  


[Epoch: 108] train_loss: 569.484, train_acc: 0.1818181872367859, val_acc: 0.1766444891691208


loss: 196.9973602294922: 100%|██████████| 127/127 [00:19<00:00,  6.48it/s] 


[Epoch: 109] train_loss: 556.485, train_acc: 0.18058635294437408, val_acc: 0.1736880987882614


loss: 959.9995727539062: 100%|██████████| 127/127 [00:19<00:00,  6.45it/s] 


[Epoch: 110] train_loss: 647.940, train_acc: 0.17812269926071167, val_acc: 0.16112342476844788


loss: 499.5235900878906: 100%|██████████| 127/127 [00:19<00:00,  6.42it/s] 


[Epoch: 111] train_loss: 593.563, train_acc: 0.17590540647506714, val_acc: 0.16703622043132782


loss: 625.3121337890625: 100%|██████████| 127/127 [00:19<00:00,  6.48it/s] 


[Epoch: 112] train_loss: 598.488, train_acc: 0.18329638242721558, val_acc: 0.17812268435955048


loss: 407.2275390625: 100%|██████████| 127/127 [00:19<00:00,  6.40it/s]    


[Epoch: 113] train_loss: 605.896, train_acc: 0.17812269926071167, val_acc: 0.14855875074863434


loss: 917.8071899414062: 100%|██████████| 127/127 [00:19<00:00,  6.43it/s] 


[Epoch: 114] train_loss: 620.826, train_acc: 0.17147080600261688, val_acc: 0.16112342476844788


loss: 446.7359313964844: 100%|██████████| 127/127 [00:19<00:00,  6.48it/s] 


[Epoch: 115] train_loss: 616.708, train_acc: 0.16802167892456055, val_acc: 0.1988174468278885


loss: 684.9019775390625: 100%|██████████| 127/127 [00:19<00:00,  6.46it/s] 


[Epoch: 116] train_loss: 616.903, train_acc: 0.17344173789024353, val_acc: 0.14412416517734528


loss: 682.1663818359375: 100%|██████████| 127/127 [00:19<00:00,  6.41it/s] 


[Epoch: 117] train_loss: 636.491, train_acc: 0.1687607765197754, val_acc: 0.1566888391971588


loss: 399.27484130859375: 100%|██████████| 127/127 [00:19<00:00,  6.43it/s]


[Epoch: 118] train_loss: 649.108, train_acc: 0.1569352149963379, val_acc: 0.13821138441562653


loss: 808.4561767578125: 100%|██████████| 127/127 [00:19<00:00,  6.43it/s] 


[Epoch: 119] train_loss: 626.130, train_acc: 0.17147080600261688, val_acc: 0.1529933512210846


loss: 583.0791625976562: 100%|██████████| 127/127 [00:19<00:00,  6.43it/s] 


[Epoch: 120] train_loss: 604.796, train_acc: 0.1763981282711029, val_acc: 0.16407981514930725


loss: 894.3540649414062: 100%|██████████| 127/127 [00:19<00:00,  6.46it/s] 


[Epoch: 121] train_loss: 618.554, train_acc: 0.1842818558216095, val_acc: 0.14929784834384918


loss: 495.89508056640625: 100%|██████████| 127/127 [00:19<00:00,  6.45it/s]


[Epoch: 122] train_loss: 626.864, train_acc: 0.16777531802654266, val_acc: 0.14190687239170074


loss: 598.428466796875: 100%|██████████| 127/127 [00:19<00:00,  6.49it/s] 


[Epoch: 123] train_loss: 605.528, train_acc: 0.17590540647506714, val_acc: 0.17220990359783173


loss: 854.9688720703125: 100%|██████████| 127/127 [00:19<00:00,  6.44it/s] 


[Epoch: 124] train_loss: 619.092, train_acc: 0.16580438613891602, val_acc: 0.1396895796060562


loss: 285.36395263671875: 100%|██████████| 127/127 [00:20<00:00,  6.29it/s]


[Epoch: 125] train_loss: 576.217, train_acc: 0.17344173789024353, val_acc: 0.16260161995887756


loss: 921.1688842773438: 100%|██████████| 127/127 [00:19<00:00,  6.42it/s] 


[Epoch: 126] train_loss: 622.994, train_acc: 0.1788617968559265, val_acc: 0.1581670343875885


loss: 575.216064453125: 100%|██████████| 127/127 [00:19<00:00,  6.45it/s]  


[Epoch: 127] train_loss: 646.564, train_acc: 0.16703622043132782, val_acc: 0.15890613198280334


loss: 473.0589294433594: 100%|██████████| 127/127 [00:20<00:00,  6.32it/s] 


[Epoch: 128] train_loss: 613.166, train_acc: 0.17097808420658112, val_acc: 0.18329638242721558


loss: 504.7379150390625: 100%|██████████| 127/127 [00:19<00:00,  6.48it/s] 


[Epoch: 129] train_loss: 636.599, train_acc: 0.17664450407028198, val_acc: 0.1736880987882614


loss: 564.6982421875: 100%|██████████| 127/127 [00:19<00:00,  6.43it/s]   


[Epoch: 130] train_loss: 623.471, train_acc: 0.16949988901615143, val_acc: 0.1596452295780182


loss: 544.5391845703125: 100%|██████████| 127/127 [00:20<00:00,  6.33it/s]


[Epoch: 131] train_loss: 599.411, train_acc: 0.16925351321697235, val_acc: 0.17738358676433563


loss: 703.2549438476562: 100%|██████████| 127/127 [00:20<00:00,  6.20it/s] 


[Epoch: 132] train_loss: 586.853, train_acc: 0.18477457761764526, val_acc: 0.15521064400672913


loss: 667.59375: 100%|██████████| 127/127 [00:19<00:00,  6.42it/s]         


[Epoch: 133] train_loss: 649.499, train_acc: 0.16531166434288025, val_acc: 0.16629712283611298


loss: 572.0355834960938: 100%|██████████| 127/127 [00:19<00:00,  6.38it/s] 


[Epoch: 134] train_loss: 610.220, train_acc: 0.17294900119304657, val_acc: 0.1566888391971588


loss: 698.9525756835938: 100%|██████████| 127/127 [00:20<00:00,  6.34it/s] 


[Epoch: 135] train_loss: 604.535, train_acc: 0.17023898661136627, val_acc: 0.15742793679237366


loss: 756.6956176757812: 100%|██████████| 127/127 [00:19<00:00,  6.37it/s] 


[Epoch: 136] train_loss: 614.032, train_acc: 0.1724562793970108, val_acc: 0.16925351321697235


loss: 943.5932006835938: 100%|██████████| 127/127 [00:19<00:00,  6.44it/s] 


[Epoch: 137] train_loss: 626.639, train_acc: 0.17590540647506714, val_acc: 0.1855136752128601


loss: 640.6350708007812: 100%|██████████| 127/127 [00:19<00:00,  6.37it/s] 


[Epoch: 138] train_loss: 595.163, train_acc: 0.17590540647506714, val_acc: 0.17886178195476532


loss: 511.9629211425781: 100%|██████████| 127/127 [00:19<00:00,  6.43it/s] 


[Epoch: 139] train_loss: 585.243, train_acc: 0.17689086496829987, val_acc: 0.16112342476844788


loss: 660.4241333007812: 100%|██████████| 127/127 [00:19<00:00,  6.40it/s]


[Epoch: 140] train_loss: 627.482, train_acc: 0.17861542105674744, val_acc: 0.15521064400672913


loss: 593.1508178710938: 100%|██████████| 127/127 [00:19<00:00,  6.40it/s] 


[Epoch: 141] train_loss: 605.296, train_acc: 0.18083272874355316, val_acc: 0.190687358379364


loss: 674.01025390625: 100%|██████████| 127/127 [00:19<00:00,  6.37it/s]   


[Epoch: 142] train_loss: 628.053, train_acc: 0.171224445104599, val_acc: 0.18329638242721558


loss: 726.2817993164062: 100%|██████████| 127/127 [00:19<00:00,  6.35it/s] 


[Epoch: 143] train_loss: 637.405, train_acc: 0.1724562793970108, val_acc: 0.18847006559371948


loss: 434.0119323730469: 100%|██████████| 127/127 [00:19<00:00,  6.38it/s] 


[Epoch: 144] train_loss: 629.569, train_acc: 0.18058635294437408, val_acc: 0.1751662939786911


loss: 663.766357421875: 100%|██████████| 127/127 [00:23<00:00,  5.33it/s]  


[Epoch: 145] train_loss: 598.285, train_acc: 0.1869918704032898, val_acc: 0.14708055555820465


loss: 532.1539916992188: 100%|██████████| 127/127 [00:20<00:00,  6.29it/s] 


[Epoch: 146] train_loss: 647.320, train_acc: 0.17319537699222565, val_acc: 0.16038432717323303


loss: 451.1240539550781: 100%|██████████| 127/127 [00:20<00:00,  6.32it/s] 


[Epoch: 147] train_loss: 629.403, train_acc: 0.17171718180179596, val_acc: 0.17220990359783173


loss: 1087.54052734375: 100%|██████████| 127/127 [00:20<00:00,  6.30it/s]  


[Epoch: 148] train_loss: 612.901, train_acc: 0.171224445104599, val_acc: 0.16555801033973694


loss: 638.0059204101562: 100%|██████████| 127/127 [00:19<00:00,  6.36it/s] 


[Epoch: 149] train_loss: 615.368, train_acc: 0.16802167892456055, val_acc: 0.1529933512210846


loss: 541.0813598632812: 100%|██████████| 127/127 [00:20<00:00,  6.30it/s] 


[Epoch: 150] train_loss: 603.220, train_acc: 0.16531166434288025, val_acc: 0.1855136752128601


loss: 557.7345581054688: 100%|██████████| 127/127 [00:20<00:00,  6.29it/s] 


[Epoch: 151] train_loss: 561.326, train_acc: 0.1842818558216095, val_acc: 0.1699926108121872


loss: 421.41058349609375: 100%|██████████| 127/127 [00:20<00:00,  6.28it/s]


[Epoch: 152] train_loss: 593.072, train_acc: 0.16752895712852478, val_acc: 0.1396895796060562


loss: 256.11041259765625: 100%|██████████| 127/127 [00:19<00:00,  6.36it/s]


[Epoch: 153] train_loss: 606.039, train_acc: 0.1660507619380951, val_acc: 0.16777531802654266


loss: 452.8116455078125: 100%|██████████| 127/127 [00:24<00:00,  5.14it/s] 


[Epoch: 154] train_loss: 621.082, train_acc: 0.1763981282711029, val_acc: 0.17960087954998016


loss: 536.7647705078125: 100%|██████████| 127/127 [00:29<00:00,  4.35it/s] 


[Epoch: 155] train_loss: 614.187, train_acc: 0.18083272874355316, val_acc: 0.1736880987882614


loss: 750.3065795898438: 100%|██████████| 127/127 [00:20<00:00,  6.17it/s] 


[Epoch: 156] train_loss: 634.467, train_acc: 0.17565903067588806, val_acc: 0.16629712283611298


loss: 914.0381469726562: 100%|██████████| 127/127 [00:31<00:00,  4.06it/s] 


[Epoch: 157] train_loss: 588.220, train_acc: 0.18871644139289856, val_acc: 0.180339977145195


loss: 487.6014099121094: 100%|██████████| 127/127 [00:34<00:00,  3.71it/s] 


[Epoch: 158] train_loss: 599.714, train_acc: 0.1842818558216095, val_acc: 0.1736880987882614


loss: 680.040771484375: 100%|██████████| 127/127 [00:34<00:00,  3.73it/s]  


[Epoch: 159] train_loss: 596.135, train_acc: 0.17442721128463745, val_acc: 0.17220990359783173


loss: 608.5690307617188: 100%|██████████| 127/127 [00:34<00:00,  3.71it/s] 


[Epoch: 160] train_loss: 606.673, train_acc: 0.17590540647506714, val_acc: 0.15225425362586975


loss: 733.914306640625: 100%|██████████| 127/127 [00:34<00:00,  3.71it/s]  


[Epoch: 161] train_loss: 613.684, train_acc: 0.17418083548545837, val_acc: 0.1936437487602234


loss: 581.142333984375: 100%|██████████| 127/127 [00:34<00:00,  3.71it/s]  


[Epoch: 162] train_loss: 616.478, train_acc: 0.1736880987882614, val_acc: 0.16038432717323303


loss: 896.1629028320312: 100%|██████████| 127/127 [00:34<00:00,  3.70it/s] 


[Epoch: 163] train_loss: 621.451, train_acc: 0.17220990359783173, val_acc: 0.17147080600261688


loss: 636.47119140625: 100%|██████████| 127/127 [00:34<00:00,  3.72it/s]   


[Epoch: 164] train_loss: 611.307, train_acc: 0.16949988901615143, val_acc: 0.1478196531534195


loss: 565.0867919921875: 100%|██████████| 127/127 [00:34<00:00,  3.73it/s] 


[Epoch: 165] train_loss: 620.407, train_acc: 0.17344173789024353, val_acc: 0.1751662939786911


loss: 340.9088439941406: 100%|██████████| 127/127 [00:34<00:00,  3.72it/s]


[Epoch: 166] train_loss: 597.466, train_acc: 0.17738360166549683, val_acc: 0.17294900119304657


loss: 520.9467163085938: 100%|██████████| 127/127 [00:34<00:00,  3.72it/s] 


[Epoch: 167] train_loss: 617.729, train_acc: 0.1763981282711029, val_acc: 0.1766444891691208


loss: 571.2045288085938: 100%|██████████| 127/127 [00:34<00:00,  3.72it/s] 


[Epoch: 168] train_loss: 642.320, train_acc: 0.17319537699222565, val_acc: 0.1478196531534195


loss: 440.2700500488281: 100%|██████████| 127/127 [00:34<00:00,  3.72it/s] 


[Epoch: 169] train_loss: 613.753, train_acc: 0.17615176737308502, val_acc: 0.14560236036777496


loss: 418.7810974121094: 100%|██████████| 127/127 [00:34<00:00,  3.72it/s] 


[Epoch: 170] train_loss: 603.576, train_acc: 0.17984725534915924, val_acc: 0.16629712283611298


loss: 686.145263671875: 100%|██████████| 127/127 [00:34<00:00,  3.72it/s]  


[Epoch: 171] train_loss: 637.455, train_acc: 0.1727026402950287, val_acc: 0.17812268435955048


loss: 396.2427978515625: 100%|██████████| 127/127 [00:30<00:00,  4.20it/s] 


[Epoch: 172] train_loss: 596.856, train_acc: 0.1778763234615326, val_acc: 0.180339977145195


loss: 479.16448974609375: 100%|██████████| 127/127 [00:33<00:00,  3.77it/s]


[Epoch: 173] train_loss: 617.548, train_acc: 0.181571826338768, val_acc: 0.18403548002243042


loss: 538.620849609375: 100%|██████████| 127/127 [00:27<00:00,  4.62it/s]  


[Epoch: 174] train_loss: 617.321, train_acc: 0.16900715231895447, val_acc: 0.14338506758213043


loss: 457.8464050292969: 100%|██████████| 127/127 [00:20<00:00,  6.32it/s] 


[Epoch: 175] train_loss: 592.748, train_acc: 0.18132545053958893, val_acc: 0.17812268435955048


loss: 1000.5257568359375: 100%|██████████| 127/127 [00:19<00:00,  6.40it/s]


[Epoch: 176] train_loss: 657.748, train_acc: 0.16260163486003876, val_acc: 0.14560236036777496


loss: 653.1442260742188: 100%|██████████| 127/127 [00:19<00:00,  6.47it/s] 


[Epoch: 177] train_loss: 600.102, train_acc: 0.17541266977787018, val_acc: 0.1581670343875885


loss: 740.8121948242188: 100%|██████████| 127/127 [00:20<00:00,  6.30it/s] 


[Epoch: 178] train_loss: 631.287, train_acc: 0.16432619094848633, val_acc: 0.16260161995887756


loss: 664.4813842773438: 100%|██████████| 127/127 [00:19<00:00,  6.39it/s] 


[Epoch: 179] train_loss: 619.233, train_acc: 0.16703622043132782, val_acc: 0.15742793679237366


loss: 611.7281494140625: 100%|██████████| 127/127 [00:19<00:00,  6.42it/s] 


[Epoch: 180] train_loss: 592.163, train_acc: 0.1776299625635147, val_acc: 0.14929784834384918


loss: 293.9974060058594: 100%|██████████| 127/127 [00:19<00:00,  6.38it/s] 


[Epoch: 181] train_loss: 587.403, train_acc: 0.1830500215291977, val_acc: 0.17812268435955048


loss: 723.8850708007812: 100%|██████████| 127/127 [00:19<00:00,  6.36it/s] 


[Epoch: 182] train_loss: 573.398, train_acc: 0.1776299625635147, val_acc: 0.16555801033973694


loss: 601.6090698242188: 100%|██████████| 127/127 [00:19<00:00,  6.38it/s] 


[Epoch: 183] train_loss: 593.881, train_acc: 0.17590540647506714, val_acc: 0.14929784834384918


loss: 685.3383178710938: 100%|██████████| 127/127 [00:19<00:00,  6.38it/s] 


[Epoch: 184] train_loss: 598.834, train_acc: 0.17861542105674744, val_acc: 0.15373244881629944


loss: 375.96185302734375: 100%|██████████| 127/127 [00:19<00:00,  6.48it/s]


[Epoch: 185] train_loss: 609.563, train_acc: 0.18058635294437408, val_acc: 0.18477457761764526


loss: 609.2258911132812: 100%|██████████| 127/127 [00:19<00:00,  6.50it/s] 


[Epoch: 186] train_loss: 650.343, train_acc: 0.1724562793970108, val_acc: 0.1699926108121872


loss: 809.82568359375: 100%|██████████| 127/127 [00:19<00:00,  6.39it/s]   


[Epoch: 187] train_loss: 598.930, train_acc: 0.1778763234615326, val_acc: 0.1515151560306549


loss: 682.0501098632812: 100%|██████████| 127/127 [00:19<00:00,  6.43it/s] 


[Epoch: 188] train_loss: 617.832, train_acc: 0.16752895712852478, val_acc: 0.17073170840740204


loss: 525.2120361328125: 100%|██████████| 127/127 [00:19<00:00,  6.40it/s] 


[Epoch: 189] train_loss: 610.593, train_acc: 0.17738360166549683, val_acc: 0.16629712283611298


loss: 455.6199645996094: 100%|██████████| 127/127 [00:19<00:00,  6.40it/s] 


[Epoch: 190] train_loss: 637.421, train_acc: 0.16752895712852478, val_acc: 0.16703622043132782


loss: 930.4839477539062: 100%|██████████| 127/127 [00:19<00:00,  6.42it/s]


[Epoch: 191] train_loss: 572.012, train_acc: 0.181571826338768, val_acc: 0.1699926108121872


loss: 515.6618041992188: 100%|██████████| 127/127 [00:19<00:00,  6.39it/s] 


[Epoch: 192] train_loss: 623.903, train_acc: 0.171224445104599, val_acc: 0.1988174468278885


loss: 709.21484375: 100%|██████████| 127/127 [00:19<00:00,  6.43it/s]      


[Epoch: 193] train_loss: 610.522, train_acc: 0.17713722586631775, val_acc: 0.14708055555820465


loss: 505.7580261230469: 100%|██████████| 127/127 [00:19<00:00,  6.39it/s] 


[Epoch: 194] train_loss: 609.100, train_acc: 0.17220990359783173, val_acc: 0.1766444891691208


loss: 435.166015625: 100%|██████████| 127/127 [00:19<00:00,  6.38it/s]     


[Epoch: 195] train_loss: 581.769, train_acc: 0.1830500215291977, val_acc: 0.17590539157390594


loss: 693.4651489257812: 100%|██████████| 127/127 [00:20<00:00,  6.32it/s] 


[Epoch: 196] train_loss: 610.285, train_acc: 0.17713722586631775, val_acc: 0.15373244881629944


loss: 397.0137939453125: 100%|██████████| 127/127 [00:20<00:00,  6.30it/s] 


[Epoch: 197] train_loss: 565.536, train_acc: 0.1882237046957016, val_acc: 0.1766444891691208


loss: 622.226806640625: 100%|██████████| 127/127 [00:20<00:00,  6.31it/s]  


[Epoch: 198] train_loss: 595.674, train_acc: 0.17713722586631775, val_acc: 0.16555801033973694


loss: 602.2128295898438: 100%|██████████| 127/127 [00:19<00:00,  6.40it/s] 


[Epoch: 199] train_loss: 604.882, train_acc: 0.1803399920463562, val_acc: 0.16925351321697235


loss: 740.0032348632812: 100%|██████████| 127/127 [00:19<00:00,  6.44it/s] 


[Epoch: 200] train_loss: 581.223, train_acc: 0.17467357218265533, val_acc: 0.15890613198280334
Finished Training


## Save/Load model.

In [6]:
model_dir = os.path.join(os.getcwd(), 'baseline_model', 'logistic_regression')
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
torch.save(best_model, os.path.join(model_dir, 'model.pth'))
best_model = torch.load(os.path.join(model_dir, 'model.pth'))

## Validate the models.

In [7]:
print(f"Accuracy on validation data: {validation(best_model, test_loader)}")

Accuracy on validation data: 0.210643008351326


In [8]:
transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

def validation2(net,image_path,transform):
    net.eval()
    correct=0
    total = 0
    image = Image.open(image_path).convert('RGB')
    inputs = transform(image)
    inputs = inputs.cuda()
    inputs = torch.unsqueeze(inputs, 0)
      
        # forward + backward + optimize
    outputs = net(inputs)

    return torch.argmax(outputs,-1)

validation2(best_model, 'IMG_0640.PNG', transform_test)

tensor([0], device='cuda:0')