In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from PIL import Image
import timm  # Import timm library
from tqdm import tqdm
from sklearn.utils import shuffle
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
dd = pd.read_csv("./data/data_with_only_image_dir.csv")
dd.label.value_counts()

-1    7452
 1    2373
 0    1719
Name: label, dtype: int64

In [3]:
dd_filtered = dd[dd['label'] != 0]

# Sample equal number of -1 and 1 labels
dd_negative = dd_filtered[dd_filtered['label'] == -1].sample(n=2000, random_state=42)
dd_positive = dd_filtered[dd_filtered['label'] == 1].sample(n=2000, random_state=42)

# Concatenate the samples to create the final dataset
balanced_dataset = pd.concat([dd_negative, dd_positive])

# Shuffle the dataset
balanced_dataset = balanced_dataset.sample(frac=1, random_state=42).reset_index(drop=True)

balanced_dataset['label'] = balanced_dataset['label'].replace(-1, 0)

balanced_dataset.to_csv('./data/balanced_dataset.csv', index=False)

In [4]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None, data_length=None):
        # Assuming csv_file is a CSV file with columns "image_sample" and "label"
        self.data = pd.read_csv(csv_file)
        self.data = shuffle(self.data, random_state=42)
        if data_length != None:
            data_shape = self.data.shape
            included_data, excluded_data = train_test_split(self.data, test_size=1-(data_length/data_shape[0]), random_state=42, stratify=self.data.label)
            print(included_data.label.value_counts())
            self.data = included_data
#             self.data = self.data.iloc[:data_length]
            
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 0]  # assuming the image path is in the first column
        image = Image.open(img_path).convert('RGB')

        label = int(self.data.iloc[idx, 1])  # assuming the label is in the second column

        if self.transform:
            image = self.transform(image)

        return image, label

In [11]:
# Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Split the dataset
dataset = CustomDataset(csv_file='./data/balanced_dataset.csv', transform=transform)
# train_size = int(0.8 * len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])


# Assuming you have your labels in the "label" column
labels = dataset.data['label'].values

# Use StratifiedShuffleSplit to split the dataset while maintaining class ratios
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

for train_index, val_index in sss.split(range(len(dataset)), labels):
    train_dataset = torch.utils.data.Subset(dataset, train_index)
    val_dataset = torch.utils.data.Subset(dataset, val_index)
    
# Size of the training dataset
print("Size of the training dataset:", len(train_dataset))
print("Size of the testing dataset:", len(val_dataset))

# Optionally, you can convert train_dataset and val_dataset to DataLoader if needed
batch_size = 32


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Size of the training dataset: 3200
Size of the testing dataset: 800


In [12]:
train_dataset.dataset.data.label.value_counts()

1    2000
0    2000
Name: label, dtype: int64

In [13]:
# Define the Vision Transformer Model
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# Modify the Model Head
num_classes = 2
model.head = nn.Linear(model.head.in_features, num_classes)

### Stage 1 training (3 epochs)

In [14]:
# Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

best_loss = float('inf')  # Initialize with a large value

# Training
num_epochs = 10
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

for epoch in tqdm(range(3)):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm(train_loader, desc='Epoch {}'.format(epoch), leave=False):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")
    
    # Save the model if it has the best loss so far
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), 'best_model_weights.pth')

  0%|                                                     | 0/3 [00:00<?, ?it/s]
