In [1]:
!pip install einops



In [2]:
import torch
import timm
import numpy as np

from einops import repeat, rearrange
from einops.layers.torch import Rearrange

from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
from torchvision import transforms
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
import wandb
import os
import shutil
from sklearn.model_selection import train_test_split
from PIL import Image
import math
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import roc_curve, auc



In [3]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msamkitshah1262[0m ([33msamkitshah1262-warner-bros-discovery[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
wandb.init(project="deeplens-foundational", entity="samkitshah1262-warner-bros-discovery")

In [5]:
# Configuration
class Config:
    SEED = 1
    BATCH_SIZE = 64
    EPOCHS = 200
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    MASK_RATIO = 0.75
    NUM_CLASSES = 3
    IMG_SIZE = (64, 64)
    USE_SAVED_MODEL = False
    CKPT_PATH = "./model_weights.pth"

config = Config()

# Set random seeds
torch.manual_seed(config.SEED)
np.random.seed(config.SEED)

In [6]:
device = torch.device("mps" if torch.mps.is_available() else "cpu")

In [7]:
img = np.load("dataset/no_sub/no_sub_sim_6956151560647865808482838248806684.npy")  # Load .npy file
img = torch.tensor(img, dtype=torch.float32)
img.shape

torch.Size([64, 64])

In [8]:
DATA_PATH_ROOT = "/Users/sshah/2024/projects/gsoc/ml4sci/fm/Dataset"
DATA_PATH_AXION = "/axion"
DATA_PATH_CDM = "/cdm"
DATA_PATH_NO_SUB = "/no_sub"

OUTPUT_ROOT = "/dataset"

In [9]:
class NpyDirectoryDataset(Dataset):

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}

        class_names = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        class_names.sort()

        for idx, class_name in enumerate(class_names):
            self.class_to_idx[class_name] = idx

        for class_name in class_names:
            class_dir = os.path.join(root_dir, class_name)
            class_idx = self.class_to_idx[class_name]
            
            for file_name in os.listdir(class_dir):
                file_path = os.path.join(class_dir, file_name)
                if file_path.endswith('.npy') and os.path.isfile(file_path):
                    self.samples.append((file_path, class_idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        file_path, label = self.samples[idx]
        
        try:
            data = np.load(file_path, allow_pickle=True)

            data_tensor = torch.from_numpy(data).float()
            if data_tensor.ndim == 2:
                data_tensor = data_tensor.unsqueeze(0)  
            if self.transform:
                data_tensor = self.transform(data_tensor)
            
            return data_tensor, label
            
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            placeholder = torch.zeros((1, 64, 64))
            return placeholder, label

In [10]:
transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.Grayscale(num_output_channels=1),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [11]:
def create_data_loaders(root_dir, batch_size=32, train_ratio=0.8, num_workers=4, seed=42):

    torch.manual_seed(seed)
    np.random.seed(seed)

    full_dataset = NpyDirectoryDataset(root_dir=root_dir,transform=transforms)
    
    if len(full_dataset) == 0:
        raise ValueError(f"No valid .npy files found in {root_dir}. Please check the directory structure.")
    
    print(f"Found {len(full_dataset)} .npy files total")
    print(f"Class mapping: {full_dataset.class_to_idx}")

    total_size = len(full_dataset)
    train_size = int(train_ratio * total_size)
    val_size = total_size - train_size

    train_dataset, val_dataset = random_split(
        full_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(seed)
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers
    )

    class_loaders = {}

    class_names = list(full_dataset.class_to_idx.keys())
    
    for class_name in class_names:
        class_idx = full_dataset.class_to_idx[class_name]
        class_indices = [i for i, (_, label) in enumerate(full_dataset.samples) if label == class_idx]
        
        if len(class_indices) == 0:
            print(f"Warning: No samples found for class {class_name}")
            continue

        class_train_indices = [idx for idx in range(len(full_dataset)) if 
                              idx in train_dataset.indices and 
                              full_dataset.samples[idx][1] == class_idx]
                              
        class_val_indices = [idx for idx in range(len(full_dataset)) if 
                            idx in val_dataset.indices and 
                            full_dataset.samples[idx][1] == class_idx]
        
        class_train_subset = torch.utils.data.Subset(full_dataset, class_train_indices)
        class_val_subset = torch.utils.data.Subset(full_dataset, class_val_indices)

        class_loaders[f"{class_name}_train"] = DataLoader(
            class_train_subset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers
        )
        
        class_loaders[f"{class_name}_val"] = DataLoader(
            class_val_subset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers
        )

    loaders = {
        "train": train_loader,
        "val": val_loader,
        **class_loaders
    }

    print("\nDataset information:")
    print(f"Training set: {len(train_dataset)} samples")
    print(f"Validation set: {len(val_dataset)} samples")
    
    for name, loader in class_loaders.items():
        print(f"{name}: {len(loader.dataset)} samples")
    
    return loaders

In [12]:
root_directory = DATA_PATH_ROOT

loaders = create_data_loaders(
    root_dir=root_directory,
    batch_size=config.BATCH_SIZE,
    train_ratio=0.9,
    num_workers=0
)

Found 89104 .npy files total
Class mapping: {'axion': 0, 'cdm': 1, 'no_sub': 2}

Dataset information:
Training set: 80193 samples
Validation set: 8911 samples
axion_train: 26981 samples
axion_val: 2915 samples
cdm_train: 26723 samples
cdm_val: 3036 samples
no_sub_train: 26489 samples
no_sub_val: 2960 samples


In [13]:
axion_train_loader = loaders["axion_train"]
cdm_val_loader = loaders["cdm_val"]
combined_train_loader = loaders["train"]

print(f"Number of classes: {len(loaders)//2 - 1}")

for name, loader in loaders.items():
    print(f"{name}: {len(loader.dataset)} samples, {len(loader)} batches")

Number of classes: 3
train: 80193 samples, 1254 batches
val: 8911 samples, 140 batches
axion_train: 26981 samples, 422 batches
axion_val: 2915 samples, 46 batches
cdm_train: 26723 samples, 418 batches
cdm_val: 3036 samples, 48 batches
no_sub_train: 26489 samples, 414 batches
no_sub_val: 2960 samples, 47 batches


In [14]:
def random_indexes(size : int):
    forward_indexes = np.arange(size)
    np.random.shuffle(forward_indexes)
    backward_indexes = np.argsort(forward_indexes)
    return forward_indexes, backward_indexes

def take_indexes(sequences, indexes):
    return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))

class PatchShuffle(torch.nn.Module):
    def __init__(self, ratio) -> None:
        super().__init__()
        self.ratio = ratio

    def forward(self, patches : torch.Tensor):
        T, B, C = patches.shape
        remain_T = int(T * (1 - self.ratio))

        indexes = [random_indexes(T) for _ in range(B)]
        forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
        backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)

        patches = take_indexes(patches, forward_indexes)
        patches = patches[:remain_T]

        return patches, forward_indexes, backward_indexes

class MAE_Encoder(torch.nn.Module):
    def __init__(self,
                 image_size=64,
                 patch_size=4,
                 emb_dim=192,
                 num_layer=12,
                 num_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
        self.shuffle = PatchShuffle(mask_ratio)

        self.patchify = torch.nn.Conv2d(1, emb_dim, patch_size, patch_size)

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.layer_norm = torch.nn.LayerNorm(emb_dim)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.cls_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding

        patches, forward_indexes, backward_indexes = self.shuffle(patches)

        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')

        return features, backward_indexes

class MAE_Decoder(torch.nn.Module):
    def __init__(self,
                 image_size=64,
                 patch_size=4,
                 emb_dim=192,
                 num_layer=4,
                 num_head=3,
                 ) -> None:
        super().__init__()

        self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.head = torch.nn.Linear(emb_dim, 1 * patch_size ** 2)
        self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.mask_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, features, backward_indexes):
        T = features.shape[0]
        backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
        features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
        features = take_indexes(features, backward_indexes)
        features = features + self.pos_embedding

        features = rearrange(features, 't b c -> b t c')
        features = self.transformer(features)
        features = rearrange(features, 'b t c -> t b c')
        features = features[1:]

        patches = self.head(features)
        mask = torch.zeros_like(patches)
        mask[T-1:] = 1
        mask = take_indexes(mask, backward_indexes[1:] - 1)
        img = self.patch2img(patches)
        mask = self.patch2img(mask)

        return img, mask

class MAE_ViT(torch.nn.Module):
    def __init__(self,
                 image_size=64,
                 patch_size=4,
                 emb_dim=192,
                 encoder_layer=12,
                 encoder_head=3,
                 decoder_layer=4,
                 decoder_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
        self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)

    def forward(self, img):
        features, backward_indexes = self.encoder(img)
        predicted_img, mask = self.decoder(features,  backward_indexes)
        return predicted_img, mask

In [15]:
model = MAE_ViT(mask_ratio=config.MASK_RATIO).to(device)

In [16]:
optimizer = torch.optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.95), weight_decay=config.WEIGHT_DECAY)
def lr_func(epoch):
    return min((epoch + 1) / (200 + 1e-8), 0.5 * (math.cos(epoch / config.EPOCHS * math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func, verbose=True)



In [17]:
train_dataloader = loaders['no_sub_train']
val_dataloader = loaders['no_sub_val']

In [18]:
def train_mae(model, train_loader, optimizer, device, epoch):
    model.train()
    total_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]", leave=False)
    
    for images, _ in progress_bar:
        images = images.to(device)
        # print("Image shape before model:", images.shape)   
        images = images.clamp(0, 1)
        images = images.repeat(1, 3, 1, 1)
        optimizer.zero_grad()
        reconstructed, mask = model(images)
        loss = F.mse_loss(reconstructed * mask, images * mask)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * images.size(0)
        progress_bar.set_postfix({'loss': loss.item()})
    
    return total_loss / len(train_loader.dataset)

def validate_mae(model, val_loader, device, epoch):
    model.eval()
    total_loss = 0.0
    progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]", leave=False)
    
    with torch.no_grad():
        for images, _ in progress_bar:
            images = images.to(device)
            images = images.clamp(0, 1)
            images = images.repeat(1, 3, 1, 1)
            reconstructed, mask = model(images)
            loss = F.mse_loss(reconstructed * mask, images * mask)
            
            total_loss += loss.item() * images.size(0)
            progress_bar.set_postfix({'val_loss': loss.item()})
    
    return total_loss / len(val_loader.dataset)

