In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from functools import partial

# Load the Galaxy10 dataset from Hugging Face
dataset_name = "matthieulel/galaxy10_decals"
galaxy_dataset = load_dataset(dataset_name)

# Define transformations
train_transform = transforms.Compose([
    # transforms.RandomRotation(180),  # Galaxies can appear at any orientation
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  # Slight zoom variation
    # transforms.RandomHorizontalFlip(),  # Horizontal flip is valid for galaxies
    # transforms.RandomVerticalFlip(),   # Vertical flip is also valid for galaxies
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Simulate different telescope exposures
    transforms.ToTensor(),
    # Add any other necessary transformations
])
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # Add any other necessary transformations
])

# Define preprocessing functions to apply transformations
def preprocess_train(examples, transform=None):
    examples["pixel_values"] = [transform(image) for image in examples["image"]]
    return examples

def preprocess_test(examples, transform=None):
    examples["pixel_values"] = [transform(image) for image in examples["image"]]
    return examples

# Create a validation split if none exists
train_dataset = galaxy_dataset["train"].map(
    partial(preprocess_train, transform=train_transform),
    batched=True,
    remove_columns=["image"]
)
train_dataset.set_format(type="torch", columns=["pixel_values", "label"])

if "test" in galaxy_dataset:
    test_dataset = galaxy_dataset["test"].map(
        partial(preprocess_test, transform=test_transform),
        batched=True,
        remove_columns=["image"]
    )
    test_dataset.set_format(type="torch", columns=["pixel_values", "label"])

# Define a custom collate function to handle the format
def collate_fn(batch):
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    labels = torch.tensor([item["label"] for item in batch])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
batch_size = 64
num_workers = 4  # Adjust based on your system
prefetch_factor = 2

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    # pin_memory=True,
    persistent_workers=(num_workers > 0),
    prefetch_factor=prefetch_factor,
    drop_last=True,
    collate_fn=collate_fn
)
if 'test' in galaxy_dataset:
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        # pin_memory=True,
        persistent_workers=(num_workers > 0),
        prefetch_factor=prefetch_factor,
        collate_fn=collate_fn
    )

# Verify the dataloaders
print(f"Number of training batches: {len(train_loader)}")
if 'test' in galaxy_dataset:
    print(f"Number of test batches: {len(test_loader)}")

# Example of iterating through the dataloader
for batch in train_loader:
    print(f"Pixel values shape: {batch['pixel_values'].shape}")
    print(f"Labels shape: {batch['labels'].shape}")
    break

In [None]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from torchvision.models import vit_l_16, ViT_L_16_Weights
from transformers import ViTFeatureExtractor

# Define the AttentionHead class
class AdvancedMLPHead(nn.Module):
    def __init__(self, in_features, hidden_dim, num_classes, dropout=0.3):
        super().__init__()
        self.norm = nn.LayerNorm(in_features)
        self.mlp = nn.Sequential(
            nn.Linear(in_features, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),  # Add dropout
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),  # Add dropout
            nn.Linear(hidden_dim // 2, num_classes)
        )
        
    def forward(self, x):
        if len(x.shape) > 2:
            x = x[:, 0]  # Take only the CLS token
        x = self.norm(x)
        return self.mlp(x)

# Load the model architecture (without pretrained weights)
model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT)
# Get the number of features in the final layer
num_features = model.heads.head.in_features

# Replace the classifier head with the AttentionHead
num_classes = 10  # For Galaxy10 DECals dataset
model.heads.head = AdvancedMLPHead(num_features, 2048, num_classes)

# # Load the saved model weights from the .pt file
# model.load_state_dict(torch.load('best_galaxy_vit_model.pth'))

# Unfreeze all parameters for training the entire model
for param in model.parameters():
    param.requires_grad = True

# Initialize feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-large-patch16-224")

# Define optimizer and loss function for all parameters
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-5,  # Lower learning rate for fine-tuning the entire model
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0.05
)
criterion = nn.CrossEntropyLoss()

# Learning rate scheduler
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#     optimizer, mode='max', factor=0.1, patience=3,
# )


scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, 
    T_0=5,  # Restart every 5 epochs
    T_mult=1,
    eta_min=1e-6
)

# Track best validation accuracy for model saving
best_val_acc = 0.0
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Training loop
for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_losses = []
    train_preds = []
    train_labels = []

    # Create a new iterator for each epoch
    train_loader_iter = iter(train_loader)
    num_batches = len(train_loader)

    # Use tqdm for progress bar
    with tqdm(total=num_batches, desc=f"Epoch {epoch+1}/{num_epochs} [Train]") as pbar:
        for batch_idx in range(num_batches):
            try:
                # Get the next batch
                batch = next(train_loader_iter)
            except StopIteration:
                # If we've run out of data, break the loop
                break

            # Unpack the batch - assuming batch contains images and labels
            if isinstance(batch, dict):
                images, labels = batch['pixel_values'], batch['labels']
            else:
                images, labels = batch
                # Process images with feature extractor
                inputs = feature_extractor(images=images, return_tensors="pt")
                images = inputs['pixel_values']

            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)  # For torchvision ViT, don't use pixel_values parameter
            logits = outputs if not hasattr(outputs, 'logits') else outputs.logits
            loss = criterion(logits, labels)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step(epoch + batch_idx / num_batches)

            # Track statistics
            train_losses.append(loss.item())
            batch_preds = torch.argmax(logits, dim=1).cpu().numpy()
            train_preds.extend(batch_preds)
            batch_labels = labels.cpu().numpy()
            train_labels.extend(batch_labels)

            # Update progress bar
            batch_acc = accuracy_score(batch_labels, batch_preds)
            pbar.set_postfix(loss=loss.item(), acc=batch_acc)
            pbar.update(1)

    # Calculate training metrics
    train_acc = accuracy_score(train_labels, train_preds)
    train_loss = sum(train_losses) / len(train_losses)

    # Evaluation phase
    model.eval()
    val_preds = []
    val_labels = []
    val_losses = []

    # Create iterator for validation data
    val_dataloader_iter = iter(test_loader)
    num_val_batches = len(test_loader)

    with torch.no_grad():
        with tqdm(total=num_val_batches, desc=f"Epoch {epoch+1}/{num_epochs} [Valid]") as pbar:
            for _ in range(num_val_batches):
                try:
                    # Get the next batch
                    batch = next(val_dataloader_iter)
                except StopIteration:
                    break

                # Unpack the batch - assuming batch contains images and labels
                if isinstance(batch, dict):
                    images, labels = batch['pixel_values'], batch['labels']
                else:
                    images, labels = batch
                    # Process images with feature extractor
                    inputs = feature_extractor(images=images, return_tensors="pt")
                    images = inputs['pixel_values']

                images = images.to(device)
                labels = labels.to(device)

                outputs = model(images)  # For torchvision ViT, don't use pixel_values parameter
                logits = outputs if not hasattr(outputs, 'logits') else outputs.logits
                loss = criterion(logits, labels)

                # Track statistics
                val_losses.append(loss.item())
                batch_preds = torch.argmax(logits, dim=1).cpu().numpy()
                val_preds.extend(batch_preds)
                batch_labels = labels.cpu().numpy()
                val_labels.extend(batch_labels)

                # Update progress bar
                batch_acc = accuracy_score(batch_labels, batch_preds)
                pbar.set_postfix(loss=loss.item(), acc=batch_acc)
                pbar.update(1)

    # Calculate validation metrics
    val_acc = accuracy_score(val_labels, val_preds)
    val_loss = sum(val_losses) / len(val_losses)

    # Update learning rate based on validation accuracy
    # scheduler.step(val_acc)

    # Print progress
    print(f'Epoch {epoch+1}/{num_epochs} - Train loss: {train_loss:.4f}, Train acc: {train_acc:.4f}, Val loss: {val_loss:.4f}, Val acc: {val_acc:.4f}')

    # Save the best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_galaxy_vit_model_reg.pth')
        print(f'Model saved with validation accuracy: {val_acc:.4f}') 

print(f'Training completed. Best validation accuracy: {best_val_acc:.4f}')
