# Homework: Galaxy Image Classification

**Course:** Deep Learning for Computer Vision

**Objective:** Train a deep learning model to classify galaxy images from the Galaxy10 DECals dataset into one of 10 categories.

**Dataset:** Galaxy10 DECals
* **Source:** [Hugging Face Datasets](https://huggingface.co/datasets/matthieulel/galaxy10_decals)
* **Description:** Contains 17,736 color galaxy images (256x256 pixels) divided into 10 classes. Images originate from DESI Legacy Imaging Surveys, with labels from Galaxy Zoo.
* **Classes:**
    * 0: Disturbed Galaxies
    * 1: Merging Galaxies
    * 2: Round Smooth Galaxies
    * 3: In-between Round Smooth Galaxies
    * 4: Cigar Shaped Smooth Galaxies
    * 5: Barred Spiral Galaxies
    * 6: Unbarred Tight Spiral Galaxies
    * 7: Unbarred Loose Spiral Galaxies
    * 8: Edge-on Galaxies without Bulge
    * 9: Edge-on Galaxies with Bulge

**Tasks:**
1.  Load and explore the dataset.
2.  Preprocess the images.
3.  Define and train a model.
4.  Evaluate the model's performance using standard classification metrics on the test set.

Homework is succesfully completed if you get >0.9 Accuracy on the Test set.

# Prerequisites

In [1]:
%pip install -r requirements.txt

Collecting torch
  Using cached torch-2.7.0-cp310-cp310-manylinux_2_28_x86_64.whl (865.2 MB)
Collecting datasets
  Using cached datasets-3.5.0-py3-none-any.whl (491 kB)
Collecting matplotlib
  Using cached matplotlib-3.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)
Collecting scikit-learn
  Using cached scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.5 MB)
Collecting torchvision
  Using cached torchvision-0.22.0-cp310-cp310-manylinux_2_28_x86_64.whl (7.4 MB)
Collecting torchmetrics
  Using cached torchmetrics-1.7.1-py3-none-any.whl (961 kB)
Collecting albumentations
  Using cached albumentations-2.0.5-py3-none-any.whl (290 kB)
Collecting seaborn
  Using cached seaborn-0.13.2-py3-none-any.whl (294 kB)
Collecting timm
  Using cached timm-1.0.15-py3-none-any.whl (2.4 MB)
Collecting transformers
  Using cached transformers-4.51.3-py3-none-any.whl (10.4 MB)
Collecting nvidia-cuda-runtime-cu12==12.6.77
  Using cached nvidia_cuda_run

In [None]:
import torch

