In [None]:
# Setup
!pip install -q timm
import torch
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

In [None]:
# Download EuroSAT RGB dataset
!wget -q https://zenodo.org/record/7711810/files/EuroSAT_RGB.zip
!unzip -q EuroSAT_RGB.zip
!ls EuroSAT_RGB/

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
import numpy as np
from tqdm.auto import tqdm
import timm
import json
from sklearn.model_selection import train_test_split

# Configuration
CLASSES = ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway',
           'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']
DATA_ROOT = Path('EuroSAT_RGB')
BATCH_SIZE = 64
NUM_EPOCHS = 15
LR = 1e-4
MODEL_NAME = 'vit_small_patch16_224'  # 22M params, good balance

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

In [None]:
# Create stratified splits (70/15/15)
all_images = []
all_labels = []

for cls_idx, cls_name in enumerate(CLASSES):
    cls_path = DATA_ROOT / cls_name
    for img_path in cls_path.glob('*.jpg'):
        all_images.append(str(img_path))
        all_labels.append(cls_idx)

print(f"Total images: {len(all_images)}")

# Split: 70% train, 15% val, 15% test
train_imgs, temp_imgs, train_labels, temp_labels = train_test_split(
    all_images, all_labels, test_size=0.3, stratify=all_labels, random_state=42
)
val_imgs, test_imgs, val_labels, test_labels = train_test_split(
    temp_imgs, temp_labels, test_size=0.5, stratify=temp_labels, random_state=42
)

print(f"Train: {len(train_imgs)}, Val: {len(val_imgs)}, Test: {len(test_imgs)}")

In [None]:
# Dataset
class EuroSATDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, self.labels[idx]

# Transforms - ViT expects 224x224
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.2, 0.2, 0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_ds = EuroSATDataset(train_imgs, train_labels, train_transform)
val_ds = EuroSATDataset(val_imgs, val_labels, val_transform)
test_ds = EuroSATDataset(test_imgs, test_labels, val_transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
# Model: ViT-S/16 with pretrained ImageNet weights
model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=10)
model = model.to(device)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model: {MODEL_NAME}")
print(f"Parameters: {num_params:,}")

In [None]:
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

best_val_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

# Training loop
for epoch in range(NUM_EPOCHS):
    # Train
    model.train()
    train_loss, train_correct, train_total = 0, 0, 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS}')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_correct += (outputs.argmax(1) == labels).sum().item()
        train_total += labels.size(0)
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*train_correct/train_total:.1f}%'})
    
    # Validate
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            val_loss += criterion(outputs, labels).item()
            val_correct += (outputs.argmax(1) == labels).sum().item()
            val_total += labels.size(0)
    
    scheduler.step()
    
    train_acc = 100 * train_correct / train_total
    val_acc = 100 * val_correct / val_total
    history['train_loss'].append(train_loss / len(train_loader))
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss / len(val_loader))
    history['val_acc'].append(val_acc)
    
    print(f'  Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%')
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_acc': val_acc,
            'model_name': MODEL_NAME
        }, 'vit_classifier.pth')
        print(f'  ✓ Saved best model (Val Acc: {val_acc:.2f}%)')

print(f'\nBest Val Acc: {best_val_acc:.2f}%')

In [None]:
# Final test evaluation
checkpoint = torch.load('vit_classifier.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

test_correct, test_total = 0, 0
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Testing'):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        test_correct += (outputs.argmax(1) == labels).sum().item()
        test_total += labels.size(0)

test_acc = 100 * test_correct / test_total
print(f'\n{"="*50}')
print(f'ViT-S/16 Test Accuracy: {test_acc:.2f}%')
print(f'(ResNet-18 Baseline: 97.65%)')
print(f'{"="*50}')

# Save final results
results = {
    'model': MODEL_NAME,
    'test_accuracy': test_acc,
    'best_val_accuracy': best_val_acc,
    'epochs': NUM_EPOCHS,
    'batch_size': BATCH_SIZE,
    'learning_rate': LR,
    'optimizer': 'AdamW',
    'scheduler': 'CosineAnnealingLR',
    'history': history
}
with open('vit_results.json', 'w') as f:
    json.dump(results, f, indent=2)

print('\n✓ Download vit_classifier.pth and vit_results.json')

In [None]:
# Download files
from google.colab import files
files.download('vit_classifier.pth')
files.download('vit_results.json')