In [21]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from PIL import Image
import timm  # Import timm library
from tqdm import tqdm
from sklearn.utils import shuffle
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score
import pandas as pd
from sklearn.model_selection import train_test_split

In [30]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None, data_length=None):
        # Assuming csv_file is a CSV file with columns "image_sample" and "label"
        self.data = pd.read_csv(csv_file)
        self.data = shuffle(self.data, random_state=42)
        if data_length != None:
            data_shape = self.data.shape
            included_data, excluded_data = train_test_split(self.data, test_size=1-(data_length/data_shape[0]), random_state=42, stratify=self.data.label)
            print(included_data.label.value_counts())
            self.data = included_data
#             self.data = self.data.iloc[:data_length]
            
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 0]  # assuming the image path is in the first column
        image = Image.open(img_path).convert('RGB')

        label = int(self.data.iloc[idx, 1])  # assuming the label is in the second column
        label = label + 1

        if self.transform:
            image = self.transform(image)

        return image, label

In [31]:
# Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Split the dataset
dataset = CustomDataset(csv_file='./data/data_with_only_image_dir.csv', transform=transform, data_length = 2000)
# train_size = int(0.8 * len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])


# Assuming you have your labels in the "label" column
labels = dataset.data['label'].values

# Use StratifiedShuffleSplit to split the dataset while maintaining class ratios
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

for train_index, val_index in sss.split(range(len(dataset)), labels):
    train_dataset = torch.utils.data.Subset(dataset, train_index)
    val_dataset = torch.utils.data.Subset(dataset, val_index)

# Optionally, you can convert train_dataset and val_dataset to DataLoader if needed
batch_size = 16


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

-1    1291
 1     411
 0     298
Name: label, dtype: int64


In [45]:
train_dataset.dataset.data.label.value_counts()

-1    1291
 1     411
 0     298
Name: label, dtype: int64

In [32]:
# Define the Vision Transformer Model
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# Modify the Model Head
num_classes = 3
model.head = nn.Linear(model.head.in_features, num_classes)

In [33]:
# Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

best_loss = float('inf')  # Initialize with a large value

# Training
num_epochs = 10
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

for epoch in tqdm(range(num_epochs)):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm(train_loader, desc='Epoch {}'.format(epoch), leave=False):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")
    
    # Save the model if it has the best loss so far
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), 'best_model_weights.pth')

  0%|                                                    | 0/10 [00:00<?, ?it/s]