In [231]:
best_loss = float('inf')
for epoch in range(config.EPOCHS):
    train_loss = train_mae(model, train_dataloader, optimizer, device, epoch)
    val_loss = validate_mae(model, val_dataloader, device, epoch)
    
    print(f"Epoch {epoch+1}/{config.EPOCHS}")
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    wandb.log({
        "epoch": epoch,
        "train_loss": train_loss,
        "val_loss": val_loss,
    })

    # Save best model
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': val_loss,
        }, "best_mae_model.pth")
        print(f"Saved new best model with val loss: {val_loss:.4f}")

                                                                               

Epoch 1/200
Train Loss: 0.0529 | Val Loss: 0.0354
Saved new best model with val loss: 0.0354


                                                                               

Epoch 2/200
Train Loss: 0.0351 | Val Loss: 0.0346
Saved new best model with val loss: 0.0346


                                                                               

Epoch 3/200
Train Loss: 0.0334 | Val Loss: 0.0319
Saved new best model with val loss: 0.0319


                                                                               

Epoch 4/200
Train Loss: 0.0300 | Val Loss: 0.0279
Saved new best model with val loss: 0.0279


                                                                               

Epoch 5/200
Train Loss: 0.0256 | Val Loss: 0.0235
Saved new best model with val loss: 0.0235


                                                                               

