In [1]:
import timm
import torch
import torch.nn as nn
import sys
import os
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import datasets, transforms
sys.path.append(os.path.abspath(".."))
from data.ImageDataset import ImageDataset

In [10]:
class ViTBinaryClassifier(nn.Module):
    def __init__(self, model_name="vit_base_patch16_224", pretrained=True, num_classes=1):
        super(ViTBinaryClassifier, self).__init__()
        self.vit = timm.create_model(model_name, pretrained=pretrained, drop_rate=0.6, attn_drop_rate=0.5)
        in_features = self.vit.head.in_features
        self.vit.head = nn.Sequential(
            nn.Linear(in_features, num_classes),
            nn.Sigmoid()  # Sigmoid for binary classification
        )

    def forward(self, x):
        return self.vit(x)


In [11]:
model = ViTBinaryClassifier()
criterion = nn.BCELoss()
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = model.to(device)

cuda


In [12]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # Normalize for ImageNet
])

train_dataset = ImageDataset(
    annotations_path="/home/ec2-user/CS230Project/data/annotations/train.json",
    images_dir="/home/ec2-user/CS230Project/data/train",
    transform=transform,
)

val_dataset = ImageDataset(
    annotations_path="/home/ec2-user/CS230Project/data/annotations/val.json",
    images_dir="/home/ec2-user/CS230Project/data/val",
    transform=transform,
)


train_loader = DataLoader(train_dataset, batch_size=64, num_workers=7, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, num_workers=7, shuffle=False)

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device).float()
        labels = labels.view(-1, 1)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        predicted = (outputs > 0.5).float()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_accuracy = 100. * correct / total
    print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Accuracy: {train_accuracy:.2f}%")
    model.eval()
    val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device).float()
            labels = labels.view(-1, 1)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            predictions = (outputs > 0.5).float()
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    val_accuracy = 100. * correct / total
    print(f"Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {val_accuracy:.2f}%")

    scheduler.step()

    checkpoint_path = f"/home/ec2-user/CS230Project/code/models/saved-weights/ViT/ViT_{epoch+1}.pth"
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Model saved to {checkpoint_path}")


Training Epoch 1/10: 100%|██████████| 690/690 [06:51<00:00,  1.68it/s]

Epoch 1, Train Loss: 0.6628, Accuracy: 59.61%





Validation Loss: 0.5235, Accuracy: 74.60%
Model saved to /home/ec2-user/CS230Project/code/models/saved-weights/ViT/ViT_1.pth


Training Epoch 2/10: 100%|██████████| 690/690 [06:52<00:00,  1.67it/s]

Epoch 2, Train Loss: 0.5936, Accuracy: 67.26%





Validation Loss: 0.4976, Accuracy: 76.00%
Model saved to /home/ec2-user/CS230Project/code/models/saved-weights/ViT/ViT_2.pth


Training Epoch 3/10: 100%|██████████| 690/690 [06:52<00:00,  1.67it/s]

Epoch 3, Train Loss: 0.5441, Accuracy: 72.07%





Validation Loss: 0.5998, Accuracy: 74.89%
Model saved to /home/ec2-user/CS230Project/code/models/saved-weights/ViT/ViT_3.pth


Training Epoch 4/10: 100%|██████████| 690/690 [06:52<00:00,  1.67it/s]

Epoch 4, Train Loss: 0.4782, Accuracy: 76.69%





Validation Loss: 0.4985, Accuracy: 77.67%
Model saved to /home/ec2-user/CS230Project/code/models/saved-weights/ViT/ViT_4.pth


Training Epoch 5/10: 100%|██████████| 690/690 [06:52<00:00,  1.67it/s]

Epoch 5, Train Loss: 0.4045, Accuracy: 81.52%





Validation Loss: 0.5149, Accuracy: 76.66%
Model saved to /home/ec2-user/CS230Project/code/models/saved-weights/ViT/ViT_5.pth


Training Epoch 6/10: 100%|██████████| 690/690 [06:53<00:00,  1.67it/s]

Epoch 6, Train Loss: 0.3387, Accuracy: 85.30%





Validation Loss: 0.5723, Accuracy: 77.55%
Model saved to /home/ec2-user/CS230Project/code/models/saved-weights/ViT/ViT_6.pth


Training Epoch 7/10: 100%|██████████| 690/690 [06:53<00:00,  1.67it/s]

Epoch 7, Train Loss: 0.2826, Accuracy: 88.32%





Validation Loss: 0.6345, Accuracy: 75.83%
Model saved to /home/ec2-user/CS230Project/code/models/saved-weights/ViT/ViT_7.pth


Training Epoch 8/10: 100%|██████████| 690/690 [06:52<00:00,  1.67it/s]

Epoch 8, Train Loss: 0.1584, Accuracy: 93.89%





Validation Loss: 0.8424, Accuracy: 78.13%
Model saved to /home/ec2-user/CS230Project/code/models/saved-weights/ViT/ViT_8.pth


Training Epoch 9/10: 100%|██████████| 690/690 [06:52<00:00,  1.67it/s]

Epoch 9, Train Loss: 0.1152, Accuracy: 95.36%





Validation Loss: 1.0382, Accuracy: 78.22%
Model saved to /home/ec2-user/CS230Project/code/models/saved-weights/ViT/ViT_9.pth


Training Epoch 10/10: 100%|██████████| 690/690 [06:52<00:00,  1.67it/s]

Epoch 10, Train Loss: 0.0972, Accuracy: 95.75%





Validation Loss: 1.2304, Accuracy: 78.26%
Model saved to /home/ec2-user/CS230Project/code/models/saved-weights/ViT/ViT_10.pth


: 