In [1]:
!pip install einops



In [2]:
import torch
import timm
import numpy as np

from einops import repeat, rearrange
from einops.layers.torch import Rearrange

from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
from torchvision import transforms
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
import wandb
import os
import shutil
from sklearn.model_selection import train_test_split
from PIL import Image
import math
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import roc_curve, auc



In [3]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msamkitshah1262[0m ([33msamkitshah1262-warner-bros-discovery[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
wandb.init(project="deeplens-foundational", entity="samkitshah1262-warner-bros-discovery", reinit=True)

In [5]:
# Configuration
class Config:
    SEED = 1
    BATCH_SIZE = 64
    EPOCHS = 200
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    MASK_RATIO = 0.75
    NUM_CLASSES = 3
    IMG_SIZE = (64, 64)
    USE_SAVED_MODEL = False
    CKPT_PATH = "./model_weights.pth"

config = Config()

# Set random seeds
torch.manual_seed(config.SEED)
np.random.seed(config.SEED)

In [6]:
device = torch.device("mps" if torch.mps.is_available() else "cpu")

In [7]:
img = np.load("dataset/no_sub/no_sub_sim_6956151560647865808482838248806684.npy")  # Load .npy file
img = torch.tensor(img, dtype=torch.float32)
img.shape

torch.Size([64, 64])

In [8]:
DATA_PATH_ROOT = "/Users/sshah/2024/projects/gsoc/ml4sci/fm/Dataset"
DATA_PATH_AXION = "/axion"
DATA_PATH_CDM = "/cdm"
DATA_PATH_NO_SUB = "/no_sub"

OUTPUT_ROOT = "/dataset"

In [9]:
class NpyDirectoryDataset(Dataset):

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}

        class_names = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        class_names.sort()
        print(class_names)
        for idx, class_name in enumerate(class_names):
            self.class_to_idx[class_name] = idx

        for class_name in class_names:
            class_dir = os.path.join(root_dir, class_name)
            class_idx = self.class_to_idx[class_name]
            
            for file_name in os.listdir(class_dir):
                file_path = os.path.join(class_dir, file_name)
                if file_path.endswith('.npy') and os.path.isfile(file_path):
                    self.samples.append((file_path, class_idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        file_path, label = self.samples[idx]
        
        try:
            # Load the data with allow_pickle=True since it's an object array
            data = np.load(file_path, allow_pickle=True)
            
            # Extract the first element which contains the (64, 64) array
            if data.dtype == np.dtype('object') and data.size > 0:
                data = data[0]
            
            # Convert to tensor and ensure it's float
            data_tensor = torch.from_numpy(data).float()
            
            # Add channel dimension if needed
            if data_tensor.ndim == 2:
                data_tensor = data_tensor.unsqueeze(0)
                
            # Apply transforms if any
            if self.transform:
                data_tensor = self.transform(data_tensor)
            
            return data_tensor, label
            
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            # Return a placeholder if loading fails
            placeholder = torch.zeros((1, 64, 64))
            return placeholder, label

In [10]:
transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.Grayscale(num_output_channels=1),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [11]:
def create_data_loaders(root_dir, batch_size=32, train_ratio=0.8, num_workers=4, seed=42):

    torch.manual_seed(seed)
    np.random.seed(seed)

    full_dataset = NpyDirectoryDataset(root_dir=root_dir,transform=transforms)
    
    if len(full_dataset) == 0:
        raise ValueError(f"No valid .npy files found in {root_dir}. Please check the directory structure.")
    
    print(f"Found {len(full_dataset)} .npy files total")
    print(f"Class mapping: {full_dataset.class_to_idx}")

    total_size = len(full_dataset)
    train_size = int(train_ratio * total_size)
    val_size = total_size - train_size

    train_dataset, val_dataset = random_split(
        full_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(seed)
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers
    )

    class_loaders = {}

    class_names = list(full_dataset.class_to_idx.keys())
    
    for class_name in class_names:
        class_idx = full_dataset.class_to_idx[class_name]
        class_indices = [i for i, (_, label) in enumerate(full_dataset.samples) if label == class_idx]
        
        if len(class_indices) == 0:
            print(f"Warning: No samples found for class {class_name}")
            continue

        class_train_indices = [idx for idx in range(len(full_dataset)) if 
                              idx in train_dataset.indices and 
                              full_dataset.samples[idx][1] == class_idx]
                              
        class_val_indices = [idx for idx in range(len(full_dataset)) if 
                            idx in val_dataset.indices and 
                            full_dataset.samples[idx][1] == class_idx]
        
        class_train_subset = torch.utils.data.Subset(full_dataset, class_train_indices)
        class_val_subset = torch.utils.data.Subset(full_dataset, class_val_indices)

        class_loaders[f"{class_name}_train"] = DataLoader(
            class_train_subset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers
        )
        
        class_loaders[f"{class_name}_val"] = DataLoader(
            class_val_subset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers
        )

    loaders = {
        "train": train_loader,
        "val": val_loader,
        **class_loaders
    }

    print("\nDataset information:")
    print(f"Training set: {len(train_dataset)} samples")
    print(f"Validation set: {len(val_dataset)} samples")
    
    for name, loader in class_loaders.items():
        print(f"{name}: {len(loader.dataset)} samples")
    
    return loaders

In [12]:
root_directory = DATA_PATH_ROOT

loaders = create_data_loaders(
    root_dir=root_directory,
    batch_size=config.BATCH_SIZE,
    train_ratio=0.9,
    num_workers=0
)

['axion', 'cdm', 'no_sub']
Found 89104 .npy files total
Class mapping: {'axion': 0, 'cdm': 1, 'no_sub': 2}

Dataset information:
Training set: 80193 samples
Validation set: 8911 samples
axion_train: 26981 samples
axion_val: 2915 samples
cdm_train: 26723 samples
cdm_val: 3036 samples
no_sub_train: 26489 samples
no_sub_val: 2960 samples


In [13]:
def random_indexes(size : int):
    forward_indexes = np.arange(size)
    np.random.shuffle(forward_indexes)
    backward_indexes = np.argsort(forward_indexes)
    return forward_indexes, backward_indexes

def take_indexes(sequences, indexes):
    return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))

class PatchShuffle(torch.nn.Module):
    def __init__(self, ratio) -> None:
        super().__init__()
        self.ratio = ratio

    def forward(self, patches : torch.Tensor):
        T, B, C = patches.shape
        remain_T = int(T * (1 - self.ratio))

        indexes = [random_indexes(T) for _ in range(B)]
        forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
        backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)

        patches = take_indexes(patches, forward_indexes)
        patches = patches[:remain_T]

        return patches, forward_indexes, backward_indexes

class MAE_Encoder(torch.nn.Module):
    def __init__(self,
                 image_size=64,
                 patch_size=4,
                 emb_dim=192,
                 num_layer=12,
                 num_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
        self.shuffle = PatchShuffle(mask_ratio)

        self.patchify = torch.nn.Conv2d(1, emb_dim, patch_size, patch_size)

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.layer_norm = torch.nn.LayerNorm(emb_dim)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.cls_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding

        patches, forward_indexes, backward_indexes = self.shuffle(patches)

        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')

        return features, backward_indexes

class MAE_Decoder(torch.nn.Module):
    def __init__(self,
                 image_size=64,
                 patch_size=4,
                 emb_dim=192,
                 num_layer=4,
                 num_head=3,
                 ) -> None:
        super().__init__()

        self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.head = torch.nn.Linear(emb_dim,  * patch_size ** 2)
        self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.mask_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, features, backward_indexes):
        T = features.shape[0]
        backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
        features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
        features = take_indexes(features, backward_indexes)
        features = features + self.pos_embedding

        features = rearrange(features, 't b c -> b t c')
        features = self.transformer(features)
        features = rearrange(features, 'b t c -> t b c')
        features = features[1:]

        patches = self.head(features)
        mask = torch.zeros_like(patches)
        mask[T-1:] = 1
        mask = take_indexes(mask, backward_indexes[1:] - 1)
        img = self.patch2img(patches)
        mask = self.patch2img(mask)

        return img, mask

class MAE_ViT(torch.nn.Module):
    def __init__(self,
                 image_size=64,
                 patch_size=4,
                 emb_dim=192,
                 encoder_layer=12,
                 encoder_head=3,
                 decoder_layer=4,
                 decoder_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
        self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)

    def forward(self, img):
        features, backward_indexes = self.encoder(img)
        predicted_img, mask = self.decoder(features,  backward_indexes)
        return predicted_img, mask

In [14]:
# optimizer = torch.optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.95), weight_decay=config.WEIGHT_DECAY)
# def lr_func(epoch):
#     return min((epoch + 1) / (200 + 1e-8), 0.5 * (math.cos(epoch / config.EPOCHS * math.pi) + 1))
# lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func, verbose=True)

In [15]:
# train_dataloader = loaders['no_sub_train']
# val_dataloader = loaders['no_sub_val']

In [16]:
# def train_mae(model, train_loader, optimizer, device, epoch):
#     model.train()
#     total_loss = 0.0
#     progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]", leave=False)
    
#     for images, _ in progress_bar:
#         images = images.to(device)
#         # print("Image shape before model:", images.shape)   
#         images = images.clamp(0, 1)
#         # images = images.repeat(1, 3, 1, 1)
#         optimizer.zero_grad()
#         reconstructed, mask = model(images)
#         loss = F.mse_loss(reconstructed * mask, images * mask)
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item() * images.size(0)
#         progress_bar.set_postfix({'loss': loss.item()})
    
#     return total_loss / len(train_loader.dataset)

# def validate_mae(model, val_loader, device, epoch):
#     model.eval()
#     total_loss = 0.0
#     progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]", leave=False)
    
#     with torch.no_grad():
#         for images, _ in progress_bar:
#             images = images.to(device)
#             images = images.clamp(0, 1)
#             # images = images.repeat(1, 3, 1, 1)
#             reconstructed, mask = model(images)
#             loss = F.mse_loss(reconstructed * mask, images * mask)
            
#             total_loss += loss.item() * images.size(0)
#             progress_bar.set_postfix({'val_loss': loss.item()})
    
#     return total_loss / len(val_loader.dataset)

In [17]:
# best_loss = float('inf')
# for epoch in range(75):
#     train_loss = train_mae(model, train_dataloader, optimizer, device, epoch)
#     val_loss = validate_mae(model, val_dataloader, device, epoch)
    
#     print(f"Epoch {epoch+1}/{config.EPOCHS}")
#     print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
#     wandb.log({
#         "epoch": epoch,
#         "train_loss": train_loss,
#         "val_loss": val_loss,
#     })

#     # Save best model
#     if val_loss < best_loss:
#         best_loss = val_loss
#         torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'loss': val_loss,
#         }, "best_mae_model2.pth")
#         print(f"Saved new best model with val loss: {val_loss:.4f}")

In [18]:
total_train_loader = loaders['train']
total_val_loader = loaders['val']

In [19]:
TORCH_SHOW_CPP_STACKTRACES=1

In [20]:
# # First, add this at the top of your training cell
# torch.autograd.set_detect_anomaly(True)

# class ViT_Classifier(torch.nn.Module):
#     def __init__(self, encoder : MAE_Encoder, num_classes=3) -> None:
#         super().__init__()
#         encoder.mask_ratio = 0
#         self.cls_token = encoder.cls_token
#         self.pos_embedding = encoder.pos_embedding
#         self.patchify = encoder.patchify
#         self.transformer = encoder.transformer
#         self.layer_norm = encoder.layer_norm
#         self.head = torch.nn.Linear(self.pos_embedding.shape[-1], num_classes)

#     def forward(self, img):
#         # Enable anomaly detection
#         with torch.autograd.detect_anomaly():
#             patches = self.patchify(img)
#             print(f"After patchify shape: {patches.shape}")
            
#             patches = rearrange(patches, 'b c h w -> (h w) b c')
#             print(f"After first rearrange shape: {patches.shape}")
            
#             patches = patches + self.pos_embedding
#             print(f"After adding pos_embedding shape: {patches.shape}")
            
#             patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
#             print(f"After adding cls_token shape: {patches.shape}")
            
#             patches = rearrange(patches, 't b c -> b t c')
#             print(f"After second rearrange shape: {patches.shape}")
            
#             features = self.layer_norm(self.transformer(patches))
#             print(f"After transformer shape: {features.shape}")
            
#             # Ensure tensor is contiguous before reshaping
#             features = features.contiguous()
#             cls_token_features = features[:, 0, :]
#             print(f"Classifier token features shape: {cls_token_features.shape}")
            
#             logits = self.head(cls_token_features)
#             print(f"Final logits shape: {logits.shape}")
            
#             return logits

In [21]:
class ViT_Classifier(torch.nn.Module):
    def __init__(self, encoder: MAE_Encoder, num_classes=3):
        super().__init__()
        self.encoder = encoder
        # Freeze encoder parameters
        for param in self.encoder.parameters():
            param.requires_grad = False
            
        self.classifier = nn.Sequential(
            nn.LayerNorm(192),  # Normalize input features
            nn.Linear(192, 256),
            nn.GELU(),  # GELU activation often works better than ReLU
            nn.Dropout(0.2),
            nn.LayerNorm(256),
            nn.Linear(256, 128),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.LayerNorm(128),
            nn.Linear(128, num_classes)
        )
        
        # Initialize weights properly
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.constant_(module.bias, 0)

                
    def forward(self, x):
        # Add channel dimension if needed
        # if x.size(1) == 1:
        #     x = x.repeat(1, 3, 1, 1)
            
        # Get features from encoder
        with torch.no_grad():
            features, _ = self.encoder(x)
            # Take the CLS token features (first token)
            cls_features = features[0]  # Shape: [batch_size, embedding_dim]
            # print("cls_features: ",cls_features)
        # Pass through classifier
        logits = self.classifier(cls_features)
        return logits

In [218]:
# mae = MAE_ViT(
#         image_size=64,
#         patch_size=4,
#         emb_dim=192,
#         encoder_layer=12,
#         encoder_head=3,
#         decoder_layer=4,
#         decoder_head=3,
#         mask_ratio=0
# ).to(device)
# mae.load_state_dict(torch.load('best_mae_model2.pth')['model_state_dict'])
# model = ViT_Classifier(mae.encoder).to(device)
# criterion = nn.CrossEntropyLoss().to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

In [22]:
def evaluate_roc(model, loader):
    model.eval()
    all_probs = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in loader:
            images = images.cuda()
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            all_probs.append(probs.cpu())
            all_labels.append(labels)
    
    probs = torch.cat(all_probs).numpy()
    labels = torch.cat(all_labels).numpy()

    fpr, tpr, roc_auc = {}, {}, {}
    for i in range(3):
        fpr[i], tpr[i], _ = roc_curve(labels == i, probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    
    roc_auc["macro"] = np.mean(list(roc_auc.values()))
    
    return fpr, tpr, roc_auc

In [23]:
def train_classifier(model, train_loader, optimizer, criterion, device, epoch):
    model.train()
    total_loss = 0.0
    all_probs = []
    all_labels = []
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]", leave=False)
    torch.autograd.set_detect_anomaly(True)
    for images, labels in progress_bar:
        images = images.to(device)
        labels = labels.to(device)
        # print("Image shape before model:", images.shape)   
        images = images.clamp(0, 1)
        # images = images.repeat(1, 3, 1, 1)
        optimizer.zero_grad()
        # print(f"Input image shape: {images.shape}")
        outputs = model(images)
        # print(f"Model output shape: {outputs.shape}")
        # print(f"Labels shape: {labels.shape}")
        probs = torch.softmax(outputs, dim=1)
        all_probs.append(probs.cpu())
        all_labels.append(labels.cpu())
        loss = criterion(outputs, labels)
        # print(f"Loss value: {loss.item()}")
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
    probs = torch.cat(all_probs).detach().numpy()
    labels = torch.cat(all_labels).detach().numpy()

    macro_auc = roc_auc_score(labels, probs, multi_class='ovr', average='macro')
    class_auc = roc_auc_score(labels, probs, multi_class='ovr', average=None)
    
    return macro_auc, class_auc , total_loss/len(train_loader.dataset)

def validate_classifier(model, val_loader, criterion, device, epoch):
    model.eval()
    total_loss = 0.0
    all_probs = []
    all_labels = []
    progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]", leave=False)
    
    with torch.no_grad():
        for images, labels in progress_bar:
            images = images.to(device)
            labels = labels.to(device)
            images = images.clamp(0, 1)
            # images = images.repeat(1, 3, 1, 1)
            outputs = model(images)
            loss = criterion(outputs, labels)
            probs = torch.softmax(outputs, dim=1)
            all_probs.append(probs.cpu())
            all_labels.append(labels.cpu())
            total_loss += loss.item()
            progress_bar.set_postfix({'val_loss': loss.item()})
        
    probs = torch.cat(all_probs).detach().numpy()
    labels = torch.cat(all_labels).detach().numpy()

    macro_auc = roc_auc_score(labels, probs, multi_class='ovr', average='macro')
    class_auc = roc_auc_score(labels, probs, multi_class='ovr', average=None)
    
    return macro_auc, class_auc , total_loss/len(val_loader.dataset)

In [33]:
mae = MAE_ViT(
    image_size=64,
    patch_size=4,
    emb_dim=192,
    encoder_layer=12,
    encoder_head=3,
    decoder_layer=4,
    decoder_head=3,
    mask_ratio=0
).to(device)
mae.load_state_dict(torch.load('best_mae_model2.pth')['model_state_dict'])
model = ViT_Classifier(mae.encoder).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


TypeError: Value after * must be an iterable, not int



In [25]:
best_auc = 0
cnt = 0
for epoch in range(50):
    if(cnt>10):
        break
    train_macro_auc, train_class_auc, train_loss = train_classifier(model, total_train_loader, optimizer, criterion, device, epoch)
    val_macro_auc, val_class_auc, val_loss = validate_classifier(model, total_val_loader, criterion, device, epoch)
    
    print(f"Epoch {epoch+1}/{config.EPOCHS}")
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    wandb.log({
        "epoch": epoch,
        "train_macro_auc": train_macro_auc,
        "val_macro_auc": val_macro_auc,
        "train_loss_cls": train_loss,
        "val_loss_cls": val_loss,
        **{f"train_class_{i}": train_class_auc[i] for i in range(len(train_class_auc))},
        **{f"val_class_{i}": val_class_auc[i] for i in range(len(val_class_auc))}
    })

    if val_macro_auc < best_auc:
        best_auc = val_macro_auc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'auc': val_macro_auc,
        }, "best_mae_model_classifier.pth")
        print(f"Saved new best model with val loss: {val_loss:.4f}")

    else:
        cnt += 1