Epoch 6/200
Train Loss: 0.0222 | Val Loss: 0.0212
Saved new best model with val loss: 0.0212


                                                                               

Epoch 7/200
Train Loss: 0.0207 | Val Loss: 0.0205
Saved new best model with val loss: 0.0205


                                                                               

Epoch 8/200
Train Loss: 0.0202 | Val Loss: 0.0200
Saved new best model with val loss: 0.0200


                                                                               

Epoch 9/200
Train Loss: 0.0199 | Val Loss: 0.0198
Saved new best model with val loss: 0.0198


                                                                                

Epoch 10/200
Train Loss: 0.0196 | Val Loss: 0.0196
Saved new best model with val loss: 0.0196


                                                                                

Epoch 11/200
Train Loss: 0.0195 | Val Loss: 0.0193
Saved new best model with val loss: 0.0193


                                                                                

Epoch 12/200
Train Loss: 0.0194 | Val Loss: 0.0193
Saved new best model with val loss: 0.0193


                                                                                

Epoch 13/200
Train Loss: 0.0192 | Val Loss: 0.0193
Saved new best model with val loss: 0.0193


                                                                                

Epoch 14/200
Train Loss: 0.0191 | Val Loss: 0.0190
Saved new best model with val loss: 0.0190


                                                                                

Epoch 15/200
Train Loss: 0.0190 | Val Loss: 0.0190
Saved new best model with val loss: 0.0190


                                                                                

Epoch 16/200
Train Loss: 0.0189 | Val Loss: 0.0189
Saved new best model with val loss: 0.0189


                                                                                

Epoch 17/200
Train Loss: 0.0188 | Val Loss: 0.0188
Saved new best model with val loss: 0.0188


                                                                                

Epoch 18/200
Train Loss: 0.0188 | Val Loss: 0.0188
Saved new best model with val loss: 0.0188


                                                                                

Epoch 19/200
Train Loss: 0.0186 | Val Loss: 0.0186
Saved new best model with val loss: 0.0186


                                                                                

Epoch 20/200
Train Loss: 0.0185 | Val Loss: 0.0185
Saved new best model with val loss: 0.0185


                                                                                

Epoch 21/200
Train Loss: 0.0184 | Val Loss: 0.0187


                                                                                

Epoch 22/200
Train Loss: 0.0183 | Val Loss: 0.0183
Saved new best model with val loss: 0.0183


                                                                                

Epoch 23/200
Train Loss: 0.0182 | Val Loss: 0.0182
Saved new best model with val loss: 0.0182


                                                                                

Epoch 24/200
Train Loss: 0.0181 | Val Loss: 0.0181
Saved new best model with val loss: 0.0181


                                                                                

Epoch 25/200
Train Loss: 0.0180 | Val Loss: 0.0179
Saved new best model with val loss: 0.0179


                                                                                

Epoch 26/200
Train Loss: 0.0180 | Val Loss: 0.0179
Saved new best model with val loss: 0.0179


                                                                                

Epoch 27/200
Train Loss: 0.0179 | Val Loss: 0.0178
Saved new best model with val loss: 0.0178


                                                                                

Epoch 28/200
Train Loss: 0.0177 | Val Loss: 0.0176
Saved new best model with val loss: 0.0176


                                                                                

Epoch 29/200
Train Loss: 0.0176 | Val Loss: 0.0176
Saved new best model with val loss: 0.0176


                                                                                

Epoch 30/200
Train Loss: 0.0175 | Val Loss: 0.0175
Saved new best model with val loss: 0.0175


                                                                                

Epoch 31/200
Train Loss: 0.0174 | Val Loss: 0.0174
Saved new best model with val loss: 0.0174


                                                                                

Epoch 32/200
Train Loss: 0.0173 | Val Loss: 0.0173
Saved new best model with val loss: 0.0173


                                                                                

Epoch 33/200
Train Loss: 0.0173 | Val Loss: 0.0175


                                                                                

Epoch 34/200
Train Loss: 0.0172 | Val Loss: 0.0172
Saved new best model with val loss: 0.0172


                                                                                

Epoch 35/200
Train Loss: 0.0171 | Val Loss: 0.0171
Saved new best model with val loss: 0.0171


                                                                                

Epoch 36/200
Train Loss: 0.0170 | Val Loss: 0.0170
Saved new best model with val loss: 0.0170


                                                                                

Epoch 37/200
Train Loss: 0.0169 | Val Loss: 0.0169
Saved new best model with val loss: 0.0169


                                                                                

Epoch 38/200
Train Loss: 0.0169 | Val Loss: 0.0169
Saved new best model with val loss: 0.0169


                                                                                

