In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from PIL import Image
import timm  # Import timm library
from tqdm import tqdm
from sklearn.utils import shuffle
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score
import pandas as pd
from sklearn.model_selection import train_test_split



In [2]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None, data_length=None):
        # Assuming csv_file is a CSV file with columns "image_sample" and "label"
        self.data = pd.read_csv(csv_file)
        self.data = shuffle(self.data, random_state=42)
        if data_length != None:
            data_shape = self.data.shape
            included_data, excluded_data = train_test_split(self.data, test_size=1-(data_length/data_shape[0]), random_state=42, stratify=self.data.label)
            print(included_data.label.value_counts())
            self.data = included_data
#             self.data = self.data.iloc[:data_length]
            
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 0]  # assuming the image path is in the first column
        image = Image.open(img_path).convert('RGB')

        label = int(self.data.iloc[idx, 1])  # assuming the label is in the second column
        label = label + 1

        if self.transform:
            image = self.transform(image)

        return image, label

In [6]:
# Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Split the dataset
dataset = CustomDataset(csv_file='./data/data_with_only_image_dir.csv', transform=transform, data_length = 2000)
# train_size = int(0.8 * len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])


# Assuming you have your labels in the "label" column
labels = dataset.data['label'].values

# Use StratifiedShuffleSplit to split the dataset while maintaining class ratios
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

for train_index, val_index in sss.split(range(len(dataset)), labels):
    train_dataset = torch.utils.data.Subset(dataset, train_index)
    val_dataset = torch.utils.data.Subset(dataset, val_index)

# Optionally, you can convert train_dataset and val_dataset to DataLoader if needed
batch_size = 16


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

-1    64
 1    21
 0    15
Name: label, dtype: int64


In [7]:
train_dataset.dataset.data.label.value_counts()

-1    64
 1    21
 0    15
Name: label, dtype: int64

In [8]:
# Define the Vision Transformer Model
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# Modify the Model Head
num_classes = 3
model.head = nn.Linear(model.head.in_features, num_classes)

In [9]:
# Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

best_loss = float('inf')  # Initialize with a large value

# Training
num_epochs = 30
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

for epoch in tqdm(range(num_epochs)):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm(train_loader, desc='Epoch {}'.format(epoch), leave=False):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")
    
    # Save the model if it has the best loss so far
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), 'best_model_weights.pth')

  0%|                                                    | 0/30 [00:00<?, ?it/s]
Epoch 0:   0%|                                            | 0/5 [00:00<?, ?it/s][A
Epoch 0:  20%|███████▏                            | 1/5 [00:08<00:34,  8.65s/it][A
Epoch 0:  40%|██████████████▍                     | 2/5 [00:17<00:26,  8.89s/it][A
Epoch 0:  60%|█████████████████████▌              | 3/5 [00:27<00:18,  9.12s/it][A
Epoch 0:  80%|████████████████████████████▊       | 4/5 [00:36<00:09,  9.23s/it][A
Epoch 0: 100%|████████████████████████████████████| 5/5 [00:46<00:00,  9.55s/it][A
                                                                                [A

Epoch 1/30, Loss: 6.437681555747986


  3%|█▍                                          | 1/30 [00:47<22:55, 47.42s/it]
Epoch 1:   0%|                                            | 0/5 [00:00<?, ?it/s][A
  3%|█▍                                          | 1/30 [00:57<27:46, 57.46s/it][A


KeyboardInterrupt: 

In [34]:
print(best_loss)

0.6753486797213555


In [35]:
# Evaluation

# Load the saved weights
model.load_state_dict(torch.load('best_model_weights.pth'))

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(val_loader):

        outputs = model(images)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

100%|███████████████████████████████████████████| 25/25 [01:21<00:00,  3.24s/it]


In [36]:

# Calculate accuracy
accuracy = accuracy_score(all_labels, all_preds)
print(f"Test Accuracy: {accuracy * 100:.4f}%")

Test Accuracy: 57.7500%
