In [1]:
import pandas as pd
import numpy as np
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import os
from torch.utils.data import Subset
import math
from torch.functional import F
import pennylane as qml

# Hyperparameters
batch_size = 16
load_saved_model = False
training_epoch = 150
learning_rate = 0.0008
momentum = 0.9
weight_decay = 0.000005
device = "cuda"

In [2]:
fraction_size = 0.3

root_folder = "AIDER"
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 256x256
    transforms.ToTensor(),          # Convert images to Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize with ImageNet stats
])
dataset = datasets.ImageFolder(root=root_folder, transform=transform)

# Split train and test and put in dataloader
full_dataset_size = len(dataset)
subset_size = int(fraction_size * full_dataset_size)
subset_indices = torch.randperm(full_dataset_size)[:subset_size].tolist()
subset = Subset(dataset, subset_indices)

train_size = int(0.9 * len(subset))
test_size = len(subset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(subset, [train_size, test_size])
dataloader_train = DataLoader(train_dataset, batch_size=32, shuffle=True)
dataloader_test = DataLoader(test_dataset, batch_size=32, shuffle=False)

print("Train size: ", len(train_dataset), ", Test size: ", len(test_dataset))

Train size:  1736 , Test size:  193


In [3]:
# QUANTUM BLOCK
n_qubits = 4
dev = qml.device("default.qubit", wires=n_qubits)
n_layers = 4
weight_shapes = {"weights": (n_layers, n_qubits)}
dev_quantum = torch.device(device)

# Define the quantum node
@qml.qnode(dev)
def qnode(inputs, weights):
    qml.AmplitudeEmbedding(inputs, wires=range(n_qubits), normalize=True, pad_with=0.0)
    qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]

expanded_circuit = qml.transforms.broadcast_expand(qnode)
class QNet(torch.nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.qlayer = qml.qnn.TorchLayer(expanded_circuit, weight_shapes)

    def forward(self, x):
        
        x = self.qlayer(x.to('cpu')).to(dev_quantum)  # Ensure compatibility with quantum layer

        return x

In [4]:
# from architectures.vit import ViT
from vit_pytorch import ViT
from torch import optim
from torch import nn
from torchvision import models

#self.features = models.resnet34(pretrained =True)

class QCNN(nn.Module): 
    def __init__(self) -> None:
        super(QCNN, self).__init__()
        # self.features = models.resnet34(pretrained =True)
        self.vit = ViT(
            image_size = 256,
            patch_size = 32,
            num_classes = 2,
            dim = 1024,
            depth = 10,
            heads = 16,
            mlp_dim = 2048,
            dropout = 0.1,
            emb_dropout = 0.1
        )
        self.qnet = QNet(4) # From 2 to 16
        self.seq = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Linear(16, 5)
        ) # From 16 to 5

    def forward(self, x):
        x = self.vit(x)
        x = self.qnet(x)
        x = self.seq(x)
        return x


model = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 5,
    dim = 1024,
    depth = 10,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
example = next(iter(dataloader_train))
print(model(example[0].to(device)).shape)

torch.Size([32, 5])


In [5]:
from tqdm import tqdm
from sklearn.metrics import accuracy_score, roc_auc_score, roc_curve, auc
import torch.nn.functional as F