Epoch 39/200
Train Loss: 0.0168 | Val Loss: 0.0168
Saved new best model with val loss: 0.0168


                                                                                

Epoch 40/200
Train Loss: 0.0167 | Val Loss: 0.0168
Saved new best model with val loss: 0.0168


                                                                                

Epoch 41/200
Train Loss: 0.0167 | Val Loss: 0.0167
Saved new best model with val loss: 0.0167


                                                                                

Epoch 42/200
Train Loss: 0.0166 | Val Loss: 0.0166
Saved new best model with val loss: 0.0166


                                                                                

Epoch 43/200
Train Loss: 0.0166 | Val Loss: 0.0165
Saved new best model with val loss: 0.0165


                                                                                

Epoch 44/200
Train Loss: 0.0165 | Val Loss: 0.0166


                                                                                

Epoch 45/200
Train Loss: 0.0164 | Val Loss: 0.0165
Saved new best model with val loss: 0.0165


                                                                                

Epoch 46/200
Train Loss: 0.0164 | Val Loss: 0.0162
Saved new best model with val loss: 0.0162


                                                                                

Epoch 47/200
Train Loss: 0.0164 | Val Loss: 0.0163


                                                                                

Epoch 48/200
Train Loss: 0.0163 | Val Loss: 0.0166


                                                                                

Epoch 49/200
Train Loss: 0.0163 | Val Loss: 0.0162
Saved new best model with val loss: 0.0162


                                                                                

Epoch 50/200
Train Loss: 0.0162 | Val Loss: 0.0161
Saved new best model with val loss: 0.0161


                                                                                

Epoch 51/200
Train Loss: 0.0162 | Val Loss: 0.0161
Saved new best model with val loss: 0.0161


                                                                                

Epoch 52/200
Train Loss: 0.0161 | Val Loss: 0.0161


                                                                                

Epoch 53/200
Train Loss: 0.0160 | Val Loss: 0.0161
Saved new best model with val loss: 0.0161


                                                                                

Epoch 54/200
Train Loss: 0.0160 | Val Loss: 0.0159
Saved new best model with val loss: 0.0159


                                                                                

Epoch 55/200
Train Loss: 0.0159 | Val Loss: 0.0160


                                                                                

Epoch 56/200
Train Loss: 0.0159 | Val Loss: 0.0159


                                                                                

Epoch 57/200
Train Loss: 0.0158 | Val Loss: 0.0158
Saved new best model with val loss: 0.0158


                                                                                

Epoch 58/200
Train Loss: 0.0157 | Val Loss: 0.0158
Saved new best model with val loss: 0.0158


                                                                                

Epoch 59/200
Train Loss: 0.0157 | Val Loss: 0.0156
Saved new best model with val loss: 0.0156


                                                                                

Epoch 60/200
Train Loss: 0.0156 | Val Loss: 0.0155
Saved new best model with val loss: 0.0155


                                                                                

Epoch 61/200
Train Loss: 0.0156 | Val Loss: 0.0154
Saved new best model with val loss: 0.0154


                                                                                

Epoch 62/200
Train Loss: 0.0154 | Val Loss: 0.0154
Saved new best model with val loss: 0.0154


                                                                                

Epoch 63/200
Train Loss: 0.0154 | Val Loss: 0.0154


                                                                                

Epoch 64/200
Train Loss: 0.0153 | Val Loss: 0.0152
Saved new best model with val loss: 0.0152


                                                                                

Epoch 65/200
Train Loss: 0.0152 | Val Loss: 0.0151
Saved new best model with val loss: 0.0151


                                                                                

Epoch 66/200
Train Loss: 0.0151 | Val Loss: 0.0150
Saved new best model with val loss: 0.0150


                                                                                

Epoch 67/200
Train Loss: 0.0150 | Val Loss: 0.0149
Saved new best model with val loss: 0.0149


                                                                                

Epoch 68/200
Train Loss: 0.0150 | Val Loss: 0.0150


                                                                                

Epoch 69/200
Train Loss: 0.0149 | Val Loss: 0.0147
Saved new best model with val loss: 0.0147


                                                                                

Epoch 70/200
Train Loss: 0.0147 | Val Loss: 0.0147


                                                                                

Epoch 71/200
Train Loss: 0.0146 | Val Loss: 0.0145
Saved new best model with val loss: 0.0145


                                                                                

Epoch 72/200
Train Loss: 0.0145 | Val Loss: 0.0144
Saved new best model with val loss: 0.0144


                                                                                

Epoch 73/200
Train Loss: 0.0144 | Val Loss: 0.0143
Saved new best model with val loss: 0.0143


                                                                                

Epoch 74/200
Train Loss: 0.0143 | Val Loss: 0.0141
Saved new best model with val loss: 0.0141


                                                                                

Epoch 75/200
Train Loss: 0.0142 | Val Loss: 0.0143


                                                                                

Epoch 76/200
Train Loss: 0.0141 | Val Loss: 0.0138
Saved new best model with val loss: 0.0138


                                                                                

Epoch 77/200
Train Loss: 0.0139 | Val Loss: 0.0138
Saved new best model with val loss: 0.0138


                                                                                

Epoch 78/200
Train Loss: 0.0139 | Val Loss: 0.0138
Saved new best model with val loss: 0.0138


                                                                                

Epoch 79/200
Train Loss: 0.0137 | Val Loss: 0.0136
Saved new best model with val loss: 0.0136


                                                                                