Epoch 0:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 0:   1%|▎                                 | 1/100 [00:08<13:56,  8.45s/it][A
Epoch 0:   2%|▋                                 | 2/100 [00:17<14:23,  8.81s/it][A
Epoch 0:   3%|█                                 | 3/100 [00:26<14:16,  8.83s/it][A
Epoch 0:   4%|█▎                                | 4/100 [00:35<14:00,  8.76s/it][A
Epoch 0:   5%|█▋                                | 5/100 [00:43<13:48,  8.72s/it][A
Epoch 0:   6%|██                                | 6/100 [00:52<13:44,  8.77s/it][A
Epoch 0:   7%|██▍                               | 7/100 [01:01<13:37,  8.79s/it][A
Epoch 0:   8%|██▋                               | 8/100 [01:09<13:23,  8.74s/it][A
Epoch 0:   9%|███                               | 9/100 [01:19<13:27,  8.87s/it][A
Epoch 0:  10%|███▎                             | 10/100 [01:28<13:23,  8.92s/it

Epoch 0:  96%|███████████████████████████████▋ | 96/100 [14:41<00:37,  9.49s/it][A
Epoch 0:  97%|████████████████████████████████ | 97/100 [14:51<00:29,  9.77s/it][A
Epoch 0:  98%|████████████████████████████████▎| 98/100 [15:01<00:19,  9.75s/it][A
Epoch 0:  99%|████████████████████████████████▋| 99/100 [15:10<00:09,  9.75s/it][A
Epoch 0: 100%|████████████████████████████████| 100/100 [15:20<00:00,  9.68s/it][A
                                                                                [A

Epoch 1/10, Loss: 1.3077141878008842


 10%|████                                     | 1/10 [15:21<2:18:12, 921.35s/it]
Epoch 1:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 1:   1%|▎                                 | 1/100 [00:09<15:50,  9.60s/it][A
Epoch 1:   2%|▋                                 | 2/100 [00:18<15:23,  9.43s/it][A
Epoch 1:   3%|█                                 | 3/100 [00:28<15:30,  9.59s/it][A
Epoch 1:   4%|█▎                                | 4/100 [00:38<15:24,  9.63s/it][A
Epoch 1:   5%|█▋                                | 5/100 [00:48<15:14,  9.63s/it][A
Epoch 1:   6%|██                                | 6/100 [00:57<15:03,  9.61s/it][A
Epoch 1:   7%|██▍                               | 7/100 [01:07<15:04,  9.72s/it][A
Epoch 1:   8%|██▋                               | 8/100 [01:17<14:54,  9.73s/it][A
Epoch 1:   9%|███                               | 9/100 [01:26<14:39,  9.66s/it][A
Epoch 1:  10%|███▎                             | 10/100 [01:36<14:42,  9.80s/it

Epoch 1:  96%|███████████████████████████████▋ | 96/100 [15:14<00:37,  9.36s/it][A
Epoch 1:  97%|████████████████████████████████ | 97/100 [15:23<00:28,  9.37s/it][A
Epoch 1:  98%|████████████████████████████████▎| 98/100 [15:33<00:18,  9.34s/it][A
Epoch 1:  99%|████████████████████████████████▋| 99/100 [15:42<00:09,  9.38s/it][A
Epoch 1: 100%|████████████████████████████████| 100/100 [15:51<00:00,  9.29s/it][A
                                                                                [A

Epoch 2/10, Loss: 0.8046485349535942


 20%|████████▏                                | 2/10 [31:13<2:05:17, 939.73s/it]
Epoch 2:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 2:   1%|▎                                 | 1/100 [00:09<15:07,  9.17s/it][A
Epoch 2:   2%|▋                                 | 2/100 [00:18<15:03,  9.22s/it][A
Epoch 2:   3%|█                                 | 3/100 [00:27<14:59,  9.28s/it][A
Epoch 2:   4%|█▎                                | 4/100 [00:37<14:56,  9.34s/it][A
Epoch 2:   5%|█▋                                | 5/100 [00:46<14:38,  9.24s/it][A
Epoch 2:   6%|██                                | 6/100 [00:55<14:32,  9.29s/it][A
Epoch 2:   7%|██▍                               | 7/100 [01:04<14:23,  9.29s/it][A
Epoch 2:   8%|██▋                               | 8/100 [01:14<14:18,  9.33s/it][A
Epoch 2:   9%|███                               | 9/100 [01:23<14:08,  9.32s/it][A
Epoch 2:  10%|███▎                             | 10/100 [01:32<13:56,  9.30s/it

Epoch 2:  96%|███████████████████████████████▋ | 96/100 [15:47<00:41, 10.25s/it][A
Epoch 2:  97%|████████████████████████████████ | 97/100 [15:58<00:30, 10.26s/it][A
Epoch 2:  98%|████████████████████████████████▎| 98/100 [16:08<00:20, 10.33s/it][A
Epoch 2:  99%|████████████████████████████████▋| 99/100 [16:19<00:10, 10.52s/it][A
Epoch 2: 100%|████████████████████████████████| 100/100 [16:30<00:00, 10.52s/it][A
                                                                                [A

Epoch 3/10, Loss: 0.7578918114304543


 30%|████████████▎                            | 3/10 [47:44<1:52:21, 963.13s/it]
Epoch 3:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 3:   1%|▎                                 | 1/100 [00:10<16:53, 10.23s/it][A
Epoch 3:   2%|▋                                 | 2/100 [00:20<17:13, 10.54s/it][A
Epoch 3:   3%|█                                 | 3/100 [00:31<17:07, 10.59s/it][A
Epoch 3:   4%|█▎                                | 4/100 [00:42<16:51, 10.54s/it][A
Epoch 3:   5%|█▋                                | 5/100 [00:52<16:36, 10.49s/it][A
Epoch 3:   6%|██                                | 6/100 [01:02<16:10, 10.33s/it][A
Epoch 3:   7%|██▍                               | 7/100 [01:13<16:16, 10.50s/it][A
Epoch 3:   8%|██▋                               | 8/100 [01:23<16:05, 10.50s/it][A
Epoch 3:   9%|███                               | 9/100 [01:34<15:52, 10.47s/it][A
Epoch 3:  10%|███▎                             | 10/100 [01:44<15:39, 10.44s/it

Epoch 3:  96%|███████████████████████████████▋ | 96/100 [17:21<00:44, 11.19s/it][A
Epoch 3:  97%|████████████████████████████████ | 97/100 [17:33<00:33, 11.22s/it][A
Epoch 3:  98%|████████████████████████████████▎| 98/100 [17:44<00:22, 11.13s/it][A
Epoch 3:  99%|████████████████████████████████▋| 99/100 [17:54<00:10, 10.93s/it][A
Epoch 3: 100%|████████████████████████████████| 100/100 [18:06<00:00, 11.14s/it][A
                                                                                [A

Epoch 4/10, Loss: 0.7275749033689499


 40%|███████████████▏                      | 4/10 [1:05:52<1:41:12, 1012.10s/it]
Epoch 4:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 4:   1%|▎                                 | 1/100 [00:10<17:39, 10.70s/it][A
Epoch 4:   2%|▋                                 | 2/100 [00:21<17:24, 10.66s/it][A
Epoch 4:   3%|█                                 | 3/100 [00:32<17:17, 10.70s/it][A
Epoch 4:   4%|█▎                                | 4/100 [00:43<17:27, 10.91s/it][A
Epoch 4:   5%|█▋                                | 5/100 [00:53<17:08, 10.83s/it][A
Epoch 4:   6%|██                                | 6/100 [01:04<16:51, 10.77s/it][A
Epoch 4:   7%|██▍                               | 7/100 [01:14<16:25, 10.60s/it][A
Epoch 4:   8%|██▋                               | 8/100 [01:25<16:21, 10.67s/it][A
Epoch 4:   9%|███                               | 9/100 [01:36<16:07, 10.63s/it][A
Epoch 4:  10%|███▎                             | 10/100 [01:47<16:06, 10.74s/it

Epoch 4:  96%|███████████████████████████████▋ | 96/100 [18:38<00:47, 11.80s/it][A
Epoch 4:  97%|████████████████████████████████ | 97/100 [18:50<00:35, 11.80s/it][A
Epoch 4:  98%|████████████████████████████████▎| 98/100 [19:01<00:23, 11.70s/it][A
Epoch 4:  99%|████████████████████████████████▋| 99/100 [19:13<00:11, 11.79s/it][A
Epoch 4: 100%|████████████████████████████████| 100/100 [19:25<00:00, 11.77s/it][A
                                                                                [A

Epoch 5/10, Loss: 0.6978789108991623


 50%|███████████████████                   | 5/10 [1:25:18<1:28:58, 1067.69s/it]
Epoch 5:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 5:   1%|▎                                 | 1/100 [00:11<19:12, 11.65s/it][A
Epoch 5:   2%|▋                                 | 2/100 [00:24<20:17, 12.43s/it][A
Epoch 5:   3%|█                                 | 3/100 [00:37<20:24, 12.63s/it][A
Epoch 5:   4%|█▎                                | 4/100 [00:51<20:53, 13.05s/it][A
Epoch 5:   5%|█▋                                | 5/100 [01:03<20:27, 12.92s/it][A
Epoch 5:   6%|██                                | 6/100 [01:17<20:31, 13.10s/it][A
Epoch 5:   7%|██▍                               | 7/100 [01:29<19:47, 12.77s/it][A
Epoch 5:   8%|██▋                               | 8/100 [01:40<18:59, 12.39s/it][A
Epoch 5:   9%|███                               | 9/100 [01:53<18:37, 12.28s/it][A
Epoch 5:  10%|███▎                             | 10/100 [02:04<18:14, 12.17s/it

Epoch 5:  96%|███████████████████████████████▋ | 96/100 [20:09<00:48, 12.21s/it][A
Epoch 5:  97%|████████████████████████████████ | 97/100 [20:21<00:36, 12.24s/it][A
Epoch 5:  98%|████████████████████████████████▎| 98/100 [20:33<00:24, 12.18s/it][A
Epoch 5:  99%|████████████████████████████████▋| 99/100 [20:46<00:12, 12.35s/it][A
Epoch 5: 100%|████████████████████████████████| 100/100 [20:58<00:00, 12.29s/it][A
                                                                                [A

Epoch 6/10, Loss: 0.6906272108852863


 60%|██████████████████████▊               | 6/10 [1:46:17<1:15:31, 1132.83s/it]
Epoch 6:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 6:   1%|▎                                 | 1/100 [00:13<22:12, 13.46s/it][A
Epoch 6:   2%|▋                                 | 2/100 [00:25<20:46, 12.72s/it][A
Epoch 6:   3%|█                                 | 3/100 [00:38<20:21, 12.60s/it][A
Epoch 6:   4%|█▎                                | 4/100 [00:51<20:25, 12.76s/it][A
Epoch 6:   5%|█▋                                | 5/100 [01:03<19:51, 12.54s/it][A
Epoch 6:   6%|██                                | 6/100 [01:15<19:17, 12.32s/it][A
Epoch 6:   7%|██▍                               | 7/100 [01:27<19:06, 12.33s/it][A
Epoch 6:   8%|██▋                               | 8/100 [01:40<19:02, 12.42s/it][A
Epoch 6:   9%|███                               | 9/100 [01:52<18:44, 12.36s/it][A
Epoch 6:  10%|███▎                             | 10/100 [02:04<18:36, 12.40s/it

Epoch 6:  96%|███████████████████████████████▋ | 96/100 [19:44<00:48, 12.03s/it][A
Epoch 6:  97%|████████████████████████████████ | 97/100 [19:56<00:36, 12.02s/it][A
Epoch 6:  98%|████████████████████████████████▎| 98/100 [20:08<00:23, 11.97s/it][A
Epoch 6:  99%|████████████████████████████████▋| 99/100 [20:21<00:12, 12.28s/it][A
Epoch 6: 100%|████████████████████████████████| 100/100 [20:34<00:00, 12.53s/it][A
 70%|████████████████████████████            | 7/10 [2:06:51<58:18, 1166.00s/it][A

Epoch 7/10, Loss: 0.7045515233278274



Epoch 7:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 7:   1%|▎                                 | 1/100 [00:12<21:20, 12.93s/it][A
Epoch 7:   2%|▋                                 | 2/100 [00:25<20:55, 12.81s/it][A
Epoch 7:   3%|█                                 | 3/100 [00:37<20:01, 12.39s/it][A
Epoch 7:   4%|█▎                                | 4/100 [00:49<19:46, 12.36s/it][A
Epoch 7:   5%|█▋                                | 5/100 [01:02<19:39, 12.42s/it][A
Epoch 7:   6%|██                                | 6/100 [01:14<19:17, 12.32s/it][A
Epoch 7:   7%|██▍                               | 7/100 [01:26<19:02, 12.28s/it][A
Epoch 7:   8%|██▋                               | 8/100 [01:38<18:47, 12.26s/it][A
Epoch 7:   9%|███                               | 9/100 [01:51<18:46, 12.38s/it][A
Epoch 7:  10%|███▎                             | 10/100 [02:04<19:01, 12.68s/it][A
Epoch 7:  11%|███▋                             | 11/100 [02:17<18:38, 12.57

Epoch 7:  97%|████████████████████████████████ | 97/100 [21:25<00:39, 13.26s/it][A
Epoch 7:  98%|████████████████████████████████▎| 98/100 [21:38<00:26, 13.16s/it][A
Epoch 7:  99%|████████████████████████████████▋| 99/100 [21:50<00:13, 13.01s/it][A
Epoch 7: 100%|████████████████████████████████| 100/100 [22:04<00:00, 13.17s/it][A
                                                                                [A

Epoch 8/10, Loss: 0.6753486797213555


 80%|████████████████████████████████        | 8/10 [2:28:57<40:33, 1216.68s/it]
Epoch 8:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 8:   1%|▎                                 | 1/100 [00:12<20:13, 12.26s/it][A
Epoch 8:   2%|▋                                 | 2/100 [00:24<20:16, 12.42s/it][A
Epoch 8:   3%|█                                 | 3/100 [00:36<19:54, 12.32s/it][A
Epoch 8:   4%|█▎                                | 4/100 [00:49<19:56, 12.46s/it][A
Epoch 8:   5%|█▋                                | 5/100 [01:01<19:36, 12.39s/it][A
Epoch 8:   6%|██                                | 6/100 [01:14<19:28, 12.43s/it][A
Epoch 8:   7%|██▍                               | 7/100 [01:27<19:43, 12.73s/it][A
Epoch 8:   8%|██▋                               | 8/100 [01:40<19:45, 12.88s/it][A
Epoch 8:   9%|███                               | 9/100 [01:54<19:50, 13.08s/it][A
Epoch 8:  10%|███▎                             | 10/100 [02:06<19:20, 12.90s/it

Epoch 8:  96%|███████████████████████████████▋ | 96/100 [22:17<00:58, 14.59s/it][A
Epoch 8:  97%|████████████████████████████████ | 97/100 [22:32<00:43, 14.52s/it][A
Epoch 8:  98%|████████████████████████████████▎| 98/100 [22:45<00:28, 14.26s/it][A
Epoch 8:  99%|████████████████████████████████▋| 99/100 [22:59<00:14, 14.21s/it][A
Epoch 8: 100%|████████████████████████████████| 100/100 [23:14<00:00, 14.24s/it][A
 90%|████████████████████████████████████    | 9/10 [2:52:11<21:12, 1272.18s/it][A

Epoch 9/10, Loss: 0.7024641272425651



Epoch 9:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 9:   1%|▎                                 | 1/100 [00:13<22:04, 13.38s/it][A
Epoch 9:   2%|▋                                 | 2/100 [00:26<21:44, 13.31s/it][A
Epoch 9:   3%|█                                 | 3/100 [00:41<22:43, 14.05s/it][A
Epoch 9:   4%|█▎                                | 4/100 [00:56<22:49, 14.26s/it][A
Epoch 9:   5%|█▋                                | 5/100 [01:11<22:57, 14.50s/it][A
Epoch 9:   6%|██                                | 6/100 [01:26<23:08, 14.77s/it][A
Epoch 9:   7%|██▍                               | 7/100 [01:41<23:17, 15.03s/it][A
Epoch 9:   8%|██▋                               | 8/100 [01:57<23:08, 15.09s/it][A
Epoch 9:   9%|███                               | 9/100 [02:11<22:24, 14.78s/it][A
Epoch 9:  10%|███▎                             | 10/100 [02:26<22:10, 14.78s/it][A
Epoch 9:  11%|███▋                             | 11/100 [02:40<21:56, 14.79

Epoch 9:  97%|████████████████████████████████ | 97/100 [24:44<00:45, 15.18s/it][A
Epoch 9:  98%|████████████████████████████████▎| 98/100 [25:00<00:30, 15.43s/it][A
Epoch 9:  99%|████████████████████████████████▋| 99/100 [25:16<00:15, 15.62s/it][A
Epoch 9: 100%|████████████████████████████████| 100/100 [25:31<00:00, 15.52s/it][A
100%|███████████████████████████████████████| 10/10 [3:17:42<00:00, 1186.29s/it][A

Epoch 10/10, Loss: 0.6865372931957245





In [34]:
print(best_loss)

0.6753486797213555


In [35]:
# Evaluation

# Load the saved weights
model.load_state_dict(torch.load('best_model_weights.pth'))

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(val_loader):

        outputs = model(images)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

100%|███████████████████████████████████████████| 25/25 [01:21<00:00,  3.24s/it]


In [36]:

# Calculate accuracy
accuracy = accuracy_score(all_labels, all_preds)
print(f"Test Accuracy: {accuracy * 100:.4f}%")

Test Accuracy: 57.7500%
