In [48]:
!pip install einops



In [49]:
import torch
import timm
import numpy as np

from einops import repeat, rearrange
from einops.layers.torch import Rearrange

from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
from torchvision import transforms
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
import wandb
import os
import shutil
from sklearn.model_selection import train_test_split
from PIL import Image
import math
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import roc_curve, auc

In [50]:
wandb.login()



True

In [51]:
wandb.init(project="deeplens-foundational", entity="samkitshah1262-warner-bros-discovery")

In [52]:
# Configuration
class Config:
    SEED = 1
    BATCH_SIZE = 64
    EPOCHS = 200
    LEARNING_RATE = 1e-5
    WEIGHT_DECAY = 1e-5
    MASK_RATIO = 0.75
    NUM_CLASSES = 3
    IMG_SIZE = (64, 64)
    USE_SAVED_MODEL = False
    CKPT_PATH = "./model_weights.pth"

config = Config()

# Set random seeds
torch.manual_seed(config.SEED)
np.random.seed(config.SEED)

In [53]:
device = torch.device("mps" if torch.mps.is_available() else "cpu")

In [54]:
img = np.load("dataset/no_sub/no_sub_sim_6956151560647865808482838248806684.npy")  # Load .npy file
img = torch.tensor(img, dtype=torch.float32)
img.shape

torch.Size([64, 64])

In [55]:
DATA_PATH_ROOT = "/Users/sshah/2024/projects/gsoc/ml4sci/fm/Dataset"
DATA_PATH_AXION = "/axion"
DATA_PATH_CDM = "/cdm"
DATA_PATH_NO_SUB = "/no_sub"

OUTPUT_ROOT = "/dataset"