Epoch 80/200
Train Loss: 0.0136 | Val Loss: 0.0134
Saved new best model with val loss: 0.0134


                                                                                

Epoch 81/200
Train Loss: 0.0135 | Val Loss: 0.0133
Saved new best model with val loss: 0.0133


                                                                                

Epoch 82/200
Train Loss: 0.0134 | Val Loss: 0.0133
Saved new best model with val loss: 0.0133


                                                                                

Epoch 83/200
Train Loss: 0.0133 | Val Loss: 0.0134


                                                                                

Epoch 84/200
Train Loss: 0.0132 | Val Loss: 0.0130
Saved new best model with val loss: 0.0130


                                                                                

Epoch 85/200
Train Loss: 0.0130 | Val Loss: 0.0131


                                                                                

Epoch 86/200
Train Loss: 0.0129 | Val Loss: 0.0128
Saved new best model with val loss: 0.0128


                                                                                

Epoch 87/200
Train Loss: 0.0128 | Val Loss: 0.0128
Saved new best model with val loss: 0.0128


                                                                                

Epoch 88/200
Train Loss: 0.0128 | Val Loss: 0.0126
Saved new best model with val loss: 0.0126


                                                                                

Epoch 89/200
Train Loss: 0.0127 | Val Loss: 0.0126


                                                                                

Epoch 90/200
Train Loss: 0.0126 | Val Loss: 0.0123
Saved new best model with val loss: 0.0123


                                                                                

Epoch 91/200
Train Loss: 0.0125 | Val Loss: 0.0124


                                                                                

Epoch 92/200
Train Loss: 0.0124 | Val Loss: 0.0122
Saved new best model with val loss: 0.0122


                                                                                

Epoch 93/200
Train Loss: 0.0122 | Val Loss: 0.0124


                                                                                

Epoch 94/200
Train Loss: 0.0121 | Val Loss: 0.0120
Saved new best model with val loss: 0.0120


                                                                                

Epoch 95/200
Train Loss: 0.0120 | Val Loss: 0.0121


                                                                                

Epoch 96/200
Train Loss: 0.0120 | Val Loss: 0.0118
Saved new best model with val loss: 0.0118


                                                                                

Epoch 97/200
Train Loss: 0.0118 | Val Loss: 0.0117
Saved new best model with val loss: 0.0117


                                                                                

Epoch 98/200
Train Loss: 0.0118 | Val Loss: 0.0117


                                                                                

Epoch 99/200
Train Loss: 0.0117 | Val Loss: 0.0114
Saved new best model with val loss: 0.0114


                                                                                 

Epoch 100/200
Train Loss: 0.0116 | Val Loss: 0.0116


                                                                                  

Epoch 101/200
Train Loss: 0.0115 | Val Loss: 0.0116


                                                                                  

Epoch 102/200
Train Loss: 0.0114 | Val Loss: 0.0114
Saved new best model with val loss: 0.0114


                                                                                  

Epoch 103/200
Train Loss: 0.0114 | Val Loss: 0.0112
Saved new best model with val loss: 0.0112


                                                                                  

Epoch 104/200
Train Loss: 0.0112 | Val Loss: 0.0112


                                                                                  

Epoch 105/200
Train Loss: 0.0112 | Val Loss: 0.0111
Saved new best model with val loss: 0.0111


                                                                                  

Epoch 106/200
Train Loss: 0.0111 | Val Loss: 0.0113


                                                                                  

Epoch 107/200
Train Loss: 0.0110 | Val Loss: 0.0111


                                                                                  

Epoch 108/200
Train Loss: 0.0110 | Val Loss: 0.0108
Saved new best model with val loss: 0.0108


                                                                                  

Epoch 109/200
Train Loss: 0.0109 | Val Loss: 0.0109


                                                                                  

Epoch 110/200
Train Loss: 0.0108 | Val Loss: 0.0108
Saved new best model with val loss: 0.0108


                                                                                  

Epoch 111/200
Train Loss: 0.0107 | Val Loss: 0.0106
Saved new best model with val loss: 0.0106


                                                                                  

Epoch 112/200
Train Loss: 0.0106 | Val Loss: 0.0104
Saved new best model with val loss: 0.0104


                                                                                  

Epoch 113/200
Train Loss: 0.0105 | Val Loss: 0.0104
Saved new best model with val loss: 0.0104


                                                                                  

Epoch 114/200
Train Loss: 0.0105 | Val Loss: 0.0104


                                                                                  

Epoch 115/200
Train Loss: 0.0104 | Val Loss: 0.0104


                                                                                  

Epoch 116/200
Train Loss: 0.0103 | Val Loss: 0.0103
Saved new best model with val loss: 0.0103


                                                                                  

Epoch 117/200
Train Loss: 0.0102 | Val Loss: 0.0104


                                                                                  

Epoch 118/200
Train Loss: 0.0102 | Val Loss: 0.0105


                                                                                  

Epoch 119/200
Train Loss: 0.0101 | Val Loss: 0.0101
Saved new best model with val loss: 0.0101


                                                                                  

Epoch 120/200
Train Loss: 0.0100 | Val Loss: 0.0100
Saved new best model with val loss: 0.0100


                                                                                  

Epoch 121/200
Train Loss: 0.0100 | Val Loss: 0.0101


                                                                                  

Epoch 122/200
Train Loss: 0.0099 | Val Loss: 0.0098
Saved new best model with val loss: 0.0098


                                                                                  

Epoch 123/200
Train Loss: 0.0099 | Val Loss: 0.0098
Saved new best model with val loss: 0.0098


                                                                                  