In [None]:
# Cell 4: Visualize one example from each class
def show_class_examples(dataset, class_names_map, samples_per_row=5, num_rows=2):
    """Displays one sample image for each class."""
    if not dataset:
        print("Dataset not loaded. Cannot visualize.")
        return

    num_classes_to_show = len(class_names_map)
    if num_classes_to_show > samples_per_row * num_rows:
        print(f"Warning: Not enough space to show all {num_classes_to_show} classes.")
        num_classes_to_show = samples_per_row * num_rows

    fig, axes = plt.subplots(num_rows, samples_per_row, figsize=(15, 6)) # Adjusted figsize
    axes = axes.ravel() # Flatten the axes array

    split_name = 'train' if 'train' in dataset else list(dataset.keys())[0]
    data_split = dataset[split_name]

    images_shown = 0
    processed_labels = set()

    for i in range(len(data_split)):
        if images_shown >= num_classes_to_show:
            break # Stop once we have shown one for each target class

        example = data_split[i]
        label = example['label']

        if label not in processed_labels and label < num_classes_to_show:
            img = example['image']
            ax_idx = label # Use label directly as index into the flattened axes
            axes[ax_idx].imshow(img)
            axes[ax_idx].set_title(f"Class {label}: {class_names_map[label]}", fontsize=9)
            axes[ax_idx].axis('off')
            processed_labels.add(label)
            images_shown += 1

    # Hide any unused subplots
    for i in range(images_shown, len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
def evaluate_predictions(predicted_labels, true_labels, class_names_list):
    """
    Calculates and prints classification metrics from predicted labels and true labels.

    Args:
        predicted_labels (list or np.array): The predicted class indices for the test set.
        true_labels (list or np.array): The ground truth class indices for the test set.
        class_names_list (list): A list of strings containing the names of the classes.
    """
    if len(predicted_labels) != len(true_labels):
        print(f"Error: Number of predictions ({len(predicted_labels)}) does not match number of true labels ({len(true_labels)}).")
        return None # Indicate failure

    print(f"Evaluating {len(predicted_labels)} predictions against true labels...")

    # Ensure inputs are numpy arrays for scikit-learn
    predicted_labels = np.array(predicted_labels)
    true_labels = np.array(true_labels)

    # Calculate metrics using scikit-learn
    accuracy = accuracy_score(true_labels, predicted_labels)
    # Calculate precision, recall, f1 per class and average (weighted)
    # Use zero_division=0 to handle cases where a class might not be predicted or present in labels
    precision, recall, f1, _ = precision_recall_fscore_support(
        true_labels, predicted_labels, average='weighted', zero_division=0
    )
    # Get per-class metrics as well
    per_class_precision, per_class_recall, per_class_f1, per_class_support = precision_recall_fscore_support(
        true_labels, predicted_labels, average=None, zero_division=0, labels=range(len(class_names_list))
    )

    # Generate Confusion Matrix
    cm = confusion_matrix(true_labels, predicted_labels, labels=range(len(class_names_list)))

    # Print Metrics
    print(f"\n--- Evaluation Metrics ---")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Weighted Precision: {precision:.4f}")
    print(f"Weighted Recall: {recall:.4f}")
    print(f"Weighted F1-Score: {f1:.4f}")
    print("-" * 25)
    print("Per-Class Metrics:")
    print(f"{'Class':<30} | {'Precision':<10} | {'Recall':<10} | {'F1-Score':<10} | {'Support':<10}")
    print("-" * 80)
    for i, name in enumerate(class_names_list):
         # Handle cases where support might be 0 for a class in true labels if dataset is small/filtered
         support = per_class_support[i] if i < len(per_class_support) else 0
         prec = per_class_precision[i] if i < len(per_class_precision) else 0
         rec = per_class_recall[i] if i < len(per_class_recall) else 0
         f1s = per_class_f1[i] if i < len(per_class_f1) else 0
         print(f"{f'{i}: {name}':<30} | {prec:<10.4f} | {rec:<10.4f} | {f1s:<10.4f} | {support:<10}")
    print("-" * 80)

    # Plot Confusion Matrix
    print("\nPlotting Confusion Matrix...")
    fig, ax = plt.subplots(figsize=(10, 10))
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_normalized, display_labels=class_names_list)
    disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation='vertical', values_format=".2f")
    plt.title('Normalized Confusion Matrix')

    plt.tight_layout()
    plt.show()

    metrics = {
        'accuracy': accuracy,
        'precision_weighted': precision,
        'recall_weighted': recall,
        'f1_weighted': f1,
        'confusion_matrix': cm,
        'per_class_metrics': {
            'precision': per_class_precision,
            'recall': per_class_recall,
            'f1': per_class_f1,
            'support': per_class_support
        }
    }
    return metrics

# Data

In [None]:
%pip install jupyterlab-widgets