Epoch 0:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 0:   1%|▎                                 | 1/100 [00:19<32:07, 19.47s/it][A
Epoch 0:   2%|▋                                 | 2/100 [00:40<33:22, 20.44s/it][A
Epoch 0:   3%|█                                 | 3/100 [00:58<31:05, 19.23s/it][A
Epoch 0:   4%|█▎                                | 4/100 [01:16<29:46, 18.61s/it][A
Epoch 0:   5%|█▋                                | 5/100 [01:33<28:47, 18.18s/it][A
Epoch 0:   6%|██                                | 6/100 [01:51<28:17, 18.05s/it][A
Epoch 0:   7%|██▍                               | 7/100 [02:09<27:57, 18.04s/it][A
Epoch 0:   8%|██▋                               | 8/100 [02:26<27:21, 17.84s/it][A
Epoch 0:   9%|███                               | 9/100 [02:44<26:48, 17.68s/it][A
Epoch 0:  10%|███▎                             | 10/100 [03:01<26:25, 17.62s/it

Epoch 1/3, Loss: 1.3809153044223785


 33%|█████████████▋                           | 1/3 [30:09<1:00:19, 1809.65s/it]
Epoch 1:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 1:   1%|▎                                 | 1/100 [00:18<30:12, 18.31s/it][A
Epoch 1:   2%|▋                                 | 2/100 [00:36<30:05, 18.42s/it][A
Epoch 1:   3%|█                                 | 3/100 [00:54<29:35, 18.30s/it][A
Epoch 1:   4%|█▎                                | 4/100 [01:13<29:32, 18.47s/it][A
Epoch 1:   5%|█▋                                | 5/100 [01:31<29:02, 18.34s/it][A
Epoch 1:   6%|██                                | 6/100 [01:50<28:40, 18.31s/it][A
Epoch 1:   7%|██▍                               | 7/100 [02:08<28:22, 18.31s/it][A
Epoch 1:   8%|██▋                               | 8/100 [02:26<28:08, 18.35s/it][A
Epoch 1:   9%|███                               | 9/100 [02:45<27:55, 18.42s/it][A
Epoch 1:  10%|███▎                             | 10/100 [03:03<27:36, 18.41s/it

Epoch 2/3, Loss: 0.580397612452507


 67%|███████████████████████████▎             | 2/3 [1:01:42<30:58, 1858.45s/it]
Epoch 2:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 2:   1%|▎                                 | 1/100 [00:18<30:45, 18.64s/it][A
Epoch 2:   2%|▋                                 | 2/100 [00:37<31:00, 18.98s/it][A
Epoch 2:   3%|█                                 | 3/100 [00:57<30:53, 19.11s/it][A
Epoch 2:   4%|█▎                                | 4/100 [01:16<30:32, 19.09s/it][A
Epoch 2:   5%|█▋                                | 5/100 [01:35<30:33, 19.30s/it][A
Epoch 2:   6%|██                                | 6/100 [01:55<30:21, 19.38s/it][A
Epoch 2:   7%|██▍                               | 7/100 [02:14<30:05, 19.41s/it][A
Epoch 2:   8%|██▋                               | 8/100 [02:34<29:51, 19.47s/it][A
Epoch 2:   9%|███                               | 9/100 [02:54<29:38, 19.54s/it][A
Epoch 2:  10%|███▎                             | 10/100 [03:13<29:18, 19.53s/it

Epoch 3/3, Loss: 0.5505284512042999


100%|█████████████████████████████████████████| 3/3 [1:35:33<00:00, 1911.31s/it]


In [21]:
torch.save(model.state_dict(), 'best_model_weights1.pth')

### Stage 2 training (3 epochs)

In [22]:
model.load_state_dict(torch.load('best_model_weights1.pth'))

for epoch in tqdm(range(3, 6)):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm(train_loader, desc='Epoch {}'.format(epoch), leave=False):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")
    
    # Save the model if it has the best loss so far
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), 'best_model_weights.pth')

  0%|                                                     | 0/3 [00:00<?, ?it/s]
Epoch 3:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 3:   1%|▎                                 | 1/100 [00:21<35:20, 21.42s/it][A
Epoch 3:   2%|▋                                 | 2/100 [00:39<31:26, 19.25s/it][A
Epoch 3:   3%|█                                 | 3/100 [00:58<31:01, 19.19s/it][A
Epoch 3:   4%|█▎                                | 4/100 [01:18<31:25, 19.64s/it][A
Epoch 3:   5%|█▋                                | 5/100 [01:37<30:53, 19.51s/it][A
Epoch 3:   6%|██                                | 6/100 [01:56<30:20, 19.36s/it][A
Epoch 3:   7%|██▍                               | 7/100 [02:15<29:26, 19.00s/it][A
Epoch 3:   8%|██▋                               | 8/100 [02:34<29:12, 19.05s/it][A
Epoch 3:   9%|███                               | 9/100 [02:53<29:04, 19.17s/it][A
Epoch 3:  10%|███▎                             | 10/100 [03:12<28:36, 19.08s/it

Epoch 3:  96%|███████████████████████████████▋ | 96/100 [32:31<01:24, 21.11s/it][A
Epoch 3:  97%|████████████████████████████████ | 97/100 [32:52<01:03, 21.22s/it][A
Epoch 3:  98%|████████████████████████████████▎| 98/100 [33:14<00:42, 21.28s/it][A
Epoch 3:  99%|████████████████████████████████▋| 99/100 [33:35<00:21, 21.32s/it][A
Epoch 3: 100%|████████████████████████████████| 100/100 [33:57<00:00, 21.63s/it][A
                                                                                [A

Epoch 4/10, Loss: 0.5336129349470139


 33%|█████████████▋                           | 1/3 [33:58<1:07:57, 2038.80s/it]
Epoch 4:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 4:   1%|▎                                 | 1/100 [00:22<36:31, 22.13s/it][A
Epoch 4:   2%|▋                                 | 2/100 [00:43<35:08, 21.52s/it][A
Epoch 4:   3%|█                                 | 3/100 [01:04<34:19, 21.23s/it][A
Epoch 4:   4%|█▎                                | 4/100 [01:24<33:31, 20.95s/it][A
Epoch 4:   5%|█▋                                | 5/100 [01:46<33:38, 21.25s/it][A
Epoch 4:   6%|██                                | 6/100 [02:07<33:14, 21.22s/it][A
Epoch 4:   7%|██▍                               | 7/100 [02:28<32:50, 21.19s/it][A
Epoch 4:   8%|██▋                               | 8/100 [02:49<32:28, 21.18s/it][A
Epoch 4:   9%|███                               | 9/100 [03:11<32:08, 21.19s/it][A
Epoch 4:  10%|███▎                             | 10/100 [03:32<31:46, 21.19s/it

Epoch 4:  96%|███████████████████████████████▋ | 96/100 [34:07<01:26, 21.68s/it][A
Epoch 4:  97%|████████████████████████████████ | 97/100 [34:29<01:04, 21.59s/it][A
Epoch 4:  98%|████████████████████████████████▎| 98/100 [34:50<00:43, 21.55s/it][A
Epoch 4:  99%|████████████████████████████████▋| 99/100 [35:12<00:21, 21.56s/it][A
Epoch 4: 100%|████████████████████████████████| 100/100 [35:33<00:00, 21.50s/it][A
 67%|███████████████████████████▎             | 2/3 [1:09:32<34:54, 2094.57s/it][A

Epoch 5/10, Loss: 0.5397738146781922



Epoch 5:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 5:   1%|▎                                 | 1/100 [00:21<36:11, 21.93s/it][A
Epoch 5:   2%|▋                                 | 2/100 [00:43<35:06, 21.50s/it][A
Epoch 5:   3%|█                                 | 3/100 [01:04<34:47, 21.52s/it][A
Epoch 5:   4%|█▎                                | 4/100 [01:26<34:48, 21.75s/it][A
Epoch 5:   5%|█▋                                | 5/100 [01:47<33:58, 21.46s/it][A
Epoch 5:   6%|██                                | 6/100 [02:09<33:40, 21.49s/it][A
Epoch 5:   7%|██▍                               | 7/100 [02:29<32:54, 21.24s/it][A
Epoch 5:   8%|██▋                               | 8/100 [02:51<32:31, 21.21s/it][A
Epoch 5:   9%|███                               | 9/100 [03:12<32:16, 21.28s/it][A
Epoch 5:  10%|███▎                             | 10/100 [03:33<31:46, 21.18s/it][A
Epoch 5:  11%|███▋                             | 11/100 [03:55<31:36, 21.31

Epoch 5:  97%|████████████████████████████████ | 97/100 [33:36<01:02, 20.84s/it][A
Epoch 5:  98%|████████████████████████████████▎| 98/100 [33:57<00:41, 20.89s/it][A
Epoch 5:  99%|████████████████████████████████▋| 99/100 [34:17<00:20, 20.79s/it][A
Epoch 5: 100%|████████████████████████████████| 100/100 [34:38<00:00, 20.77s/it][A
100%|█████████████████████████████████████████| 3/3 [1:44:10<00:00, 2083.58s/it][A

Epoch 6/10, Loss: 0.6201720744371414





In [23]:
torch.save(model.state_dict(), 'best_model_weights2.pth')

### Stage 3 training (4 epochs)

In [32]:
model.load_state_dict(torch.load('best_model_weights2.pth'))

for epoch in tqdm(range(6, num_epochs)):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm(train_loader, desc='Epoch {}'.format(epoch), leave=False):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")
    
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in tqdm(val_loader):

            outputs = model(images)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    
    # Calculate accuracy
    accuracy = accuracy_score(all_labels, all_preds)
    print(f"Epoch {epoch}, Test Accuracy: {accuracy * 100:.4f}%")
    
    # Save the model if it has the best loss so far
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), 'best_model_weights.pth')

  0%|                                                     | 0/4 [00:00<?, ?it/s]
Epoch 6:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 6:   1%|▎                                 | 1/100 [00:18<30:01, 18.20s/it][A
Epoch 6:   2%|▋                                 | 2/100 [00:37<30:41, 18.79s/it][A
Epoch 6:   3%|█                                 | 3/100 [00:57<31:35, 19.54s/it][A
Epoch 6:   4%|█▎                                | 4/100 [01:18<31:59, 20.00s/it][A
Epoch 6:   5%|█▋                                | 5/100 [01:39<32:02, 20.24s/it][A
Epoch 6:   6%|██                                | 6/100 [02:00<32:01, 20.44s/it][A
Epoch 6:   7%|██▍                               | 7/100 [02:20<31:53, 20.57s/it][A
Epoch 6:   8%|██▋                               | 8/100 [02:42<31:54, 20.81s/it][A
Epoch 6:   9%|███                               | 9/100 [03:03<31:52, 21.02s/it][A
Epoch 6:  10%|███▎                             | 10/100 [03:24<31:33, 21.04s/it

Epoch 6:  96%|███████████████████████████████▋ | 96/100 [35:27<01:33, 23.34s/it][A
Epoch 6:  97%|████████████████████████████████ | 97/100 [35:51<01:10, 23.59s/it][A
Epoch 6:  98%|████████████████████████████████▎| 98/100 [36:16<00:47, 23.98s/it][A
Epoch 6:  99%|████████████████████████████████▋| 99/100 [36:41<00:24, 24.11s/it][A
Epoch 6: 100%|████████████████████████████████| 100/100 [37:05<00:00, 24.10s/it][A
                                                                                [A

Epoch 7/10, Loss: 0.5898178651928901



  0%|                                                    | 0/25 [00:00<?, ?it/s][A
  4%|█▊                                          | 1/25 [00:06<02:34,  6.44s/it][A
  8%|███▌                                        | 2/25 [00:13<02:33,  6.67s/it][A
 12%|█████▎                                      | 3/25 [00:19<02:27,  6.69s/it][A
 16%|███████                                     | 4/25 [00:26<02:19,  6.67s/it][A
 20%|████████▊                                   | 5/25 [00:33<02:13,  6.66s/it][A
 24%|██████████▌                                 | 6/25 [00:39<02:05,  6.58s/it][A
 28%|████████████▎                               | 7/25 [00:46<01:58,  6.60s/it][A
 32%|██████████████                              | 8/25 [00:52<01:52,  6.61s/it][A
 36%|███████████████▊                            | 9/25 [00:59<01:45,  6.61s/it][A
 40%|█████████████████▏                         | 10/25 [01:06<01:38,  6.58s/it][A
 44%|██████████████████▉                        | 11/25 [01:12<01:31,  6.56

Epoch 6, Test Accuracy: 73.6250%



Epoch 7:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 7:   1%|▎                                 | 1/100 [00:25<41:52, 25.38s/it][A
Epoch 7:   2%|▋                                 | 2/100 [00:49<40:01, 24.51s/it][A
Epoch 7:   3%|█                                 | 3/100 [01:12<38:54, 24.07s/it][A
Epoch 7:   4%|█▎                                | 4/100 [01:36<37:57, 23.72s/it][A
Epoch 7:   5%|█▋                                | 5/100 [01:58<37:01, 23.38s/it][A
Epoch 7:   6%|██                                | 6/100 [02:22<36:52, 23.54s/it][A
Epoch 7:   7%|██▍                               | 7/100 [02:45<36:20, 23.45s/it][A
Epoch 7:   8%|██▋                               | 8/100 [03:08<35:43, 23.30s/it][A
Epoch 7:   9%|███                               | 9/100 [03:31<35:07, 23.16s/it][A
Epoch 7:  10%|███▎                             | 10/100 [03:55<35:00, 23.34s/it][A
Epoch 7:  11%|███▋                             | 11/100 [04:19<34:53, 23.53

Epoch 7:  97%|████████████████████████████████ | 97/100 [39:01<01:15, 25.03s/it][A
Epoch 7:  98%|████████████████████████████████▎| 98/100 [39:26<00:50, 25.04s/it][A
Epoch 7:  99%|████████████████████████████████▋| 99/100 [39:51<00:24, 24.89s/it][A
Epoch 7: 100%|████████████████████████████████| 100/100 [40:15<00:00, 24.65s/it][A
                                                                                [A

Epoch 8/10, Loss: 0.540996582210064



  0%|                                                    | 0/25 [00:00<?, ?it/s][A
  4%|█▊                                          | 1/25 [00:06<02:27,  6.16s/it][A
  8%|███▌                                        | 2/25 [00:12<02:27,  6.41s/it][A
 12%|█████▎                                      | 3/25 [00:19<02:20,  6.40s/it][A
 16%|███████                                     | 4/25 [00:25<02:16,  6.50s/it][A
 20%|████████▊                                   | 5/25 [00:32<02:10,  6.51s/it][A
 24%|██████████▌                                 | 6/25 [00:39<02:05,  6.59s/it][A
 28%|████████████▎                               | 7/25 [00:45<01:59,  6.64s/it][A
 32%|██████████████                              | 8/25 [00:52<01:52,  6.60s/it][A
 36%|███████████████▊                            | 9/25 [00:58<01:45,  6.62s/it][A
 40%|█████████████████▏                         | 10/25 [01:05<01:38,  6.58s/it][A
 44%|██████████████████▉                        | 11/25 [01:12<01:32,  6.63

Epoch 7, Test Accuracy: 76.5000%



Epoch 8:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 8:   1%|▎                                 | 1/100 [00:24<40:42, 24.67s/it][A
Epoch 8:   2%|▋                                 | 2/100 [00:49<40:03, 24.52s/it][A
Epoch 8:   3%|█                                 | 3/100 [01:13<39:29, 24.43s/it][A
Epoch 8:   4%|█▎                                | 4/100 [01:37<39:01, 24.40s/it][A
Epoch 8:   5%|█▋                                | 5/100 [02:02<38:53, 24.57s/it][A
Epoch 8:   6%|██                                | 6/100 [02:27<38:41, 24.70s/it][A
Epoch 8:   7%|██▍                               | 7/100 [02:51<37:57, 24.49s/it][A
Epoch 8:   8%|██▋                               | 8/100 [03:16<37:36, 24.53s/it][A
Epoch 8:   9%|███                               | 9/100 [03:40<36:51, 24.31s/it][A
Epoch 8:  10%|███▎                             | 10/100 [04:04<36:42, 24.47s/it][A
Epoch 8:  11%|███▋                             | 11/100 [04:29<36:34, 24.66

Epoch 8:  97%|████████████████████████████████ | 97/100 [42:19<01:20, 26.85s/it][A
Epoch 8:  98%|████████████████████████████████▎| 98/100 [42:46<00:53, 26.90s/it][A
Epoch 8:  99%|████████████████████████████████▋| 99/100 [43:13<00:26, 26.94s/it][A
Epoch 8: 100%|████████████████████████████████| 100/100 [43:41<00:00, 27.08s/it][A
                                                                                [A

Epoch 9/10, Loss: 0.5063545736670494



  0%|                                                    | 0/25 [00:00<?, ?it/s][A
  4%|█▊                                          | 1/25 [00:06<02:34,  6.44s/it][A
  8%|███▌                                        | 2/25 [00:13<02:42,  7.08s/it][A
 12%|█████▎                                      | 3/25 [00:21<02:41,  7.34s/it][A
 16%|███████                                     | 4/25 [00:28<02:29,  7.11s/it][A
 20%|████████▊                                   | 5/25 [00:35<02:19,  6.98s/it][A
 24%|██████████▌                                 | 6/25 [00:41<02:09,  6.84s/it][A
 28%|████████████▎                               | 7/25 [00:48<02:02,  6.79s/it][A
 32%|██████████████                              | 8/25 [00:54<01:54,  6.72s/it][A
 36%|███████████████▊                            | 9/25 [01:01<01:47,  6.73s/it][A
 40%|█████████████████▏                         | 10/25 [01:08<01:41,  6.75s/it][A
 44%|██████████████████▉                        | 11/25 [01:15<01:34,  6.73

Epoch 8, Test Accuracy: 71.1250%


 75%|██████████████████████████████▊          | 3/4 [2:09:19<43:53, 2633.32s/it]
Epoch 9:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 9:   1%|▎                                 | 1/100 [00:28<47:28, 28.77s/it][A
Epoch 9:   2%|▋                                 | 2/100 [00:56<45:53, 28.09s/it][A
Epoch 9:   3%|█                                 | 3/100 [01:23<44:43, 27.67s/it][A
Epoch 9:   4%|█▎                                | 4/100 [01:50<43:46, 27.36s/it][A
Epoch 9:   5%|█▋                                | 5/100 [02:18<43:41, 27.59s/it][A
Epoch 9:   6%|██                                | 6/100 [02:46<43:18, 27.65s/it][A
Epoch 9:   7%|██▍                               | 7/100 [03:12<42:24, 27.36s/it][A
Epoch 9:   8%|██▋                               | 8/100 [03:40<42:09, 27.50s/it][A
Epoch 9:   9%|███                               | 9/100 [04:08<41:37, 27.44s/it][A
Epoch 9:  10%|███▎                             | 10/100 [04:35<41:03, 27.37s/it

Epoch 9:  96%|███████████████████████████████▋ | 96/100 [43:51<01:48, 27.16s/it][A
Epoch 9:  97%|████████████████████████████████ | 97/100 [44:18<01:21, 27.02s/it][A
Epoch 9:  98%|████████████████████████████████▎| 98/100 [44:45<00:53, 26.94s/it][A
Epoch 9:  99%|████████████████████████████████▋| 99/100 [45:11<00:26, 26.89s/it][A
Epoch 9: 100%|████████████████████████████████| 100/100 [45:39<00:00, 26.99s/it][A
                                                                                [A

Epoch 10/10, Loss: 0.5107413464784623



  0%|                                                    | 0/25 [00:00<?, ?it/s][A
  4%|█▊                                          | 1/25 [00:05<02:22,  5.95s/it][A
  8%|███▌                                        | 2/25 [00:12<02:21,  6.16s/it][A
 12%|█████▎                                      | 3/25 [00:18<02:18,  6.29s/it][A
 16%|███████                                     | 4/25 [00:24<02:11,  6.28s/it][A
 20%|████████▊                                   | 5/25 [00:31<02:05,  6.28s/it][A
 24%|██████████▌                                 | 6/25 [00:37<01:58,  6.26s/it][A
 28%|████████████▎                               | 7/25 [00:43<01:54,  6.35s/it][A
 32%|██████████████                              | 8/25 [00:50<01:46,  6.29s/it][A
 36%|███████████████▊                            | 9/25 [00:56<01:40,  6.29s/it][A
 40%|█████████████████▏                         | 10/25 [01:02<01:34,  6.31s/it][A
 44%|██████████████████▉                        | 11/25 [01:09<01:28,  6.29

Epoch 9, Test Accuracy: 67.3750%





In [33]:
torch.save(model.state_dict(), 'best_model_weights3.pth')

In [34]:
print(best_loss)

0.5063545736670494


In [35]:
# Evaluation

# Load the saved weights
model.load_state_dict(torch.load('best_model_weights3.pth'))

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(val_loader):

        outputs = model(images)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

100%|███████████████████████████████████████████| 25/25 [02:49<00:00,  6.79s/it]


In [36]:

# Calculate accuracy
accuracy = accuracy_score(all_labels, all_preds)
print(f"Test Accuracy: {accuracy * 100:.4f}%")

Test Accuracy: 67.3750%