Epoch 124/200
Train Loss: 0.0098 | Val Loss: 0.0100


                                                                                  

Epoch 125/200
Train Loss: 0.0097 | Val Loss: 0.0096
Saved new best model with val loss: 0.0096


                                                                                  

Epoch 126/200
Train Loss: 0.0097 | Val Loss: 0.0097


                                                                                  

Epoch 127/200
Train Loss: 0.0096 | Val Loss: 0.0095
Saved new best model with val loss: 0.0095


                                                                                  

Epoch 128/200
Train Loss: 0.0096 | Val Loss: 0.0095


                                                                                  

Epoch 129/200
Train Loss: 0.0095 | Val Loss: 0.0091
Saved new best model with val loss: 0.0091


                                                                                  

Epoch 130/200
Train Loss: 0.0094 | Val Loss: 0.0095


                                                                                  

Epoch 131/200
Train Loss: 0.0094 | Val Loss: 0.0092


                                                                                  

Epoch 132/200
Train Loss: 0.0093 | Val Loss: 0.0092


                                                                                  

Epoch 133/200
Train Loss: 0.0093 | Val Loss: 0.0092


                                                                                  

Epoch 134/200
Train Loss: 0.0092 | Val Loss: 0.0093


                                                                                  

Epoch 135/200
Train Loss: 0.0091 | Val Loss: 0.0090
Saved new best model with val loss: 0.0090


                                                                                  

Epoch 136/200
Train Loss: 0.0091 | Val Loss: 0.0089
Saved new best model with val loss: 0.0089


                                                                                  

Epoch 137/200
Train Loss: 0.0090 | Val Loss: 0.0090


                                                                                  

Epoch 138/200
Train Loss: 0.0090 | Val Loss: 0.0090


                                                                                  

Epoch 139/200
Train Loss: 0.0089 | Val Loss: 0.0089
Saved new best model with val loss: 0.0089


                                                                                    

Epoch 140/200
Train Loss: 0.0089 | Val Loss: 0.0088
Saved new best model with val loss: 0.0088


                                                                                  

Epoch 141/200
Train Loss: 0.0088 | Val Loss: 0.0086
Saved new best model with val loss: 0.0086


                                                                                  

Epoch 142/200
Train Loss: 0.0088 | Val Loss: 0.0087


                                                                                  

Epoch 143/200
Train Loss: 0.0087 | Val Loss: 0.0086


                                                                                  

Epoch 144/200
Train Loss: 0.0087 | Val Loss: 0.0086
Saved new best model with val loss: 0.0086


                                                                                  

Epoch 145/200
Train Loss: 0.0086 | Val Loss: 0.0085
Saved new best model with val loss: 0.0085


                                                                                  

Epoch 146/200
Train Loss: 0.0085 | Val Loss: 0.0084
Saved new best model with val loss: 0.0084


                                                                                  

Epoch 147/200
Train Loss: 0.0085 | Val Loss: 0.0085


                                                                                  

Epoch 148/200
Train Loss: 0.0085 | Val Loss: 0.0083
Saved new best model with val loss: 0.0083


                                                                                  

Epoch 149/200
Train Loss: 0.0084 | Val Loss: 0.0083
Saved new best model with val loss: 0.0083


                                                                                  

Epoch 150/200
Train Loss: 0.0084 | Val Loss: 0.0082
Saved new best model with val loss: 0.0082


                                                                                  

Epoch 151/200
Train Loss: 0.0083 | Val Loss: 0.0085


                                                                                  

Epoch 152/200
Train Loss: 0.0083 | Val Loss: 0.0082
Saved new best model with val loss: 0.0082


                                                                                  

Epoch 153/200
Train Loss: 0.0082 | Val Loss: 0.0080
Saved new best model with val loss: 0.0080


                                                                                  

Epoch 154/200
Train Loss: 0.0082 | Val Loss: 0.0079
Saved new best model with val loss: 0.0079


                                                                                  

Epoch 155/200
Train Loss: 0.0081 | Val Loss: 0.0081


                                                                                  

Epoch 156/200
Train Loss: 0.0081 | Val Loss: 0.0079
Saved new best model with val loss: 0.0079


                                                                                  

Epoch 157/200
Train Loss: 0.0080 | Val Loss: 0.0078
Saved new best model with val loss: 0.0078


                                                                                  

Epoch 158/200
Train Loss: 0.0079 | Val Loss: 0.0080


                                                                                  

Epoch 159/200
Train Loss: 0.0079 | Val Loss: 0.0079


                                                                                  

Epoch 160/200
Train Loss: 0.0078 | Val Loss: 0.0080


                                                                                  

Epoch 161/200
Train Loss: 0.0078 | Val Loss: 0.0077
Saved new best model with val loss: 0.0077


                                                                                  

Epoch 162/200
Train Loss: 0.0078 | Val Loss: 0.0076
Saved new best model with val loss: 0.0076


                                                                                  

Epoch 163/200
Train Loss: 0.0077 | Val Loss: 0.0076
Saved new best model with val loss: 0.0076


                                                                                  

Epoch 164/200
Train Loss: 0.0076 | Val Loss: 0.0076


                                                                                  

Epoch 165/200
Train Loss: 0.0077 | Val Loss: 0.0077


                                                                                  

Epoch 166/200
Train Loss: 0.0076 | Val Loss: 0.0074
Saved new best model with val loss: 0.0074


                                                                                  