[0mNote: you may need to restart the kernel to use updated packages.


In [None]:
import datasets

dataset_name = "matthieulel/galaxy10_decals"
galaxy_dataset = datasets.load_dataset(dataset_name)

# Define class names based on the dataset card
class_names = [
    "Disturbed", "Merging", "Round Smooth", "In-between Round Smooth",
    "Cigar Shaped Smooth", "Barred Spiral", "Unbarred Tight Spiral",
    "Unbarred Loose Spiral", "Edge-on without Bulge", "Edge-on with Bulge"
]

# Create a dictionary for easy lookup
label2name = {i: name for i, name in enumerate(class_names)}
name2label = {name: i for i, name in enumerate(class_names)}

num_classes = len(class_names)
print(f"\nNumber of classes: {num_classes}")
print("Class names:", class_names)

README.md:   0%|          | 0.00/1.42k [00:00<?, ?B/s]

(…)-00000-of-00005-3ac4e3b3fa8df68d.parquet:   0%|          | 0.00/457M [00:00<?, ?B/s]

(…)-00001-of-00005-72e716d68bb13413.parquet:   0%|          | 0.00/458M [00:00<?, ?B/s]

In [None]:
show_class_examples(galaxy_dataset, label2name, samples_per_row=5, num_rows=2)

# Your training code here

In [None]:
%pip install torchvision torchmetrics albumentations

In [None]:
import torch
# from torchvision.transforms import v2
from torch.utils.data import DataLoader
from torch import nn
from torchmetrics import Accuracy
from tqdm import tqdm
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from albumentations import Compose, Rotate, HorizontalFlip, RandomBrightnessContrast, CoarseDropout
from albumentations.pytorch import ToTensorV2 

class GalaxyDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']
        label = item['label']

        image_np = np.array(image)

        if self.transform:
            transformed = self.transform(image=image_np)
            image = transformed['image']
        else:
            image = torch.transforms.ToTensor()(image)

        if isinstance(label, str):
            label = int(label)
        label = torch.tensor(label, dtype=torch.long)

        return image, label



from albumentations import Normalize

train_transform = Compose([
    # Rotate(limit=45, p=0.7),
    # HorizontalFlip(p=0.5),
    # RandomBrightnessContrast(p=0.3),
    # CoarseDropout(num_holes_range=(7,8), hole_height_range=(15,16), hole_width_range=(15, 16), p=0.5),
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

    ToTensorV2()
])

test_transform = Compose([
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

train_dataset = GalaxyDataset(galaxy_dataset['train'], transform=train_transform)
test_dataset = GalaxyDataset(galaxy_dataset['test'], transform=test_transform)

BATCH_SIZE = 64

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

sample_images, sample_labels = next(iter(train_loader))
print(f"Image batch shape: {sample_images.shape}")
print(f"Label batch shape: {sample_labels.shape}")

In [None]:
%pip install seaborn

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import seaborn as sns

accuracy = Accuracy(task="multiclass", num_classes=10).to(device)

def epoch_train(loader, model, criterion, optimizer):
    model.train()
    total_loss = 0.0
    accuracy.reset()
    
    for images, labels in loader:
        if isinstance(labels, list):
            labels = torch.tensor(labels)
            
        images, labels = images.to(device), labels.to(device)


        optimizer.zero_grad()
        outputs = model(images)
        # print(outputs) 
        loss = criterion(outputs, labels)
            
        loss.backward()
        optimizer.step()
            
        total_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1) 
        accuracy.update(preds, labels)
    
    avg_loss = total_loss / len(loader.dataset)
    avg_acc = accuracy.compute()
    return avg_loss, avg_acc

def epoch_test(loader, model, criterion, epoch=0):
    model.eval()
    total_loss = 0.0
    accuracy.reset()
    all_labels = []
    all_preds = []
    
    with torch.no_grad():
        for images, labels in loader:
            if isinstance(labels, list):
                labels = torch.tensor(labels)
        
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            
            accuracy.update(preds, labels)
    
    plot_confusion_matrix(all_labels, all_preds, epoch)
    
    avg_loss = total_loss / len(loader.dataset)
    avg_acc = accuracy.compute()
    return avg_loss, avg_acc

def plot_confusion_matrix(all_labels, all_preds, epoch):
    cm = confusion_matrix(all_labels, all_preds)
    
    plt.figure(figsize=(10, 7))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('predicted')
    plt.ylabel('true')
    plt.title(f'cf  (Epoch {epoch})')
    
    filename = f'confusion_matrix_epoch_{epoch:03d}.png'
    plt.savefig(filename, bbox_inches='tight')
    plt.close() 


In [None]:
def train(train_loader, test_loader, model, criterion, optimizer, epochs=50):
    model = model.to(device)
    
    for epoch in tqdm(range(epochs), desc="Training"):
        train_loss, train_acc = epoch_train(train_loader, model, criterion, optimizer)
        test_loss, test_acc = epoch_test(test_loader, model, criterion, epoch=epoch)
        
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}\n")

In [None]:
%pip install timm transformers

In [None]:
import timm
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = timm.create_model('resnet50', pretrained=True, num_classes=10)
# model = model.to(device)

 #https://huggingface.co/docs/transformers/model_doc/swinv2

from transformers import Swinv2ForImageClassification

model = Swinv2ForImageClassification.from_pretrained(
    "microsoft/swinv2-tiny-patch4-window8-256",
    num_labels=10, ignore_mismatched_sizes=True
)

opt = torch.optim.AdamW(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss()

In [None]:
train(train_loader, test_loader, model, criterion, opt, epochs=20)

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   5%|▌         | 1/20 [00:48<15:29, 48.93s/it]

Epoch 1/20
Train Loss: 1.2060 | Train Acc: 0.5687
Test Loss: 1.2547 | Test Acc: 0.5750



Training:  10%|█         | 2/20 [01:37<14:34, 48.61s/it]

Epoch 2/20
Train Loss: 0.7446 | Train Acc: 0.7475
Test Loss: 1.1357 | Test Acc: 0.6471



Training:  15%|█▌        | 3/20 [02:25<13:46, 48.61s/it]

Epoch 3/20
Train Loss: 0.6070 | Train Acc: 0.7937
Test Loss: 0.8488 | Test Acc: 0.7198



Training:  20%|██        | 4/20 [03:14<12:58, 48.63s/it]

Epoch 4/20
Train Loss: 0.5286 | Train Acc: 0.8173
Test Loss: 0.5925 | Test Acc: 0.8134



Training:  25%|██▌       | 5/20 [04:03<12:09, 48.66s/it]

Epoch 5/20
Train Loss: 0.4682 | Train Acc: 0.8382
Test Loss: 0.7861 | Test Acc: 0.7486



Training:  30%|███       | 6/20 [04:51<11:21, 48.66s/it]

Epoch 6/20
Train Loss: 0.4079 | Train Acc: 0.8609
Test Loss: 0.6808 | Test Acc: 0.7745



Training:  35%|███▌      | 7/20 [05:40<10:32, 48.63s/it]

Epoch 7/20
Train Loss: 0.3745 | Train Acc: 0.8704
Test Loss: 0.6195 | Test Acc: 0.7937



Training:  40%|████      | 8/20 [06:29<09:43, 48.63s/it]

Epoch 8/20
Train Loss: 0.3132 | Train Acc: 0.8919
Test Loss: 0.6953 | Test Acc: 0.7847



Training:  45%|████▌     | 9/20 [07:17<08:54, 48.60s/it]

Epoch 9/20
Train Loss: 0.2842 | Train Acc: 0.8999
Test Loss: 0.6971 | Test Acc: 0.7869



Training:  50%|█████     | 10/20 [08:06<08:05, 48.60s/it]

Epoch 10/20
Train Loss: 0.2336 | Train Acc: 0.9202
Test Loss: 0.7124 | Test Acc: 0.7948



Training:  50%|█████     | 10/20 [08:33<08:33, 51.31s/it]


KeyboardInterrupt: 

# trying ensemble models..

In [None]:
%%python -m ipykernel install --user --name=venv

UsageError: %%python is a cell magic, but the cell body is empty.


In [None]:
import torch
import torch.nn as nn

from torchvision.models import densenet121, DenseNet121_Weights
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
from torchvision.models import alexnet, AlexNet_Weights
from torchvision.models import vgg16, VGG16_Weights
from torchvision.models import vgg19, VGG19_Weights

In [None]:
def changedClassifierLayer(model, modelName, N_CLASSES=10):
    for param in model.parameters():
      param.requires_grad = False

    if modelName == "DenseNet121":
      num_input = model.classifier.in_features

    elif modelName == "ResNet50":
      num_input = model.fc.in_features

    elif modelName == "EfficientNet-V2-M" or modelName == "AlexNet":
      num_input = model.classifier[1].in_features

    elif modelName == "VGG19" or modelName == "VGG16":
      num_input = model.classifier[0].in_features

    classifier = nn.Sequential(
      nn.Linear(num_input, 256),
      nn.ReLU(),
      nn.Dropout(0.2),
      nn.Linear(256, 128),
      nn.ReLU(),
      nn.Dropout(0.2),
      nn.Linear(128, N_CLASSES),
      nn.LogSoftmax(dim=1)
    )

    if modelName == "ResNet50":
      model.fc = classifier
    else:
      model.classifier = classifier

In [None]:
efficientnet_weights_path = 'models/EfficientNet-V2-M.pth'
densenet_weights_path = 'models/DenseNet121.pth'
resnet_weights_path = 'models/ResNet50.pth'
alexnet_weights_path = 'models/AlexNet.pth'
vgg16_weights_path = 'models/VGG16.pth'
vgg19_weights_path = 'models/VGG19.pth'

In [None]:
efficientnetV2M_model = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
densenet_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
resnet_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
alexnet_model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
vgg16_model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
vgg19_model = vgg19(weights=VGG19_Weights.IMAGENET1K_V1)

In [None]:
class EnsembleModel(nn.Module):
    def __init__(self, model_list, weights=None):
        super(EnsembleModel, self).__init__()
        self.models = nn.ModuleList(model_list)
        self.weights = weights

    def forward(self, x):
        outputs = []
        for model in self.models:
            output = model(x)  
            outputs.append(output)
        
        if self.weights is None:
            ensemble_output = torch.mean(torch.stack(outputs), dim=0)
        else:
            weighted_outputs = torch.stack([w * output for w, output in zip(self.weights, outputs)])
            ensemble_output = torch.sum(weighted_outputs, dim=0)

        return ensemble_output


In [None]:
models_list = [
    # efficientnetV2M_model.to(device),
    # densenet_model.to(device),
    resnet_model.to(device),
    alexnet_model.to(device),
    vgg16_model.to(device),
    vgg19_model.to(device)
]

ensemble_model = EnsembleModel(models_list)

model = ensemble_model.to(device)
parameters_to_optimize = []

for m in models_list:
    parameters_to_optimize += list(filter(lambda p: p.requires_grad, m.parameters()))

opt = torch.optim.AdamW(parameters_to_optimize, lr=0.005)
criterion = nn.CrossEntropyLoss()


train(train_loader, test_loader, model, criterion, opt, epochs=20)

Training:   5%|▌         | 1/20 [04:11<1:19:43, 251.79s/it]

Epoch 1/20
Train Loss: 1933581.3621 | Train Acc: 0.1952
Test Loss: 1.8647 | Test Acc: 0.2948



Training:  10%|█         | 2/20 [08:24<1:15:38, 252.16s/it]

Epoch 2/20
Train Loss: 1.8692 | Train Acc: 0.4199
Test Loss: 1.4648 | Test Acc: 0.4645



Training:  15%|█▌        | 3/20 [12:37<1:11:34, 252.62s/it]

Epoch 3/20
Train Loss: 1.3114 | Train Acc: 0.5469
Test Loss: 1.3096 | Test Acc: 0.5519



Training:  15%|█▌        | 3/20 [14:49<1:24:01, 296.58s/it]


RuntimeError: Detected more unique values in `preds` than expected. Expected only 10 but found 12 in `preds`. Found values: tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9, 473, 836],
       device='cuda:0').

