# ResNet baseline

- We also observe the accuracy of using ResNet-18 for the classification task, to compare whether our weak MAE pretrianing is useful or not. Since the model we train for the task starts to overfit, we interupt it and dont let the training proceed.
- Usually as a good coding practice, I follow early stopping protocols, but, I decide not to go forward with that here, just to observe the entire training dynamics.

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.models as models


class LensDatasetWithLabels(Dataset):
    def __init__(self, dataset_dirs, transform=None):
        self.data = []
        self.labels = []
        self.transform = transform
        self.class_labels = {"axion": 0, "cdm": 1, "no_sub": 2}
        for class_name, dir_path in dataset_dirs.items():
            label = self.class_labels[class_name]
            for file in os.listdir(dir_path):
                if file.endswith('.npy'):
                    file_path = os.path.join(dir_path, file)
                    self.data.append(file_path)
                    self.labels.append(label)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        img = np.load(img_path, allow_pickle=True)
        
        if isinstance(img, np.ndarray) and img.dtype == object:
            img = img[0]
        
        img = np.array(img, dtype=np.float32)
        
        if img.ndim > 2:
            img = img[:, :, 0]
        
        if self.transform:
            img = self.transform(img)
            
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return img, label
    
dataset_dirs = {
    "axion": "dataset/Dataset/axion",
    "cdm": "dataset/Dataset/cdm",
    "no_sub": "dataset/Dataset/no_sub"
}

# For ResNet18, we need 3-channel images. Convert grayscale (1-channel) to 3 channels.
data_transform_resnet = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # replicate the single channel into 3 channels
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create the dataset and split into training (90%) and validation (10%)
dataset_resnet = LensDatasetWithLabels(dataset_dirs, transform=data_transform_resnet)
train_size = int(0.9 * len(dataset_resnet))
val_size = len(dataset_resnet) - train_size
train_dataset, val_dataset = random_split(dataset_resnet, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

resnet18_model = models.resnet18(pretrained=True)
num_ftrs = resnet18_model.fc.in_features
resnet18_model.fc = nn.Linear(num_ftrs, 3)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet18_model = resnet18_model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet18_model.parameters(), lr=1e-4)
num_epochs = 20


for epoch in range(num_epochs):
    resnet18_model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        
        outputs = resnet18_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()
    
    train_loss = running_loss / len(train_dataset)
    train_acc = correct_train / total_train
    
    # Validation Phase
    resnet18_model.eval()
    running_val_loss = 0.0
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = resnet18_model(images)
            loss = criterion(outputs, labels)
            
            running_val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()
    
    val_loss = running_val_loss / len(val_dataset)
    val_acc = correct_val / total_val
    
    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")





Epoch [1/20] Train Loss: 0.2762, Train Acc: 0.8943 | Val Loss: 5.5360, Val Acc: 0.3769
Epoch [2/20] Train Loss: 0.1216, Train Acc: 0.9566 | Val Loss: 0.1080, Val Acc: 0.9593
Epoch [3/20] Train Loss: 0.0875, Train Acc: 0.9707 | Val Loss: 0.4815, Val Acc: 0.8816
Epoch [4/20] Train Loss: 0.0717, Train Acc: 0.9760 | Val Loss: 1.6454, Val Acc: 0.6270
Epoch [5/20] Train Loss: 0.0609, Train Acc: 0.9804 | Val Loss: 0.5672, Val Acc: 0.8536
Epoch [6/20] Train Loss: 0.0534, Train Acc: 0.9824 | Val Loss: 1.4781, Val Acc: 0.7246
Epoch [7/20] Train Loss: 0.0465, Train Acc: 0.9850 | Val Loss: 23.8683, Val Acc: 0.3367
Epoch [8/20] Train Loss: 0.0442, Train Acc: 0.9856 | Val Loss: 0.5983, Val Acc: 0.8911
Epoch [9/20] Train Loss: 0.0393, Train Acc: 0.9873 | Val Loss: 17.3061, Val Acc: 0.3367
Epoch [10/20] Train Loss: 0.0373, Train Acc: 0.9882 | Val Loss: 0.0481, Val Acc: 0.9852
Epoch [11/20] Train Loss: 0.0351, Train Acc: 0.9882 | Val Loss: 1.4913, Val Acc: 0.7471
Epoch [12/20] Train Loss: 0.0307, Train

KeyboardInterrupt: 

In [2]:

# -------------------------------
# Saving the Model
# -------------------------------
torch.save(resnet18_model.state_dict(), "resnet18_classifier.pth")