Epoch 167/200
Train Loss: 0.0075 | Val Loss: 0.0074
Saved new best model with val loss: 0.0074


                                                                                  

Epoch 168/200
Train Loss: 0.0075 | Val Loss: 0.0074
Saved new best model with val loss: 0.0074


                                                                                  

Epoch 169/200
Train Loss: 0.0074 | Val Loss: 0.0074
Saved new best model with val loss: 0.0074


                                                                                  

Epoch 170/200
Train Loss: 0.0074 | Val Loss: 0.0073
Saved new best model with val loss: 0.0073


                                                                                  

Epoch 171/200
Train Loss: 0.0074 | Val Loss: 0.0073
Saved new best model with val loss: 0.0073


                                                                                  

Epoch 172/200
Train Loss: 0.0074 | Val Loss: 0.0073
Saved new best model with val loss: 0.0073


                                                                                  

Epoch 173/200
Train Loss: 0.0073 | Val Loss: 0.0072
Saved new best model with val loss: 0.0072


                                                                                  

Epoch 174/200
Train Loss: 0.0073 | Val Loss: 0.0072


                                                                                  

Epoch 175/200
Train Loss: 0.0072 | Val Loss: 0.0074


                                                                                  

Epoch 176/200
Train Loss: 0.0072 | Val Loss: 0.0071
Saved new best model with val loss: 0.0071


                                                                                  

Epoch 177/200
Train Loss: 0.0072 | Val Loss: 0.0071


                                                                                  

Epoch 178/200
Train Loss: 0.0071 | Val Loss: 0.0071


                                                                                  

Epoch 179/200
Train Loss: 0.0070 | Val Loss: 0.0068
Saved new best model with val loss: 0.0068


                                                                                  

Epoch 180/200
Train Loss: 0.0070 | Val Loss: 0.0072


                                                                                  

Epoch 181/200
Train Loss: 0.0070 | Val Loss: 0.0069


                                                                                  

Epoch 182/200
Train Loss: 0.0069 | Val Loss: 0.0069


                                                                                  

Epoch 183/200
Train Loss: 0.0069 | Val Loss: 0.0068
Saved new best model with val loss: 0.0068


                                                                                  

Epoch 184/200
Train Loss: 0.0068 | Val Loss: 0.0068
Saved new best model with val loss: 0.0068


                                                                                  

Epoch 185/200
Train Loss: 0.0068 | Val Loss: 0.0067
Saved new best model with val loss: 0.0067


                                                                                  

Epoch 186/200
Train Loss: 0.0067 | Val Loss: 0.0066
Saved new best model with val loss: 0.0066


                                                                                  

Epoch 187/200
Train Loss: 0.0067 | Val Loss: 0.0066
Saved new best model with val loss: 0.0066


                                                                                  

Epoch 188/200
Train Loss: 0.0067 | Val Loss: 0.0067


                                                                                  

Epoch 189/200
Train Loss: 0.0067 | Val Loss: 0.0066


                                                                                  

Epoch 190/200
Train Loss: 0.0067 | Val Loss: 0.0065
Saved new best model with val loss: 0.0065


                                                                                  

Epoch 191/200
Train Loss: 0.0066 | Val Loss: 0.0064
Saved new best model with val loss: 0.0064


                                                                                  

Epoch 192/200
Train Loss: 0.0065 | Val Loss: 0.0065


                                                                                  

Epoch 193/200
Train Loss: 0.0065 | Val Loss: 0.0065


                                                                                  

Epoch 194/200
Train Loss: 0.0065 | Val Loss: 0.0067


                                                                                  

Epoch 195/200
Train Loss: 0.0064 | Val Loss: 0.0063
Saved new best model with val loss: 0.0063


                                                                                  

Epoch 196/200
Train Loss: 0.0064 | Val Loss: 0.0065


                                                                                  

Epoch 197/200
Train Loss: 0.0064 | Val Loss: 0.0063
Saved new best model with val loss: 0.0063


                                                                                  

Epoch 198/200
Train Loss: 0.0064 | Val Loss: 0.0065


                                                                                  

Epoch 199/200
Train Loss: 0.0063 | Val Loss: 0.0063
Saved new best model with val loss: 0.0063


                                                                                  

Epoch 200/200
Train Loss: 0.0063 | Val Loss: 0.0062
Saved new best model with val loss: 0.0062


In [257]:
total_train_loader = loaders['train']
total_val_loader = loaders['val']

In [232]:
class Classifier(nn.Module):
    def __init__(self, encoder, num_classes=3):
        super().__init__()
        self.encoder = encoder
        self.classifier = nn.Linear(encoder.patchify.out_channels, num_classes)
        
        # Freeze encoder (optional)
        # for param in self.encoder.parameters():
        #     param.requires_grad = False

    def forward(self, x):
        features, _ = self.encoder(x)
        cls_token = features[0]
        return self.classifier(cls_token)


In [235]:
print(torch.load('best_mae_model.pth').keys())

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])


In [249]:
mae = MAE_ViT(
        image_size=64,
        patch_size=4,
        emb_dim=192,
        encoder_layer=12,
        encoder_head=3,
        decoder_layer=4,
        decoder_head=3,
        mask_ratio=0.75
).to(device)
mae.load_state_dict(torch.load('best_mae_model.pth')['model_state_dict'])
model = Classifier(mae.encoder).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