class Engine(object):
    def __init__(self, model, optimizer, device, ema=None):
        # Initialize the Engine with the model, optimizer, and the device it's running on.
        self.model = model
        self.optimizer = optimizer
        self.device = device
        # Current epoch of training.
        self.cur_epoch = 0
        # Number of iterations the training has run.
        self.cur_iter = 0
        # The best validation epoch, used to track the epoch with the best validation performance.
        self.bestval_epoch = 0
        # Lists to track the training and validation losses.
        self.train_loss = []
        self.val_loss = []
        # Criterion for calculating loss. Here, it's Mean Squared Error Loss for regression tasks.
        self.criterion = torch.nn.CrossEntropyLoss()

    """ Block to begin training """
    def train(self, dataloader_train):
        loss_epoch = 0.
        num_batches = 0
        # Set the model to training mode.
        self.model.train()
        
        # Train loop
        # tqdm is used to display the training progress for each epoch.
        pbar = tqdm(dataloader_train, desc='Train Epoch {}'.format(self.cur_epoch))
        for data in pbar:
            # efficiently zero gradients
            # Zero the gradients before running the backward pass.
            self.optimizer.zero_grad(set_to_none=True)
            images = data[0].to(self.device, dtype=torch.float32)   # Image that will be fed into network
            gt_label = data[1].to(self.device, dtype=torch.long)

            # Pass the images through the model to get predictions.
            pred_label = self.model(images)
            
            # Calculate the loss, backpropagation, and optimization
            loss = self.criterion(pred_label, gt_label)
            loss.backward()
            
            # Perform a single optimization step (parameter update).
            self.optimizer.step()

            # Aggregate the loss for the epoch
            loss_epoch += float(loss.item())
            num_batches += 1
            pbar.set_description("Loss: {:.4f}".format(loss.item()))
            
        pbar.close()
        avg_loss = loss_epoch / num_batches
        self.train_loss.append(avg_loss)
        
        self.cur_epoch += 1
        pbar.set_description("Epoch: {}, Average Loss: {:.4f}".format(self.cur_epoch, avg_loss))

    def test(self, dataloader_test):
        # self.model.eval()  # Set the model to evaluation mode
        loss_epoch = 0.
        num_batches = 0
        
        # Prepare to collect predictions and ground truth
        predictions = []
        ground_truths = []
        
        with torch.no_grad():  # No need to calculate gradients
            pbar = tqdm(dataloader_test, desc='Test Epoch {}'.format(self.cur_epoch))
            for data in pbar:
                images = data[0].to(self.device, dtype=torch.float32)   # Image that will be fed into network
                gt_label = data[1].to(self.device, dtype=torch.long)  # GT_label

                # Pass the images through the model to get predictions.
                pred_label = self.model(images)

                # Calculate the loss, backpropagation, and optimization
                loss = self.criterion(pred_label, gt_label)
                loss_epoch += float(loss.item())
                num_batches += 1
                
                # We want to put this back on the CPU to calculate the metrics
                predictions.extend(pred_label.argmax(dim=1).cpu().numpy().flatten())
                ground_truths.extend(gt_label.cpu().numpy().flatten())
                pbar.set_description("Test Loss: {:.4f}".format(loss.item()))

        avg_loss = loss_epoch / num_batches
        self.val_loss.append(avg_loss)

        # Print the accuracy here
        accuracy = accuracy_score(ground_truths, predictions)
        
        print(f"Test Epoch: {self.cur_epoch}, Average Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        
        return avg_loss


In [6]:
trainer = Engine(model, optimizer, device, ema=None)

# Load the saved model if load_saved_model is set to True
if load_saved_model:
	model.load_state_dict(torch.load('logs/final_model.pth'))
 
# Count the total number of trainable parameters
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print ('======Total trainable parameters: ', params)

for epoch in range(trainer.cur_epoch, training_epoch):
	trainer.train(dataloader_train)

	# Test the model every 20 epochs and save it to logs folder
	if (epoch) % 10 == 0:
		trainer.test(dataloader_test)
		torch.save(model.state_dict(), os.path.join('logs', 'final_model.pth'))




Loss: 0.3416: 100%|██████████| 55/55 [00:38<00:00,  1.44it/s]
Test Loss: 0.0575: 100%|██████████| 7/7 [00:02<00:00,  3.08it/s]


Test Epoch: 1, Average Loss: 0.8695, Accuracy: 0.7150


Loss: 0.9566: 100%|██████████| 55/55 [00:38<00:00,  1.42it/s]
Loss: 0.7670: 100%|██████████| 55/55 [00:38<00:00,  1.43it/s]
Loss: 0.9525: 100%|██████████| 55/55 [00:37<00:00,  1.45it/s]
Loss: 0.5088: 100%|██████████| 55/55 [00:39<00:00,  1.39it/s]
Loss: 0.6931: 100%|██████████| 55/55 [00:38<00:00,  1.43it/s]
Loss: 1.5559: 100%|██████████| 55/55 [00:37<00:00,  1.47it/s]
Loss: 0.6527: 100%|██████████| 55/55 [00:37<00:00,  1.45it/s]
Loss: 0.5372: 100%|██████████| 55/55 [00:38<00:00,  1.43it/s]
Loss: 0.8040: 100%|██████████| 55/55 [00:38<00:00,  1.41it/s]
Loss: 1.6531: 100%|██████████| 55/55 [00:37<00:00,  1.45it/s]
Test Loss: 0.0542: 100%|██████████| 7/7 [00:01<00:00,  3.90it/s]


Test Epoch: 11, Average Loss: 0.7444, Accuracy: 0.7254


Loss: 1.4593: 100%|██████████| 55/55 [00:38<00:00,  1.44it/s]
Loss: 0.3901: 100%|██████████| 55/55 [00:38<00:00,  1.44it/s]
Loss: 1.3136: 100%|██████████| 55/55 [00:38<00:00,  1.42it/s]
Loss: 0.8036: 100%|██████████| 55/55 [00:35<00:00,  1.53it/s]
Loss: 0.5137: 100%|██████████| 55/55 [00:30<00:00,  1.80it/s]
Loss: 0.8269: 100%|██████████| 55/55 [00:33<00:00,  1.65it/s]
Loss: 1.0333: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.7129: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.9527: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.8012: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Test Loss: 0.1796: 100%|██████████| 7/7 [00:02<00:00,  3.38it/s]


Test Epoch: 21, Average Loss: 0.8468, Accuracy: 0.6891


Loss: 1.1747: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 0.9523: 100%|██████████| 55/55 [00:35<00:00,  1.57it/s]
Loss: 0.5179: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.5682: 100%|██████████| 55/55 [00:33<00:00,  1.62it/s]
Loss: 0.9054: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 1.2954: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.9948: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 1.0224: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 1.3930: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.9057: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Test Loss: 0.2432: 100%|██████████| 7/7 [00:01<00:00,  3.91it/s]


Test Epoch: 31, Average Loss: 0.8057, Accuracy: 0.7254


Loss: 1.2872: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.3721: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.9986: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 1.1062: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.8672: 100%|██████████| 55/55 [00:35<00:00,  1.53it/s]
Loss: 0.6407: 100%|██████████| 55/55 [00:35<00:00,  1.57it/s]
Loss: 0.4357: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 1.0458: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 1.1952: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 1.4250: 100%|██████████| 55/55 [00:36<00:00,  1.53it/s]
Test Loss: 0.7898: 100%|██████████| 7/7 [00:01<00:00,  3.92it/s]


Test Epoch: 41, Average Loss: 0.9868, Accuracy: 0.6269


Loss: 0.8399: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 1.3606: 100%|██████████| 55/55 [00:33<00:00,  1.62it/s]
Loss: 0.7480: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.8681: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.2427: 100%|██████████| 55/55 [00:36<00:00,  1.53it/s]
Loss: 1.3053: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 1.3819: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.9123: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.6437: 100%|██████████| 55/55 [00:36<00:00,  1.52it/s]
Loss: 1.4393: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Test Loss: 0.3726: 100%|██████████| 7/7 [00:01<00:00,  3.80it/s]


Test Epoch: 51, Average Loss: 0.8319, Accuracy: 0.7150


Loss: 1.0396: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.8518: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 1.6713: 100%|██████████| 55/55 [00:36<00:00,  1.52it/s]
Loss: 1.3636: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.6824: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.9915: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 1.0868: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 0.9961: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 0.9784: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 1.0091: 100%|██████████| 55/55 [00:34<00:00,  1.59it/s]
Test Loss: 0.3034: 100%|██████████| 7/7 [00:01<00:00,  3.59it/s]


Test Epoch: 61, Average Loss: 0.8343, Accuracy: 0.7202


Loss: 0.4285: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 0.7345: 100%|██████████| 55/55 [00:35<00:00,  1.57it/s]
Loss: 0.7821: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.7528: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 1.5001: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.6701: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.6333: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 1.5438: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 1.2509: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 0.7251: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Test Loss: 0.1694: 100%|██████████| 7/7 [00:01<00:00,  3.58it/s]


Test Epoch: 71, Average Loss: 0.7983, Accuracy: 0.7254


Loss: 0.2004: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.3786: 100%|██████████| 55/55 [00:35<00:00,  1.53it/s]
Loss: 0.8632: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 0.1991: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 0.8392: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.7492: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.7547: 100%|██████████| 55/55 [00:35<00:00,  1.53it/s]
Loss: 1.2887: 100%|██████████| 55/55 [00:34<00:00,  1.60it/s]
Loss: 0.9913: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.3240: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Test Loss: 0.2946: 100%|██████████| 7/7 [00:01<00:00,  3.96it/s]


Test Epoch: 81, Average Loss: 0.8124, Accuracy: 0.7306


Loss: 0.8076: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 1.2763: 100%|██████████| 55/55 [00:36<00:00,  1.53it/s]
Loss: 0.7955: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.9396: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 1.1826: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 1.1771: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.5826: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.9938: 100%|██████████| 55/55 [00:35<00:00,  1.53it/s]
Loss: 1.5540: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 1.3578: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Test Loss: 0.4781: 100%|██████████| 7/7 [00:01<00:00,  3.50it/s]


Test Epoch: 91, Average Loss: 0.9161, Accuracy: 0.6943


Loss: 0.8328: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 1.7198: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.9935: 100%|██████████| 55/55 [00:36<00:00,  1.52it/s]
Loss: 1.1811: 100%|██████████| 55/55 [00:35<00:00,  1.57it/s]
Loss: 0.6408: 100%|██████████| 55/55 [00:34<00:00,  1.58it/s]
Loss: 0.7480: 100%|██████████| 55/55 [00:34<00:00,  1.57it/s]
Loss: 0.2965: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.7953: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.6126: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 1.2220: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Test Loss: 0.3039: 100%|██████████| 7/7 [00:01<00:00,  3.84it/s]


Test Epoch: 101, Average Loss: 0.8355, Accuracy: 0.7202


Loss: 0.7207: 100%|██████████| 55/55 [00:36<00:00,  1.52it/s]
Loss: 1.0744: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.3260: 100%|██████████| 55/55 [00:35<00:00,  1.53it/s]
Loss: 0.7780: 100%|██████████| 55/55 [00:35<00:00,  1.57it/s]
Loss: 1.2782: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.6831: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 1.3906: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 1.4104: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.7414: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 1.0492: 100%|██████████| 55/55 [00:34<00:00,  1.57it/s]
Test Loss: 0.2267: 100%|██████████| 7/7 [00:01<00:00,  3.69it/s]


Test Epoch: 111, Average Loss: 0.8524, Accuracy: 0.7150


Loss: 0.8924: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 0.8318: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 1.2124: 100%|██████████| 55/55 [00:34<00:00,  1.62it/s]
Loss: 0.8725: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.4923: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 1.1205: 100%|██████████| 55/55 [00:35<00:00,  1.57it/s]
Loss: 1.0557: 100%|██████████| 55/55 [00:36<00:00,  1.52it/s]
Loss: 0.8521: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.6856: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 0.6970: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Test Loss: 0.3981: 100%|██████████| 7/7 [00:01<00:00,  3.70it/s]


Test Epoch: 121, Average Loss: 0.8707, Accuracy: 0.7047


Loss: 1.5052: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.5503: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 0.7069: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 0.8083: 100%|██████████| 55/55 [00:35<00:00,  1.53it/s]
Loss: 1.2534: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 1.0558: 100%|██████████| 55/55 [00:35<00:00,  1.57it/s]
Loss: 0.6509: 100%|██████████| 55/55 [00:35<00:00,  1.53it/s]
Loss: 1.0425: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 0.5612: 100%|██████████| 55/55 [00:35<00:00,  1.53it/s]
Loss: 1.1182: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Test Loss: 0.2376: 100%|██████████| 7/7 [00:01<00:00,  3.55it/s]


Test Epoch: 131, Average Loss: 0.8488, Accuracy: 0.7202


Loss: 0.8417: 100%|██████████| 55/55 [00:34<00:00,  1.61it/s]
Loss: 0.9967: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.9317: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 1.2118: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 1.4052: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 1.1020: 100%|██████████| 55/55 [00:36<00:00,  1.52it/s]
Loss: 0.3338: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 1.2415: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 1.1497: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 1.4609: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Test Loss: 0.3389: 100%|██████████| 7/7 [00:01<00:00,  3.73it/s]


Test Epoch: 141, Average Loss: 0.8958, Accuracy: 0.6943


Loss: 0.8854: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 0.8666: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 2.0633: 100%|██████████| 55/55 [00:35<00:00,  1.53it/s]
Loss: 1.6772: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 1.3364: 100%|██████████| 55/55 [00:35<00:00,  1.55it/s]
Loss: 1.1494: 100%|██████████| 55/55 [00:35<00:00,  1.53it/s]
Loss: 0.8884: 100%|██████████| 55/55 [00:35<00:00,  1.54it/s]
Loss: 0.3485: 100%|██████████| 55/55 [00:35<00:00,  1.56it/s]
Loss: 0.6096: 100%|██████████| 55/55 [00:34<00:00,  1.58it/s]