In [56]:
class NpyDirectoryDataset(Dataset):

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}

        class_names = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        class_names.sort()

        for idx, class_name in enumerate(class_names):
            self.class_to_idx[class_name] = idx

        for class_name in class_names:
            class_dir = os.path.join(root_dir, class_name)
            class_idx = self.class_to_idx[class_name]
            
            for file_name in os.listdir(class_dir):
                file_path = os.path.join(class_dir, file_name)
                if file_path.endswith('.npy') and os.path.isfile(file_path):
                    self.samples.append((file_path, class_idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        file_path, label = self.samples[idx]
        
        try:
            data = np.load(file_path, allow_pickle=True)
            img_array = data[0].astype(np.float32)
            image = Image.fromarray(img_array)
            image = image.point(lambda p: p / img_array.max())

            if self.transform:
                image = self.transform(image)
                
            return image, label
            
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            placeholder = torch.zeros((1, 64, 64))
            return placeholder, label

In [57]:
transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [58]:
def create_data_loaders(root_dir, batch_size=32, train_ratio=0.8, num_workers=4, seed=42):

    torch.manual_seed(seed)
    np.random.seed(seed)

    full_dataset = NpyDirectoryDataset(root_dir=root_dir,transform=transforms)
    
    if len(full_dataset) == 0:
        raise ValueError(f"No valid .npy files found in {root_dir}. Please check the directory structure.")
    
    print(f"Found {len(full_dataset)} .npy files total")
    print(f"Class mapping: {full_dataset.class_to_idx}")

    total_size = len(full_dataset)
    train_size = int(train_ratio * total_size)
    val_size = total_size - train_size

    train_dataset, val_dataset = random_split(
        full_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(seed)
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers
    )

    class_loaders = {}

    class_names = list(full_dataset.class_to_idx.keys())
    
    for class_name in class_names:
        class_idx = full_dataset.class_to_idx[class_name]
        class_indices = [i for i, (_, label) in enumerate(full_dataset.samples) if label == class_idx]
        
        if len(class_indices) == 0:
            print(f"Warning: No samples found for class {class_name}")
            continue

        class_train_indices = [idx for idx in range(len(full_dataset)) if 
                              idx in train_dataset.indices and 
                              full_dataset.samples[idx][1] == class_idx]
                              
        class_val_indices = [idx for idx in range(len(full_dataset)) if 
                            idx in val_dataset.indices and 
                            full_dataset.samples[idx][1] == class_idx]
        
        class_train_subset = torch.utils.data.Subset(full_dataset, class_train_indices)
        class_val_subset = torch.utils.data.Subset(full_dataset, class_val_indices)

        class_loaders[f"{class_name}_train"] = DataLoader(
            class_train_subset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers
        )
        
        class_loaders[f"{class_name}_val"] = DataLoader(
            class_val_subset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers
        )

    loaders = {
        "train": train_loader,
        "val": val_loader,
        **class_loaders
    }

    print("\nDataset information:")
    print(f"Training set: {len(train_dataset)} samples")
    print(f"Validation set: {len(val_dataset)} samples")
    
    for name, loader in class_loaders.items():
        print(f"{name}: {len(loader.dataset)} samples")
    
    return loaders

In [59]:
root_directory = DATA_PATH_ROOT

loaders = create_data_loaders(
    root_dir=root_directory,
    batch_size=config.BATCH_SIZE,
    train_ratio=0.9,
    num_workers=0
)

Found 89104 .npy files total
Class mapping: {'axion': 0, 'cdm': 1, 'no_sub': 2}

Dataset information:
Training set: 80193 samples
Validation set: 8911 samples
axion_train: 26981 samples
axion_val: 2915 samples
cdm_train: 26723 samples
cdm_val: 3036 samples
no_sub_train: 26489 samples
no_sub_val: 2960 samples


In [60]:
axion_train_loader = loaders["axion_train"]
cdm_val_loader = loaders["cdm_val"]
combined_train_loader = loaders["train"]

print(f"Number of classes: {len(loaders)//2 - 1}")

for name, loader in loaders.items():
    print(f"{name}: {len(loader.dataset)} samples, {len(loader)} batches")

Number of classes: 3
train: 80193 samples, 1254 batches
val: 8911 samples, 140 batches
axion_train: 26981 samples, 422 batches
axion_val: 2915 samples, 46 batches
cdm_train: 26723 samples, 418 batches
cdm_val: 3036 samples, 48 batches
no_sub_train: 26489 samples, 414 batches
no_sub_val: 2960 samples, 47 batches


In [82]:
def random_indexes(size : int):
    forward_indexes = np.arange(size)
    np.random.shuffle(forward_indexes)
    backward_indexes = np.argsort(forward_indexes)
    return forward_indexes, backward_indexes

def take_indexes(sequences, indexes):
    return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))

class PatchShuffle(torch.nn.Module):
    def __init__(self, ratio) -> None:
        super().__init__()
        self.ratio = ratio

    def forward(self, patches : torch.Tensor):
        T, B, C = patches.shape
        remain_T = int(T * (1 - self.ratio))

        indexes = [random_indexes(T) for _ in range(B)]
        forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
        backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)

        patches = take_indexes(patches, forward_indexes)
        patches = patches[:remain_T]

        return patches, forward_indexes, backward_indexes

class MAE_Encoder(torch.nn.Module):
    def __init__(self,
                 image_size=64,
                 patch_size=4,
                 emb_dim=192,
                 num_layer=12,
                 num_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
        self.shuffle = PatchShuffle(mask_ratio)

        self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.layer_norm = torch.nn.LayerNorm(emb_dim)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.cls_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding

        patches, forward_indexes, backward_indexes = self.shuffle(patches)
        print(f"Shuffled patches: {patches.shape} | Stride: {patches.stride()}")
        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')

        return features, backward_indexes

class MAE_Decoder(torch.nn.Module):
    def __init__(self,
                 image_size=64,
                 patch_size=4,
                 emb_dim=192,
                 num_layer=4,
                 num_head=3,
                 ) -> None:
        super().__init__()

        self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
        self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.mask_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, features, backward_indexes):
        T = features.shape[0]
        backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
        features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
        features = take_indexes(features, backward_indexes)
        features = features + self.pos_embedding

        features = rearrange(features, 't b c -> b t c')
        features = self.transformer(features)
        features = rearrange(features, 'b t c -> t b c')
        features = features[1:]

        patches = self.head(features)
        mask = torch.zeros_like(patches)
        mask[T-1:] = 1
        mask = take_indexes(mask, backward_indexes[1:] - 1)
        img = self.patch2img(patches)
        mask = self.patch2img(mask)

        return img, mask