In [None]:
%pip install transformers

Collecting transformers
  Downloading transformers-4.51.3-py3-none-any.whl (10.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.4/10.4 MB[0m [31m45.5 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
Collecting safetensors>=0.4.3
  Downloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (471 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m471.6/471.6 KB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.22,>=0.21
  Downloading tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m71.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hCollecting regex!=2019.12.17
  Downloading regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (781 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m781.7/781.7 KB[0m [31m83.5 MB/s[0m eta [36m0:00:00[0m
Installing co

# SWIN

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from albumentations import Compose, Normalize
from albumentations.pytorch import ToTensorV2
from transformers import AutoImageProcessor, Swinv2ForImageClassification

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = Swinv2ForImageClassification.from_pretrained(
#     "microsoft/swinv2-tiny-patch4-window8-256",
#     num_labels=10,
#     ignore_mismatched_sizes=True
# ).to(device)


# model = Swinv2ForImageClassification.from_pretrained(
#     "swin",
#     num_labels=10,
#     ignore_mismatched_sizes=True
# ).to(device)
# # processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
# processor = AutoImageProcessor.from_pretrained("swin")

from transformers import AutoImageProcessor, Swinv2ForImageClassification
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_dir = "./swin_3"  

try:
    processor = AutoImageProcessor.from_pretrained(model_dir)
except:
    processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
    print(f"Couldn't load processor from {model_dir}, using default processor")

model = Swinv2ForImageClassification.from_pretrained(
    model_dir,
    num_labels=10,
    ignore_mismatched_sizes=True
).to(device)




class GalaxyDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']
        label = item['label']

        image_np = np.array(image)
        if self.transform:
            transformed = self.transform(image=image_np)
            image = transformed['image']
        else:
            image = torch.from_numpy(image_np).permute(2, 0, 1).float()
        
        if isinstance(label, str):
            label = int(label)
        label = torch.tensor(label, dtype=torch.long)

        return image, label

# Define transforms - using SwinV2's recommended normalization
train_transform = Compose([
    # Rotate(limit=45, p=0.7),
    # HorizontalFlip(p=0.5),
    # RandomBrightnessContrast(p=0.3),

    # CoarseDropout(num_holes_range=(7,8), hole_height_range=(15,16), hole_width_range=(15, 16), p=0.5),
    Normalize(mean=processor.image_mean, std=processor.image_std),
    ToTensorV2()
])

test_transform = Compose([
    Normalize(mean=processor.image_mean, std=processor.image_std),
    ToTensorV2()
])

# Create datasets and dataloaders (assuming galaxy_dataset is already defined)
train_dataset = GalaxyDataset(galaxy_dataset['train'], transform=train_transform)
test_dataset = GalaxyDataset(galaxy_dataset['test'], transform=test_transform)

BATCH_SIZE = 64

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

# Modified training loop for SwinV2
def epoch_train(loader, model, criterion, optimizer):
    model.train()
    total_loss = 0.0
    accuracy = Accuracy(task="multiclass", num_classes=10).to(device)
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        
        # SwinV2 expects pixel_values as input
        outputs = model(pixel_values=images)
        loss = criterion(outputs.logits, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs.logits, 1)
        accuracy.update(preds, labels)
    
    avg_loss = total_loss / len(loader.dataset)
    avg_acc = accuracy.compute()
    return avg_loss, avg_acc

def epoch_test(loader, model, criterion, epoch=0, class_names=None):
    model.eval()
    total_loss = 0.0
    accuracy = Accuracy(task="multiclass", num_classes=10).to(device)
    all_labels = []
    all_preds = []
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(pixel_values=images)
            loss = criterion(outputs.logits, labels)
            
            total_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs.logits, 1)
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            accuracy.update(preds, labels)
    
    if class_names:
        plot_confusion_matrix(all_labels, all_preds, epoch, class_names)
    
    avg_loss = total_loss / len(loader.dataset)
    avg_acc = accuracy.compute()
    return avg_loss, avg_acc

def plot_confusion_matrix(all_labels, all_preds, epoch, class_names):
    cm = confusion_matrix(all_labels, all_preds)
    
    plt.figure(figsize=(10, 7))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'Confusion Matrix (Epoch {epoch})')
    
    filename = f'confusion_matrix_epoch_{epoch:03d}.png'
    plt.savefig(filename, bbox_inches='tight')
    plt.close()


optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()



preprocessor_config.json:   0%|          | 0.00/240 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Couldn't load processor from ./swin_3, using default processor


In [None]:
!pip cache purge
!rm -rf ~/.cache/pip

Files removed: 86


In [None]:
import os

def save_model(model, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    model.save_pretrained(output_dir)

In [None]:
def train(train_loader, test_loader, model, criterion, optimizer, epochs=50):
    model = model.to(device)
    
    for epoch in tqdm(range(epochs), desc="Training"):
        train_loss, train_acc = epoch_train(train_loader, model, criterion, optimizer)
        test_loss, test_acc = epoch_test(test_loader, model, criterion, epoch=epoch)

        if test_acc >= 0.9:
            save_model(model, "swin_final")

        
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}\n")

In [None]:
%pip install fastai>=2.7 fastcore>=1.3.27

In [None]:
%pip install toml

In [None]:
%pip install fastai==2.7 

In [None]:
%pip install torch>=2.0.0

In [None]:
%pip install numpy<2

/bin/bash: line 1: 2: No such file or directory


Note: you may need to restart the kernel to use updated packages.


In [None]:
from huggingface_hub import from_pretrained_fastai

learn = from_pretrained_fastai("dcarpintero/fastai-interstellar-object")

In [None]:
train(train_loader, test_loader, learn, criterion, optimizer, epochs=5)

NameError: name 'train' is not defined

In [None]:
train(train_loader, test_loader, model, criterion, optimizer, epochs=5)

Training:   0%|          | 0/5 [00:00<?, ?it/s]

Training:  20%|██        | 1/5 [01:27<05:48, 87.17s/it]

Epoch 1/5
Train Loss: 0.1842 | Train Acc: 0.9357
Test Loss: 0.5391 | Test Acc: 0.8484



Training:  40%|████      | 2/5 [02:54<04:21, 87.23s/it]

Epoch 2/5
Train Loss: 0.1764 | Train Acc: 0.9400
Test Loss: 0.5524 | Test Acc: 0.8596



Training:  60%|██████    | 3/5 [04:21<02:54, 87.09s/it]

Epoch 3/5
Train Loss: 0.1645 | Train Acc: 0.9419
Test Loss: 0.6013 | Test Acc: 0.8388



Training:  80%|████████  | 4/5 [05:48<01:26, 86.98s/it]

Epoch 4/5
Train Loss: 0.1660 | Train Acc: 0.9426
Test Loss: 0.5852 | Test Acc: 0.8427



Training: 100%|██████████| 5/5 [07:15<00:00, 87.03s/it]

Epoch 5/5
Train Loss: 0.1526 | Train Acc: 0.9472
Test Loss: 0.5568 | Test Acc: 0.8393






In [None]:
save_model(model, "swin_3")

# Evaluation

In [None]:
preds = # <Your preidctions here for TEST>
true_test_labels = galaxy_dataset['test']['label']
test_metrics = evaluate_predictions(preds, true_test_labels, class_names)