In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, AutoModelForImageClassification

# -------------------------------
# 1. Load dataset
# -------------------------------
data = np.load('data/images.npy')   # shape: (2682, 600, 400, 3)
label = np.load('data/labels.npy')  # shape: (2682, 10)

# Convert one-hot labels to class indices
labels = np.argmax(label, axis=1)

# -------------------------------
# 2. Load image processor
# -------------------------------
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224", use_fast=True)

# -------------------------------
# 3. Define PyTorch Dataset
# -------------------------------
class NumpyDataset(Dataset):
    def __init__(self, images, labels, processor):
        self.images = images
        self.labels = labels
        self.processor = processor
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=self.processor.image_mean, std=self.processor.image_std)
        ])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = self.images[idx]
        img = self.transform(img)
        lbl = self.labels[idx]
        return img, lbl

dataset = NumpyDataset(data, labels, image_processor)

# -------------------------------
# 4. Split into train and validation
# -------------------------------
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

# -------------------------------
# 5. Load pretrained ViT and modify for transfer learning
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
model.to(device)

# Freeze backbone (optional for feature extraction)
for param in model.vit.parameters():
    param.requires_grad = False

# Replace classifier for 10 classes
num_classes = 10
model.classifier = nn.Linear(model.config.hidden_size, num_classes).to(device)

# -------------------------------
# 6. Training setup
# -------------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# -------------------------------
# 7. Training loop
# -------------------------------
epochs = 5
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for images, labels_batch in train_loader:
        images = images.to(device)
        labels_batch = labels_batch.to(device)

        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels_batch)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}")

# -------------------------------
# 8. Validation
# -------------------------------
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for images, labels_batch in val_loader:
        images = images.to(device)
        labels_batch = labels_batch.to(device)
        outputs = model(images).logits
        predicted = torch.argmax(outputs, dim=1)
        total += labels_batch.size(0)
        correct += (predicted == labels_batch).sum().item()

print(f"Validation Accuracy: {100*correct/total:.2f}%")


  from .autonotebook import tqdm as notebook_tqdm


Epoch 1/5, Loss: 1.6471
Epoch 2/5, Loss: 1.3385
Epoch 3/5, Loss: 1.2083
Epoch 4/5, Loss: 1.1119
Epoch 5/5, Loss: 1.0408
Validation Accuracy: 61.45%