class MAE_ViT(torch.nn.Module):
    def __init__(self,
                 image_size=64,
                 patch_size=4,
                 emb_dim=192,
                 encoder_layer=12,
                 encoder_head=3,
                 decoder_layer=4,
                 decoder_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
        self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)

    def forward(self, img):
        features, backward_indexes = self.encoder(img)
        predicted_img, mask = self.decoder(features,  backward_indexes)
        return predicted_img, mask

In [83]:
total_train_loader = loaders['train']
total_val_loader = loaders['val']

In [84]:
class ViT_Classifier(torch.nn.Module):
    def __init__(self, encoder : MAE_Encoder, num_classes=3) -> None:
        super().__init__()
        encoder.mask_ratio = 0
        self.cls_token = encoder.cls_token
        self.pos_embedding = encoder.pos_embedding
        self.patchify = encoder.patchify
        self.transformer = encoder.transformer
        self.layer_norm = encoder.layer_norm
        self.head = torch.nn.Linear(self.pos_embedding.shape[-1], num_classes)

    def forward(self, img):
        patches = self.patchify(img)
        print(f"Patch stride: {patches.stride()}, Contiguous: {patches.is_contiguous()}")
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        print(f"Rearranged stride: {patches.stride()}")
        patches = patches + self.pos_embedding
        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        if not features.is_contiguous():
            print("Warning: Non-contiguous features detected!")  # [1]
            features = features.contiguous()
        logits = self.head(features[:, 0, :])
        print(logits)
        return logits


In [85]:
print(torch.load('best_mae_model2.pth').keys())

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])


In [87]:
mae = MAE_ViT(
        image_size=64,
        patch_size=4,
        emb_dim=192,
        encoder_layer=12,
        encoder_head=3,
        decoder_layer=4,
        decoder_head=3,
        mask_ratio=0.75
).to(device)
mae.load_state_dict(torch.load('best_mae_model.pth')['model_state_dict'])
model = ViT_Classifier(mae.encoder).to(device)
criterion = nn.CrossEntropyLoss().to(device)
def acc_fn(logit, label):
    return torch.mean((logit.argmax(dim=-1) == label).float())
optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
def lr_func(epoch):
    return min((epoch + 1) / (200 + 1e-8), 0.5 * (math.cos(epoch / config.EPOCHS * math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func, verbose=True)



In [88]:
def evaluate_roc(model, loader):
    model.eval()
    all_probs = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            all_probs.append(probs.cpu())
            all_labels.append(labels)
    
    probs = torch.cat(all_probs).numpy()
    labels = torch.cat(all_labels).numpy()

    fpr, tpr, roc_auc = {}, {}, {}
    for i in range(3):
        fpr[i], tpr[i], _ = roc_curve(labels == i, probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    
    roc_auc["macro"] = np.mean(list(roc_auc.values()))
    
    return fpr, tpr, roc_auc

In [89]:
optimizer.zero_grad()
def train_classifier(model, train_loader, optimizer, criterion, device, epoch):
    model.train()
    total_loss = 0.0
    all_probs = []
    all_labels = []
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]", leave=False)
    
    for images, labels in progress_bar:
        images = images.to(device)
        labels = labels.to(device)
        images = images.repeat(1, 3, 1, 1)
        print("Image shape before model:", images.shape)   
        images = images.clamp(0, 1)
        optimizer.zero_grad()
        outputs = model(images)
        print(f"Model output shape: {outputs.shape}")
        print(f"Labels shape: {labels.shape}")  
        probs = torch.softmax(outputs, dim=1)
        all_probs.append(probs.cpu())
        all_labels.append(labels.cpu())
        loss = criterion(outputs, labels)
        print("output cont: ",outputs.is_contiguous())
        print("labels cont: ",labels.is_contiguous())
        print(loss)    
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
    probs = torch.cat(all_probs).detach().to("cpu").numpy()
    labels = torch.cat(all_labels).detach().to("cpu").numpy()

    macro_auc = roc_auc_score(labels, probs, multi_class='ovr', average='macro')
    class_auc = roc_auc_score(labels, probs, multi_class='ovr', average=None)
    
    return macro_auc, class_auc , total_loss/len(train_loader.dataset)

def validate_classifier(model, val_loader, criterion, device, epoch):
    model.eval()
    total_loss = 0.0
    all_probs = []
    all_labels = []
    progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]", leave=False)
    
    with torch.no_grad():
        for images, labels in progress_bar:
            images = images.to(device)
            labels = labels.to(device)
            images = images.repeat(1, 3, 1, 1)
            images = images.clamp(0, 1)
            outputs = model(images)
            loss = criterion(outputs, labels)
            probs = torch.softmax(outputs, dim=1)
            all_probs.append(probs.cpu())
            all_labels.append(labels.cpu())
            total_loss += loss.item()
            progress_bar.set_postfix({'val_loss': loss.item()})
    probs = torch.cat(all_probs).detach().to("cpu").numpy()
    labels = torch.cat(all_labels).detach().to("cpu").numpy()

    macro_auc = roc_auc_score(labels, probs, multi_class='ovr', average='macro')
    class_auc = roc_auc_score(labels, probs, multi_class='ovr', average=None)
    
    return macro_auc, class_auc , total_loss/len(val_loader.dataset)

In [90]:
best_auc = 0
cnt = 0
# wandb.watch(
#     model,
#     log='all',
#     log_freq=50,
#     log_graph=True
# )
for epoch in range(50):
    if(cnt>10):
        break
    train_macro_auc, train_class_auc, train_loss = train_classifier(model, total_train_loader, optimizer, criterion, device, epoch)
    val_macro_auc, val_class_auc, val_loss = validate_classifier(model, total_val_loader, criterion, device, epoch)
    
    lr_scheduler.step()
    print(f"Epoch {epoch+1}/{config.EPOCHS}")
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    wandb.log({
        "epoch": epoch,
        "train_macro_auc": train_macro_auc,
        "val_macro_auc": val_macro_auc,
        "train_loss_cls": train_loss,
        "val_loss_cls": val_loss,
        **{f"train_class_{i}": train_class_auc[i] for i in range(len(train_class_auc))},
        **{f"val_class_{i}": val_class_auc[i] for i in range(len(val_class_auc))}
    })

    if val_macro_auc < best_auc:
        best_auc = val_macro_auc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'auc': val_macro_auc,
        }, "best_mae_model_classifier.pth")
        print(f"Saved new best model with val loss: {val_loss:.4f}")

    else:
        cnt += 1

Epoch 1 [Train]:   0%|          | 0/1254 [00:00<?, ?it/s]

Image shape before model: torch.Size([64, 3, 64, 64])
Patch stride: (49152, 256, 16, 1), Contiguous: True
Rearranged stride: (1, 49152, 256)


                                                         