NameError: name 'model' is not defined

In [None]:
def evaluate_roc(model, loader):
    model.eval()
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for images, labels in loader:
            images = images.cuda()
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            all_probs.append(probs.cpu())
            all_labels.append(labels)

    probs = torch.cat(all_probs).numpy()
    labels = torch.cat(all_labels).numpy()

    # Calculate metrics
    fpr, tpr, roc_auc = {}, {}, {}
    wandb_roc_data = []

    for i in range(3):
        fpr[i], tpr[i], _ = roc_curve(labels == i, probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

        # Store data for wandb
        for j in range(len(fpr[i])):
            wandb_roc_data.append([fpr[i][j], tpr[i][j], f'Class {i}'])

    roc_auc["macro"] = np.mean(list(roc_auc.values()))

    # Log ROC curve to wandb
    wandb.log({"roc_curve": wandb.Table(data=wandb_roc_data, columns=["FPR", "TPR", "Class"])})

    # Log AUC values
    wandb.log({f"AUC_Class_{i}_final": roc_auc[i] for i in range(3)})
    wandb.log({"AUC_Macro_final": roc_auc["macro"]})

    return fpr, tpr, roc_auc

fpr, tpr, roc_auc = evaluate_roc(mae, total_val_loader)

In [224]:
# Create random data for testing convergence
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

def create_random_data(num_samples=1000, img_size=64, num_classes=3):
    # Create random images with a learnable pattern
    images = torch.randn(num_samples, 1, img_size, img_size)
    
    # Create random labels with clear patterns for each class
    labels = torch.zeros(num_samples, dtype=torch.long)
    for i in range(num_samples):
        # Add some pattern to make the data learnable
        if images[i, 0, 0:32, 0:32].mean() > 0:
            labels[i] = 0
        elif images[i, 0, 32:64, 0:32].mean() > 0:
            labels[i] = 1
        else:
            labels[i] = 2
            
        # Add some noise to the pattern
        if torch.rand(1) < 0.1:  # 10% noise
            labels[i] = torch.randint(0, 3, (1,))
    
    return images, labels

# Create random datasets
train_images, train_labels = create_random_data(num_samples=1000)
val_images, val_labels = create_random_data(num_samples=200)

# Create dataloaders
random_train_dataset = TensorDataset(train_images, train_labels)
random_val_dataset = TensorDataset(val_images, val_labels)

random_train_loader = DataLoader(random_train_dataset, batch_size=32, shuffle=True)
random_val_loader = DataLoader(random_val_dataset, batch_size=32, shuffle=False)

# Initialize model and optimizer
mae = MAE_ViT(
    image_size=64,
    patch_size=4,
    emb_dim=192,
    encoder_layer=12,
    encoder_head=3,
    decoder_layer=4,
    decoder_head=3,
    mask_ratio=0
).to(device)

# Load pre-trained weights
mae.load_state_dict(torch.load('best_mae_model2.pth')['model_state_dict'])
model = ViT_Classifier(mae.encoder).to(device)

# Initialize optimizer and criterion
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training loop
print("Starting training on random data...")
best_val_acc = 0
patience = 5
patience_counter = 0

for epoch in range(20):
    # Training
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(random_train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % 10 == 0:
            print(f'Epoch: {epoch} | Batch: {batch_idx} | Loss: {loss.item():.4f} | Acc: {100.*correct/total:.2f}%')
    
    train_acc = 100.*correct/total
    
    # Validation
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in random_val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    val_acc = 100.*correct/total
    
    # Log metrics
    wandb.log({
        'random_train_loss': train_loss/len(random_train_loader),
        'random_train_acc': train_acc,
        'random_val_loss': val_loss/len(random_val_loader),
        'random_val_acc': val_acc,
        'epoch': epoch
    })
    
    print(f'\nEpoch: {epoch}')
    print(f'Train Loss: {train_loss/len(random_train_loader):.4f} | Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss/len(random_val_loader):.4f} | Val Acc: {val_acc:.2f}%')
    
    # Early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
    else:
        patience_counter += 1
        
    if patience_counter >= patience:
        print(f'\nEarly stopping triggered. Best validation accuracy: {best_val_acc:.2f}%')
        break

print("Training completed!")

Starting training on random data...
cls_features:  tensor([[ 0.2450, -1.2289, -0.5948,  ..., -0.7491,  0.1013, -0.6281],
        [ 0.2141, -1.0479, -0.5244,  ..., -0.8510, -0.0119, -0.4586],
        [-0.1153, -0.9460, -0.5355,  ..., -0.8240,  0.0588, -0.5707],
        ...,
        [ 0.1473, -1.3263, -1.0752,  ..., -0.6534,  0.3485, -0.7120],
        [-0.0344, -1.0565, -0.5682,  ..., -0.8781,  0.1710, -0.4143],
        [ 0.0994, -1.1471, -0.6164,  ..., -0.8475,  0.1146, -0.7579]],
       device='mps:0')
Epoch: 0 | Batch: 0 | Loss: 1.1979 | Acc: 46.88%
cls_features:  tensor([[ 0.1107, -0.9384, -0.4480,  ..., -0.7079,  0.1194, -0.5513],
        [ 0.1257, -1.1308, -0.5046,  ..., -0.7745,  0.1917, -0.5722],
        [ 0.0393, -1.1002, -0.6289,  ..., -0.8673,  0.2991, -0.3502],
        ...,
        [ 0.0504, -1.0735, -0.3893,  ..., -0.7381, -0.0332, -0.7761],
        [-0.1058, -0.9896, -0.3954,  ..., -0.8490,  0.2324, -0.6006],
        [ 0.1579, -0.9971, -0.6543,  ..., -0.6129,  0.1758, -0.52

KeyboardInterrupt: 