In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from PIL import Image
import timm  # Import timm library
from tqdm import tqdm
from sklearn.utils import shuffle
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
dd = pd.read_csv("./data/data_with_only_image_dir.csv")
dd.label.value_counts()

-1    7452
 1    2373
 0    1719
Name: label, dtype: int64

In [3]:
dd_filtered = dd[dd['label'] != 0]

# Sample equal number of -1 and 1 labels
dd_negative = dd_filtered[dd_filtered['label'] == -1].sample(n=2000, random_state=42)
dd_positive = dd_filtered[dd_filtered['label'] == 1].sample(n=2000, random_state=42)

# Concatenate the samples to create the final dataset
balanced_dataset = pd.concat([dd_negative, dd_positive])

# Shuffle the dataset
balanced_dataset = balanced_dataset.sample(frac=1, random_state=42).reset_index(drop=True)

balanced_dataset['label'] = balanced_dataset['label'].replace(-1, 0)

balanced_dataset.to_csv('./data/balanced_dataset.csv', index=False)

In [4]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None, data_length=None):
        # Assuming csv_file is a CSV file with columns "image_sample" and "label"
        self.data = pd.read_csv(csv_file)
        self.data = shuffle(self.data, random_state=42)
        if data_length != None:
            data_shape = self.data.shape
            included_data, excluded_data = train_test_split(self.data, test_size=1-(data_length/data_shape[0]), random_state=42, stratify=self.data.label)
            print(included_data.label.value_counts())
            self.data = included_data
#             self.data = self.data.iloc[:data_length]
            
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 0]  # assuming the image path is in the first column
        image = Image.open(img_path).convert('RGB')

        label = int(self.data.iloc[idx, 1])  # assuming the label is in the second column

        if self.transform:
            image = self.transform(image)

        return image, label

In [11]:
# Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Split the dataset
dataset = CustomDataset(csv_file='./data/balanced_dataset.csv', transform=transform)
# train_size = int(0.8 * len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])


# Assuming you have your labels in the "label" column
labels = dataset.data['label'].values

# Use StratifiedShuffleSplit to split the dataset while maintaining class ratios
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

for train_index, val_index in sss.split(range(len(dataset)), labels):
    train_dataset = torch.utils.data.Subset(dataset, train_index)
    val_dataset = torch.utils.data.Subset(dataset, val_index)
    
# Size of the training dataset
print("Size of the training dataset:", len(train_dataset))
print("Size of the testing dataset:", len(val_dataset))

# Optionally, you can convert train_dataset and val_dataset to DataLoader if needed
batch_size = 32


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Size of the training dataset: 3200
Size of the testing dataset: 800


In [12]:
train_dataset.dataset.data.label.value_counts()

1    2000
0    2000
Name: label, dtype: int64

In [13]:
# Define the Vision Transformer Model
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# Modify the Model Head
num_classes = 2
model.head = nn.Linear(model.head.in_features, num_classes)

### Stage 1 training (3 epochs)

In [14]:
# Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

best_loss = float('inf')  # Initialize with a large value

# Training
num_epochs = 10
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

for epoch in tqdm(range(3)):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm(train_loader, desc='Epoch {}'.format(epoch), leave=False):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")
    
    # Save the model if it has the best loss so far
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), 'best_model_weights.pth')

  0%|                                                     | 0/3 [00:00<?, ?it/s]