tensor([[ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
        [ 0.7127,  0.3803, -0.1758],
 



RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [None]:
x = torch.randn(2, 1, 64, 64).to(device)  # Match input channels
with torch.no_grad():
    out = model(x)
print(out[0].shape) 

Patch stride: (49152, 256, 16, 1), Contiguous: True
Rearranged stride: (1, 49152, 256)
tensor([[-0.4271, -0.9497, -0.7466],
        [-0.6088, -0.9777, -0.9135]], device='mps:0')
torch.Size([3])


In [None]:
# Try SGD with momentum
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=1e-4,
    momentum=0.9,
    nesterov=True
)

# Add aggressive LR scheduler
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.5,  # Very high to force learning
    steps_per_epoch=len(total_train_loader),
    epochs=20,
    div_factor=10,
    final_div_factor=100
)

# # Add gradient clipping
# torch.nn.utils.clip_grad_value_(model.parameters(), 1.0)

In [None]:
x = torch.randn(64, 1, 64, 64).to(device)* 0.5 + 0.5  # Batch of random noise
y = torch.randint(0, 1, (64,)).to(device)   # Random labels

# Should see loss decrease rapidly if pipeline works
for epoch in range(10):
    optimizer.zero_grad()
    outputs = model(x)
    loss = criterion(outputs, y)
    loss.backward()
    
    # Print gradients
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.detach().data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5
    print(f"Epoch {epoch}: Loss {loss.item():.4f}, Grad Norm {total_norm:.4f}")
    
    optimizer.step()

Patch stride: (49152, 256, 16, 1), Contiguous: True
Rearranged stride: (1, 49152, 256)
tensor([[ 0.5295,  0.4326, -0.5376],
        [ 0.6002,  0.4888, -0.5002],
        [ 0.6355,  0.5752, -0.4670],
        [ 0.5756,  0.5411, -0.5141],
        [ 0.5559,  0.5105, -0.4952],
        [ 0.5530,  0.5286, -0.5353],
        [ 0.5873,  0.4629, -0.4959],
        [ 0.5475,  0.5014, -0.5208],
        [ 0.5939,  0.4649, -0.5071],
        [ 0.5330,  0.5722, -0.4829],
        [ 0.5751,  0.5399, -0.5443],
        [ 0.5976,  0.5495, -0.4857],
        [ 0.5212,  0.5207, -0.5028],
        [ 0.5961,  0.5154, -0.4940],
        [ 0.5815,  0.5479, -0.4840],
        [ 0.5778,  0.5628, -0.5304],
        [ 0.5379,  0.5145, -0.5297],
        [ 0.5786,  0.6069, -0.4967],
        [ 0.5685,  0.5575, -0.4672],
        [ 0.5874,  0.5011, -0.5251],
        [ 0.5965,  0.4880, -0.4712],
        [ 0.6031,  0.5097, -0.4804],
        [ 0.5602,  0.5148, -0.5162],
        [ 0.5873,  0.5368, -0.4902],
        [ 0.5261,  0.5227

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [None]:
# class EnhancedClassifier(nn.Module):
#     def __init__(self, encoder, num_classes=3):
#         super().__init__()
#         self.encoder = encoder
#         self.align = nn.Sequential(
#             nn.Linear(192, 512),
#             nn.LayerNorm(512),
#             nn.GELU(),
#             nn.Dropout(0.5))
#         self.classifier = nn.Linear(512, num_classes)

#     def forward(self, x):
#         features, _ = self.encoder(x)
#         aligned = self.align(features[0])  # CLS token
#         return self.classifier(aligned)

# mae = MAE_ViT(
#         image_size=64,
#         patch_size=4,
#         emb_dim=192,
#         encoder_layer=12,
#         encoder_head=3,
#         decoder_layer=4,
#         decoder_head=3,
#         mask_ratio=0.75
# ).to(device)
# mae.load_state_dict(torch.load('best_mae_model.pth')['model_state_dict'])
# model = EnhancedClassifier(mae.encoder).to(device)
# optimizer = torch.optim.AdamW([
#     {'params': model.align.parameters(), 'lr': 1e-3},
#     {'params': model.classifier.parameters(), 'lr': 1e-3},
#     {'params': model.encoder.parameters(), 'lr': 1e-5}  # Slower encoder updates
# ], weight_decay=0.01)

# # Gradient clipping
# torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# x = torch.randn(64, 3, 64, 64).to(device)* 0.5 + 0.5  # Batch of random noise
# x = (x - x.mean()) / x.std()  # Standardize
# x = torch.clamp(x, min=-3, max=3)
# y = torch.randint(0, 3, (64,)).to(device)   # Random labels
# with torch.no_grad():
#     features = model.encoder(x)[0][0]
#     aligned = model.align(features)

# print("Original feature std:", features.std().item())  # Should > 0.1
# print("Aligned feature std:", aligned.std().item())  
# # Should see loss decrease rapidly if pipeline works
# for epoch in range(10):
#     optimizer.zero_grad()
#     outputs = model(x)
#     loss = criterion(outputs, y)
#     loss.backward()
    
#     # Print gradients
#     total_norm = 0
#     for p in model.parameters():
#         if p.grad is not None:
#             param_norm = p.grad.detach().data.norm(2)
#             total_norm += param_norm.item() ** 2
#     total_norm = total_norm ** 0.5
#     print(f"Epoch {epoch}: Loss {loss.item():.4f}, Grad Norm {total_norm:.4f}")
    
#     optimizer.step()

Original feature std: 1.0000218152999878
Aligned feature std: 0.887792706489563
Epoch 0: Loss 1.1467, Grad Norm 4.9957
Epoch 1: Loss 1.2066, Grad Norm 6.9606
Epoch 2: Loss 1.1566, Grad Norm 4.6473
Epoch 3: Loss 1.1579, Grad Norm 5.1724
Epoch 4: Loss 1.0509, Grad Norm 4.7101
Epoch 5: Loss 1.0846, Grad Norm 3.8509
Epoch 6: Loss 1.0984, Grad Norm 4.1500
Epoch 7: Loss 1.1093, Grad Norm 6.1238
Epoch 8: Loss 1.0519, Grad Norm 4.9738
Epoch 9: Loss 1.0787, Grad Norm 3.8720


In [None]:
import torch
from skimage.metrics import structural_similarity as ssim
import numpy as np

def calculate_metrics(original, reconstructed, mask_patch, patch_size=4):
    """
    original: Tensor [B, C, H, W] (denormalized)
    reconstructed: Tensor [B, C, H, W] (denormalized)
    mask_patch: Tensor [B, N_patches] (1=masked)
    """
    # Convert patch mask to pixel mask
    B, C, H, W = original.shape
    mask = mask_patch.view(B,1, H, W)
    mask = mask.unsqueeze(-1).unsqueeze(-1)
    mask = mask.repeat(1, 1, 1, patch_size, patch_size)
    mask = mask.view(B, 1, H, W).float()
    
    # Calculate metrics only on masked regions
    mse = ((original - reconstructed)**2 * mask).sum() / mask.sum()
    
    # PSNR
    max_pixel = 1.0  # Assuming normalized to [0,1]
    psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
    
    # SSIM (requires numpy conversion)
    ssim_total = 0
    for b in range(B):
        orig_np = original[b].permute(1,2,0).numpy()
        recon_np = reconstructed[b].permute(1,2,0).numpy()
        mask_np = mask[b].squeeze().numpy()
        
        # SSIM on masked pixels only
        ssim_val = ssim(orig_np, recon_np, 
                        data_range=1.0,
                        multichannel=True,
                        full=True,
                        win_size=3,
                        use_sample_covariance=False)[1]
        ssim_masked = (ssim_val * mask_np).sum() / mask_np.sum()
        ssim_total += ssim_masked
        
    return {
        'mse': mse.item(),
        'psnr': psnr.item(),
        'ssim': ssim_total/B
    }

In [None]:
import matplotlib.pyplot as plt

def visualize_reconstructions(original, masked, reconstructed, mask_patch, num_samples=3):
    """
    original: Tensor [B, C, H, W]
    masked: Tensor [B, C, H, W]
    reconstructed: Tensor [B, C, H, W]
    mask_patch: Tensor [B, N_patches]
    """
    patch_size = 4
    B, C, H, W = original.shape
    
    fig, axs = plt.subplots(num_samples, 4, figsize=(15, num_samples*3))
    
    for i in range(num_samples):
        # Original image
        axs[i,0].imshow(original[i].permute(1,2,0))
        axs[i,0].set_title('Original')
        axs[i,0].axis('off')
        
        # Mask visualization
        patch_mask = mask_patch[i].view(H//patch_size, W//patch_size)
        axs[i,1].imshow(patch_mask, cmap='gray')
        axs[i,1].set_title('Patch Mask')
        axs[i,1].axis('off')
        
        # Masked input
        axs[i,2].imshow(masked[i].permute(1,2,0))
        axs[i,2].set_title('Masked Input')
        axs[i,2].axis('off')
        
        # Reconstruction
        axs[i,3].imshow(reconstructed[i].permute(1,2,0))
        axs[i,3].set_title('Reconstruction')
        axs[i,3].axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
def verify_mae(model, dataloader, device):
    model.eval()
    metrics = {'mse': [], 'psnr': [], 'ssim': []}
    
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            
            # Forward pass with mask return
            reconstructed, pixel_mask, patch_mask = model(images)
            
            # Denormalize
            original = (images * 0.5) + 0.5  # Assuming normalization was (mean=0.5, std=0.5)
            reconstructed = (reconstructed * 0.5) + 0.5
            
            # Calculate metrics
            batch_metrics = calculate_metrics(
                original.cpu(),
                reconstructed.cpu(),
                patch_mask.cpu()
            )
            
            # Store metrics
            for k in metrics:
                metrics[k].append(batch_metrics[k])
            
            # Visualize first batch
            if len(metrics['mse']) == 1:
                masked_input = original * (1 - patch_mask.unsqueeze(1))  # Create masked input
                visualize_reconstructions(
                    original.cpu(),
                    masked_input.cpu(),
                    reconstructed.cpu(),
                    patch_mask.cpu()
                )
    
    # Aggregate metrics
    final_metrics = {k: np.mean(v) for k,v in metrics.items()}
    print("\nFinal Metrics:")
    print(f"MSE: {final_metrics['mse']:.4f}")
    print(f"PSNR: {final_metrics['psnr']:.2f} dB")
    print(f"SSIM: {final_metrics['ssim']:.4f}")
    
    return final_metrics

In [None]:
mae = MAE_ViT(
        image_size=64,
        patch_size=4,
        emb_dim=192,
        encoder_layer=12,
        encoder_head=3,
        decoder_layer=4,
        decoder_head=3,
        mask_ratio=0.75
).to(device)
mae.load_state_dict(torch.load('best_mae_model.pth')['model_state_dict'])
# Create test dataloader with proper normalization

test_loader = loaders['val']
# Run verification
metrics = verify_mae(mae, test_loader, device)

AttributeError: 'MAE_ViT' object has no attribute 'image_size'

In [None]:
# Add gradient monitoring
for batch_idx, (data, target) in enumerate(total_train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    
    # Log gradients
    total_grad = 0
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_mean = param.grad.abs().mean().item()
            total_grad += grad_mean
    print(f"Batch {batch_idx} - Avg grad magnitude: {total_grad/len(list(model.parameters())):.6f}")
    
    optimizer.step()

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [None]:
def evaluate_roc(model, loader):
    model.eval()
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for images, labels in loader:
            images = images.cuda()
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            all_probs.append(probs.cpu())
            all_labels.append(labels)

    probs = torch.cat(all_probs).numpy()
    labels = torch.cat(all_labels).numpy()

    # Calculate metrics
    fpr, tpr, roc_auc = {}, {}, {}
    wandb_roc_data = []

    for i in range(3):
        fpr[i], tpr[i], _ = roc_curve(labels == i, probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

        # Store data for wandb
        for j in range(len(fpr[i])):
            wandb_roc_data.append([fpr[i][j], tpr[i][j], f'Class {i}'])

    roc_auc["macro"] = np.mean(list(roc_auc.values()))

    # Log ROC curve to wandb
    wandb.log({"roc_curve": wandb.Table(data=wandb_roc_data, columns=["FPR", "TPR", "Class"])})

    # Log AUC values
    wandb.log({f"AUC_Class_{i}_final": roc_auc[i] for i in range(3)})
    wandb.log({"AUC_Macro_final": roc_auc["macro"]})

    return fpr, tpr, roc_auc

fpr, tpr, roc_auc = evaluate_roc(mae, total_val_loader)