In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from functools import partial

# Load the Galaxy10 dataset from Hugging Face
dataset_name = "matthieulel/galaxy10_decals"
galaxy_dataset = load_dataset(dataset_name)


class_names = [
    "Disturbed", "Merging", "Round Smooth", "In-between Round Smooth",
    "Cigar Shaped Smooth", "Barred Spiral", "Unbarred Tight Spiral",
    "Unbarred Loose Spiral", "Edge-on without Bulge", "Edge-on with Bulge"
]

# Create a dictionary for easy lookup
label2name = {i: name for i, name in enumerate(class_names)}
name2label = {name: i for i, name in enumerate(class_names)}

num_classes = len(class_names)
print(f"\nNumber of classes: {num_classes}")
print("Class names:", class_names)

  from .autonotebook import tqdm as notebook_tqdm



Number of classes: 10
Class names: ['Disturbed', 'Merging', 'Round Smooth', 'In-between Round Smooth', 'Cigar Shaped Smooth', 'Barred Spiral', 'Unbarred Tight Spiral', 'Unbarred Loose Spiral', 'Edge-on without Bulge', 'Edge-on with Bulge']


In [2]:
from torch.utils.data import Dataset
import random

class Galaxy10OversampledDataset(Dataset):
    def __init__(self, dataset, split="train"):
        # Load the dataset
        self.data = list(dataset[split])
        
        # Group indices by class
        self.class_indices = {i: [] for i in range(10)}
        for idx, sample in enumerate(self.data):
            self.class_indices[sample["label"]].append(idx)
        
        # Determine maximum class count for oversampling
        max_count = max(len(idxs) for idxs in self.class_indices.values())
        
        # Oversample classes 0 and 4
        self.oversampled_indices = []
        for class_idx, indices in self.class_indices.items():
            if class_idx in [0, 4]:
                # Oversample with replacement and augment
                oversample_count = max_count
                sampled_indices = random.choices(indices, k=oversample_count)
                self.oversampled_indices.extend([(i, True) for i in sampled_indices])
            else:
                # Use original samples
                self.oversampled_indices.extend([(i, False) for i in indices])
        
        # Define transforms
        self.default_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
        self.augment_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(30),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.ToTensor()
        ])
        
    def __len__(self):
        return len(self.oversampled_indices)
    
    def __getitem__(self, idx):
        sample_idx, is_augmented = self.oversampled_indices[idx]
        sample = self.data[sample_idx]
        img = sample["image"]
        label = sample["label"]
        if is_augmented:
            img = self.augment_transform(img)
        else:
            img = self.default_transform(img)
        return {"pixel_values": img, "label": label}

In [3]:
train_dataset = Galaxy10OversampledDataset(galaxy_dataset, 'train')
test_dataset = Galaxy10OversampledDataset(galaxy_dataset, 'test')

In [4]:
# Define a custom collate function to handle the format
def collate_fn(batch):
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    labels = torch.tensor([item["label"] for item in batch])
    return {"pixel_values": pixel_values, "labels": labels}

In [5]:
batch_size = 64
num_workers = 4  # Adjust based on your system
prefetch_factor = 2

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    # pin_memory=True,
    persistent_workers=(num_workers > 0),
    prefetch_factor=prefetch_factor,
    drop_last=True,
    collate_fn=collate_fn
)
if 'test' in galaxy_dataset:
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        # pin_memory=True,
        persistent_workers=(num_workers > 0),
        prefetch_factor=prefetch_factor,
        collate_fn=collate_fn
    )

# Verify the dataloaders
print(f"Number of training batches: {len(train_loader)}")
if 'test' in galaxy_dataset:
    print(f"Number of test batches: {len(test_loader)}")

# Example of iterating through the dataloader
for batch in train_loader:
    print(f"Pixel values shape: {batch['pixel_values'].shape}")
    print(f"Labels shape: {batch['labels'].shape}")
    break

Number of training batches: 304
Number of test batches: 35
Pixel values shape: torch.Size([64, 3, 224, 224])
Labels shape: torch.Size([64])


In [6]:
from transformers import ViTImageProcessor

model_name_or_path = 'google/vit-large-patch16-224'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)

In [7]:
import numpy as np
from sklearn.metrics import accuracy_score

def compute_metrics(p):
    predictions = np.argmax(p.predictions, axis=1)
    references = p.label_ids
    return {"accuracy": accuracy_score(references, predictions)}

In [8]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(class_names),
    id2label=label2name,
    label2id=name2label,
    ignore_mismatched_sizes=True
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-large-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([10, 1024]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./vit-large-galaxy",
    per_device_train_batch_size=8,  # Reduced batch size for larger model
    eval_strategy="steps",
    num_train_epochs=8,  # Increased epochs for better convergence
    fp16=True,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=1e-5,  # Lower learning rate for fine-tuning large model
    weight_decay=0.01,   # Explicit weight decay for AdamW
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='tensorboard',
    load_best_model_at_end=True,
    optim="adamw_torch",  # Explicitly use AdamW optimizer
    warmup_steps=500,     # Add warmup steps
    lr_scheduler_type="cosine",  # Cosine decay works well with AdamW
)


In [10]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=processor,
)


  trainer = Trainer(


RuntimeError: CUDA error: CUDA-capable device(s) is/are busy or unavailable
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()