Epoch 0:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 0:   1%|▎                                 | 1/100 [00:19<32:07, 19.47s/it][A
Epoch 0:   2%|▋                                 | 2/100 [00:40<33:22, 20.44s/it][A
Epoch 0:   3%|█                                 | 3/100 [00:58<31:05, 19.23s/it][A
Epoch 0:   4%|█▎                                | 4/100 [01:16<29:46, 18.61s/it][A
Epoch 0:   5%|█▋                                | 5/100 [01:33<28:47, 18.18s/it][A
Epoch 0:   6%|██                                | 6/100 [01:51<28:17, 18.05s/it][A
Epoch 0:   7%|██▍                               | 7/100 [02:09<27:57, 18.04s/it][A
Epoch 0:   8%|██▋                               | 8/100 [02:26<27:21, 17.84s/it][A
Epoch 0:   9%|███                               | 9/100 [02:44<26:48, 17.68s/it][A
Epoch 0:  10%|███▎                             | 10/100 [03:01<26:25, 17.62s/it

Epoch 0:  96%|███████████████████████████████▋ | 96/100 [28:52<01:14, 18.67s/it][A
Epoch 0:  97%|████████████████████████████████ | 97/100 [29:11<00:55, 18.63s/it][A
Epoch 0:  98%|████████████████████████████████▎| 98/100 [29:30<00:37, 18.85s/it][A
Epoch 0:  99%|████████████████████████████████▋| 99/100 [29:50<00:18, 18.93s/it][A
Epoch 0: 100%|████████████████████████████████| 100/100 [30:08<00:00, 18.86s/it][A
                                                                                [A

Epoch 1/3, Loss: 1.3809153044223785


 33%|█████████████▋                           | 1/3 [30:09<1:00:19, 1809.65s/it]
Epoch 1:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 1:   1%|▎                                 | 1/100 [00:18<30:12, 18.31s/it][A
Epoch 1:   2%|▋                                 | 2/100 [00:36<30:05, 18.42s/it][A
Epoch 1:   3%|█                                 | 3/100 [00:54<29:35, 18.30s/it][A
Epoch 1:   4%|█▎                                | 4/100 [01:13<29:32, 18.47s/it][A
Epoch 1:   5%|█▋                                | 5/100 [01:31<29:02, 18.34s/it][A
Epoch 1:   6%|██                                | 6/100 [01:50<28:40, 18.31s/it][A
Epoch 1:   7%|██▍                               | 7/100 [02:08<28:22, 18.31s/it][A
Epoch 1:   8%|██▋                               | 8/100 [02:26<28:08, 18.35s/it][A
Epoch 1:   9%|███                               | 9/100 [02:45<27:55, 18.42s/it][A
Epoch 1:  10%|███▎                             | 10/100 [03:03<27:36, 18.41s/it

Epoch 1:  96%|███████████████████████████████▋ | 96/100 [30:15<01:17, 19.43s/it][A
Epoch 1:  97%|████████████████████████████████ | 97/100 [30:34<00:57, 19.27s/it][A
Epoch 1:  98%|████████████████████████████████▎| 98/100 [30:53<00:38, 19.23s/it][A
Epoch 1:  99%|████████████████████████████████▋| 99/100 [31:12<00:19, 19.10s/it][A
Epoch 1: 100%|████████████████████████████████| 100/100 [31:31<00:00, 19.15s/it][A
                                                                                [A

Epoch 2/3, Loss: 0.580397612452507


 67%|███████████████████████████▎             | 2/3 [1:01:42<30:58, 1858.45s/it]
Epoch 2:   0%|                                          | 0/100 [00:00<?, ?it/s][A
Epoch 2:   1%|▎                                 | 1/100 [00:18<30:45, 18.64s/it][A
Epoch 2:   2%|▋                                 | 2/100 [00:37<31:00, 18.98s/it][A
Epoch 2:   3%|█                                 | 3/100 [00:57<30:53, 19.11s/it][A
Epoch 2:   4%|█▎                                | 4/100 [01:16<30:32, 19.09s/it][A
Epoch 2:   5%|█▋                                | 5/100 [01:35<30:33, 19.30s/it][A
Epoch 2:   6%|██                                | 6/100 [01:55<30:21, 19.38s/it][A
Epoch 2:   7%|██▍                               | 7/100 [02:14<30:05, 19.41s/it][A
Epoch 2:   8%|██▋                               | 8/100 [02:34<29:51, 19.47s/it][A
Epoch 2:   9%|███                               | 9/100 [02:54<29:38, 19.54s/it][A
Epoch 2:  10%|███▎                             | 10/100 [03:13<29:18, 19.53s/it

Epoch 2:  96%|███████████████████████████████▋ | 96/100 [32:29<01:22, 20.54s/it][A
Epoch 2:  97%|████████████████████████████████ | 97/100 [32:49<01:01, 20.52s/it][A
Epoch 2:  98%|████████████████████████████████▎| 98/100 [33:09<00:40, 20.47s/it][A
Epoch 2:  99%|████████████████████████████████▋| 99/100 [33:30<00:20, 20.48s/it][A
Epoch 2: 100%|████████████████████████████████| 100/100 [33:50<00:00, 20.43s/it][A
                                                                                [A

Epoch 3/3, Loss: 0.5505284512042999


100%|█████████████████████████████████████████| 3/3 [1:35:33<00:00, 1911.31s/it]


In [19]:
torch.save(model.state_dict(), 'best_model_weights1.pth')

### Stage 2 training (3 epochs)

In [None]:
model.load_state_dict(torch.load('best_model_weights1.pth'))

for epoch in tqdm(range(3, 6)):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm(train_loader, desc='Epoch {}'.format(epoch), leave=False):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")
    
    # Save the model if it has the best loss so far
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), 'best_model_weights.pth')

In [None]:
torch.save(model.state_dict(), 'best_model_weights2.pth')

### Stage 3 training (4 epochs)

In [None]:
model.load_state_dict(torch.load('best_model_weights2.pth'))

for epoch in tqdm(range(6, num_epochs)):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm(train_loader, desc='Epoch {}'.format(epoch), leave=False):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")
    
    # Save the model if it has the best loss so far
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), 'best_model_weights.pth')

In [None]:
torch.save(model.state_dict(), 'best_model_weights3.pth')

In [15]:
print(best_loss)

0.5505284512042999


In [16]:
# Evaluation

# Load the saved weights
model.load_state_dict(torch.load('best_model_weights.pth'))

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(val_loader):

        outputs = model(images)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

100%|███████████████████████████████████████████| 25/25 [02:29<00:00,  5.96s/it]


In [18]:

# Calculate accuracy
accuracy = accuracy_score(all_labels, all_preds)
print(f"Test Accuracy: {accuracy * 100:.4f}%")

Test Accuracy: 74.7500%
