In [158]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from PIL import Image
import timm  # Import timm library
from tqdm import tqdm
from sklearn.utils import shuffle
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score

In [160]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None, data_length=None):
        # Assuming csv_file is a CSV file with columns "image_sample" and "label"
        self.data = pd.read_csv(csv_file)
        self.data = shuffle(self.data, random_state=42)
        if data_length != None:
            self.data = self.data.iloc[:data_length]
            
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 0]  # assuming the image path is in the first column
        image = Image.open(img_path).convert('RGB')

        label = int(self.data.iloc[idx, 1])  # assuming the label is in the second column
        label = label + 1

        if self.transform:
            image = self.transform(image)

        return image, label

In [161]:
# Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Split the dataset
dataset = CustomDataset(csv_file='./data/data_with_only_image_dir.csv', transform=transform)
# train_size = int(0.8 * len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])


# Assuming you have your labels in the "label" column
labels = dataset.data['label'].values

# Use StratifiedShuffleSplit to split the dataset while maintaining class ratios
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

for train_index, val_index in sss.split(range(len(dataset)), labels):
    train_dataset = torch.utils.data.Subset(dataset, train_index)
    val_dataset = torch.utils.data.Subset(dataset, val_index)

# Optionally, you can convert train_dataset and val_dataset to DataLoader if needed
batch_size = 32


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [162]:
# Define the Vision Transformer Model
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# Modify the Model Head
num_classes = 3
model.head = nn.Linear(model.head.in_features, num_classes)

In [163]:
# Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training
num_epochs = 6
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

for epoch in tqdm(range(num_epochs)):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm(train_loader, desc='Epoch {}'.format(epoch), leave=False):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

  0%|                                                     | 0/6 [00:00<?, ?it/s]
Epoch 0:   0%|                                          | 0/289 [00:00<?, ?it/s][A
Epoch 0:   0%|                                | 1/289 [00:19<1:33:53, 19.56s/it][A
  0%|                                                     | 0/6 [00:38<?, ?it/s][A


KeyboardInterrupt: 

In [140]:
# Validation
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# model.eval()
# with torch.no_grad():
#     for inputs, labels in val_loader:
#         outputs = model(inputs)
#         # Evaluate the model on the validation set

# # Testing
# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# model.eval()
# with torch.no_grad():
#     for inputs, labels in test_loader:
#         outputs = model(inputs)
#         # Evaluate the model on the test set

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in tqdm(val_loader, desc='Epoch {}'.format(epoch), leave=False):
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print('Validation Accuracy: {:.4%}'.format(accuracy))


                                                                                

Validation Accuracy: 67.3885%




In [143]:
# Evaluation
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(val_loader):

        outputs = model(images)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

100%|███████████████████████████████████████████| 73/73 [07:59<00:00,  6.56s/it]


In [144]:

# Calculate accuracy
accuracy = accuracy_score(all_labels, all_preds)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

Test Accuracy: 67.39%