In [250]:
def evaluate_roc(model, loader):
    model.eval()
    all_probs = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in loader:
            images = images.cuda()
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            all_probs.append(probs.cpu())
            all_labels.append(labels)
    
    probs = torch.cat(all_probs).numpy()
    labels = torch.cat(all_labels).numpy()

    fpr, tpr, roc_auc = {}, {}, {}
    for i in range(3):
        fpr[i], tpr[i], _ = roc_curve(labels == i, probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    
    roc_auc["macro"] = np.mean(list(roc_auc.values()))
    
    return fpr, tpr, roc_auc

In [251]:
def train_classifier(model, train_loader, optimizer, criterion, device, epoch):
    model.train()
    total_loss = 0.0
    all_probs = []
    all_labels = []
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]", leave=False)
    
    for images, labels in progress_bar:
        images = images.to(device)
        # print("Image shape before model:", images.shape)   
        images = images.clamp(0, 1)
        images = images.repeat(1, 3, 1, 1)
        optimizer.zero_grad()
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        all_probs.append(probs.cpu())
        all_labels.append(labels.cpu())
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
        probs = torch.cat(all_probs).numpy()
    labels = torch.cat(all_labels).numpy()

    macro_auc = roc_auc_score(labels, probs, multi_class='ovr', average='macro')
    class_auc = roc_auc_score(labels, probs, multi_class='ovr', average=None)
    
    return macro_auc, class_auc , total_loss/len(train_loader.dataset)

def validate_classifier(model, val_loader, criterion, device, epoch):
    model.eval()
    total_loss = 0.0
    all_probs = []
    all_labels = []
    progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]", leave=False)
    
    with torch.no_grad():
        for images, labels in progress_bar:
            images = images.to(device)
            images = images.clamp(0, 1)
            images = images.repeat(1, 3, 1, 1)
            outputs = model(images)
            loss = criterion(outputs, labels)
            probs = torch.softmax(outputs, dim=1)
            all_probs.append(probs.cpu())
            all_labels.append(labels.cpu())
            total_loss += loss.item()
            progress_bar.set_postfix({'val_loss': loss.item()})
        
    probs = torch.cat(all_probs).numpy()
    labels = torch.cat(all_labels).numpy()

    macro_auc = roc_auc_score(labels, probs, multi_class='ovr', average='macro')
    class_auc = roc_auc_score(labels, probs, multi_class='ovr', average=None)
    
    return macro_auc, class_auc , total_loss/len(val_loader.dataset)

In [258]:
best_auc = 0
cnt = 0
for epoch in range(50):
    if(cnt>10):
        break
    train_macro_auc, train_class_auc, train_loss = train_classifier(model, total_train_loader, optimizer, criterion, device, epoch)
    val_macro_auc, val_class_auc, val_loss = validate_classifier(model, total_val_loader, criterion, device, epoch)
    
    print(f"Epoch {epoch+1}/{config.EPOCHS}")
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    wandb.log({
        "epoch": epoch,
        "train_macro_auc": train_macro_auc,
        "val_macro_auc": val_macro_auc,
        "train_loss_cls": train_loss,
        "val_loss_cls": val_loss,
        **{f"train_class_{i}": train_class_auc[i] for i in range(len(train_class_auc))},
        **{f"val_class_{i}": val_class_auc[i] for i in range(len(val_class_auc))}
    })

    if val_macro_auc < best_auc:
        best_auc = val_macro_auc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'auc': val_macro_auc,
        }, "best_mae_model_classifier.pth")
        print(f"Saved new best model with val loss: {val_loss:.4f}")

    else:
        cnt += 1

Epoch 1 [Train]:   0%|          | 0/1254 [00:00<?, ?it/s]

Error loading /Users/sshah/2024/projects/gsoc/ml4sci/fm/Dataset/axion/axion_sim_17244485153218637000981744574520177317.npy: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint64, uint32, uint16, uint8, and bool.
Error loading /Users/sshah/2024/projects/gsoc/ml4sci/fm/Dataset/axion/axion_sim_99052418600427477553149686448666035119.npy: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint64, uint32, uint16, uint8, and bool.
Error loading /Users/sshah/2024/projects/gsoc/ml4sci/fm/Dataset/axion/axion_sim_197246890915100861128457917963192811610.npy: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint64, uint32, uint16, uint8, and bool.
Error loading /Users/sshah/2024/project

                                                         

RuntimeError: Placeholder storage has not been allocated on MPS device!

In [None]:
def evaluate_roc(model, loader):
    model.eval()
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for images, labels in loader:
            images = images.cuda()
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            all_probs.append(probs.cpu())
            all_labels.append(labels)

    probs = torch.cat(all_probs).numpy()
    labels = torch.cat(all_labels).numpy()

    # Calculate metrics
    fpr, tpr, roc_auc = {}, {}, {}
    wandb_roc_data = []

    for i in range(3):
        fpr[i], tpr[i], _ = roc_curve(labels == i, probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

        # Store data for wandb
        for j in range(len(fpr[i])):
            wandb_roc_data.append([fpr[i][j], tpr[i][j], f'Class {i}'])

    roc_auc["macro"] = np.mean(list(roc_auc.values()))

    # Log ROC curve to wandb
    wandb.log({"roc_curve": wandb.Table(data=wandb_roc_data, columns=["FPR", "TPR", "Class"])})

    # Log AUC values
    wandb.log({f"AUC_Class_{i}_final": roc_auc[i] for i in range(3)})
    wandb.log({"AUC_Macro_final": roc_auc["macro"]})

    return fpr, tpr, roc_auc

fpr, tpr, roc_auc = evaluate_roc(mae, total_val_loader)