### Importing Required Libraries  
This section imports essential libraries for image classification using Vision Transformers (ViT), including PyTorch, torchvision, Transformers, and sklearn for model training, preprocessing, and evaluation.


In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import os
from torchvision import datasets
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import matplotlib.pyplot as plt
from transformers import ViTForImageClassification, ViTFeatureExtractor
from torchvision.transforms import functional as F
from torchvision import transforms
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, confusion_matrix, precision_score, recall_score, f1_score
from itertools import cycle
from einops import rearrange, repeat

### Custom Dataset and DataLoader Creation  
Defines `NpyDataset`, a PyTorch dataset class for loading `.npy` image files, converting them to 3-channel tensors, and applying transformations. The `create_dataloaders` function splits the dataset into training and validation sets and creates corresponding DataLoaders.


In [None]:
class NpyDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.data = []
        self.labels = []
        class_folders = sorted(os.listdir(folder_path))  # Ensure class order consistency
        
        for class_idx, class_folder in enumerate(class_folders):
            class_path = os.path.join(folder_path, class_folder)
            if not os.path.isdir(class_path):
                continue
            
            for file_name in os.listdir(class_path):
                if file_name.endswith(".npy"):
                    file_path = os.path.join(class_path, file_name)
                    self.data.append(file_path)
                    self.labels.append(class_idx)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        file_path = self.data[idx]
        array = np.load(file_path, allow_pickle=True)  # Load .npy file
        first_obj = array[0]  # Take only the first object (64, 64)
        first_obj = torch.tensor(first_obj, dtype=torch.float32).unsqueeze(0)  # Convert to tensor and add channel dim
        first_obj = first_obj.repeat(3, 1, 1)  # Convert grayscale to 3-channel format
        
        if self.transform:
            first_obj = self.transform(first_obj)
        
        label = self.labels[idx]
        return first_obj, label

def create_dataloaders(folder_path, batch_size=32, shuffle=True, val_split=0.1):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to (224, 224)
    ])
    
    dataset = NpyDataset(folder_path, transform=transform)
    
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return train_loader, val_loader

### Evaluation Metrics: Accuracy and ROC-AUC  
Defines an `accuracy` function to compute classification accuracy and a `compute_roc_auc` function to calculate the ROC curve and AUC scores for a multi-class classification problem using a one-vs-rest approach.


In [None]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs.logits, dim = 1)
    return torch.sum(preds == labels).item() / len(labels)

# Function to compute ROC curve and AUC for each class
def compute_roc_auc(all_labels, all_logits):
    # convert to numpy array 
    all_labels = all_labels.cpu().numpy()
    all_logits = all_logits.cpu().numpy()

    n_classes = 3

    fpr = {}
    tpr = {}
    roc_auc = {}

    # for each class compute the ROC curve using one-vs-rest approach
    for i in range (n_classes):
        fpr[i], tpr[i], _ = roc_curve(all_labels == i, all_logits[:,i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(
        np.eye(n_classes)[all_labels].ravel(), all_logits.ravel()
    )
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    return fpr, tpr, roc_auc

### Patch Embedding Layer for Vision Transformer  
Implements `PatchEmbed`, a PyTorch module that converts an input image into a sequence of patch embeddings using a convolutional layer. This is a key component of Vision Transformers (ViTs), where images are divided into patches and projected into an embedding space.


In [None]:
class PatchEmbed(nn.Module):
    def __init__(self,
                img_size = 224,
                patch_size = 16,
                in_chans = 3,
                embed_dim = 768):
        super().__init__()
        self.img_size = (img_size, img_size)
        self.patch_size = (patch_size, patch_size)
        self.num_patches = (img_size//patch_size)**2

        self.proj = nn.Conv2d(in_chans,
                              embed_dim,
                              kernel_size = patch_size,
                              stride = patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

### Transformer Encoder for Vision Transformer  
Defines `TransformerEncoder`, a stack of transformer blocks used in Vision Transformers (ViTs). It consists of multiple attention layers with normalization and dropout, processing image patch embeddings sequentially.


In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self,
                embed_dim = 768,
                depth = 12,
                num_heads = 12,
                mlp_ratio = 4,
                qkv_bias = True,
                drop_rate = 0.,
                attn_drop_rate = 0.,
                drop_path_rate = 0.):
        super().__init__()
        # Create a sequence of transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim = embed_dim,
                num_heads = num_heads,
                mlp_ratio = mlp_ratio,
                qkv_bias = qkv_bias,
                drop = drop_rate,
                attn_drop = attn_drop_rate,
                # Stocastic depth: gradually increase drop_path rate for deeper blocks
                drop_path = drop_path_rate*i/depth)
            for i in range(depth)])
        self.norm = nn.LayerNorm(embed_dim) # Final norm Layer

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        return x

### Transformer Block for Vision Transformer  
Defines `TransformerBlock`, a core component of the Vision Transformer (ViT). Each block consists of layer normalization, multi-head self-attention, and a feed-forward MLP with residual connections for efficient feature learning.


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self,
                dim,
                num_heads,
                mlp_ratio = 4.,
                qkv_bias = False,
                drop = 0.,
                attn_drop = 0.,
                drop_path = 0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim) #First Layer Norm
        self.attn = Attention(dim,
                              num_heads = num_heads,
                             qkv_bias = qkv_bias,
                             attn_drop = attn_drop,
                             proj_drop = drop) #Multi-Head attention
        
        self.norm2 = nn.LayerNorm(dim) #Second Layer Norm
        mlp_hidden_dim = int(dim * mlp_ratio) # Hidden dimentions of MLP
        self.mlp = Mlp(in_features = dim,
                      hidden_features = mlp_hidden_dim,
                      drop = drop)
    def forward(self, x):
        x = x + self.attn(self.norm1(x)) #Self-attention block with residual connection
        x = x + self.mlp(self.norm2(x)) #Mlp Block with residual connection
        return x

### Multi-Head Self-Attention Mechanism  
Implements `Attention`, a multi-head self-attention module used in Vision Transformers (ViTs). It computes relationships between tokens using scaled dot-product attention, applies dropout, and projects the output back to the embedding dimension.


In [None]:
class Attention(nn.Module):
    def __init__(self,
                dim,
                num_heads = 8,
                qkv_bias = False,
                attn_drop = 0.,
                proj_drop = 0.):
        super().__init__()
        self.num_heads = num_heads # Number of attention Heads
        head_dim = dim//num_heads #Dimension of each Head
        self.scale = head_dim ** -0.5 #Scaling Factor for Dot product
        
        #Linear Proj for Q,K,V for all heads simultaneously
        self.qkv = nn.Linear(dim, dim*3, bias = qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop) # Dropout for attention Matrix
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)  # Dropout for output Projection

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C//self.num_heads).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2] #Shape: [B, H, N, C/H]

        # Compute scaled dot-product attention
        # (q@k.transpose) calculates similarity between query and key vectors
        attn = (q@k.transpose(-2, -1))*self.scale #B, H, N, N
        attn = attn.softmax(dim = -1)
        attn = self.attn_drop(attn)

        # Apply attention weights to values
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        #Project back to embeding dimention
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

### Multi-Layer Perceptron (MLP) for Vision Transformer  
Defines `Mlp`, a feed-forward neural network with two linear layers, GELU activation, and dropout. It is used in transformer blocks to process token embeddings after the attention mechanism.


In [None]:
class Mlp(nn.Module):
    """
    Multi-layer preceptron: Implements a simple feed forward network with one hidden layer and GELU activation
    """
    def __init__(self,
                in_features,
                hidden_features = None,
                out_features = None,
                drop = 0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        # Two-Layer MLP
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        """
        Input: B × N × in_features
        Output: B × N × out_features
        """
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

### Transformer Decoder for Vision Transformer  
Implements `TransformerDecoder`, a stack of transformer blocks designed to process encoded representations. It consists of multiple self-attention layers, feed-forward MLPs, and layer normalization to refine feature embeddings.


In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self,
                embed_dim = 768,
                depth = 8,
                num_heads = 16,
                mlp_ratio = 4.,
                qkv_bias = True,
                drop_rate = 0.,
                attn_drop_rate = 0.,
                drop_path_rate = 0.):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim = embed_dim,
                num_heads = num_heads,
                mlp_ratio = mlp_ratio,
                qkv_bias = qkv_bias,
                drop = drop_rate,
                attn_drop = attn_drop_rate,
                drop_path = drop_path_rate * i / depth
            ) for i in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        for block in self.blocks:
            x  = block(x)
        x = self.norm(x)
        return x
        

### Masked Autoencoder (MAE) Implementation
This class implements a Masked Autoencoder (MAE) using a Vision Transformer (ViT) backbone. It randomly masks image patches, encodes visible patches with a transformer encoder, and reconstructs the full image using a transformer decoder. The model is trained using an MSE loss calculated only on masked patches.


In [None]:
class MaskedAutoEncoder(nn.Module):
    def __init__(self,
                img_size = 224,
                patch_size = 16,
                in_chans = 3,
                embed_dim = 1024,
                depth = 24,
                num_heads = 16,
                decoder_embed_dim = 512,
                decoder_depth = 8,
                decoder_num_heads = 16,
                mlp_ratio = 4.,
                norm_layer = nn.LayerNorm):
        super().__init__()
        # Encoder Components
        # PatchEmbed splits the image into patches and embeds them
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.num_patches = self.patch_embed.num_patches

        # Class token and Positional Encoding for encoder
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))

        # Encoder for the visible Patches
        self.encoder = TransformerEncoder(
            embed_dim = embed_dim,
            depth = depth,
            num_heads = num_heads,
            mlp_ratio = mlp_ratio
        )

        # Decoder components
        # Convert the encoder output to decoder dimension
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias = True)
        # Learnable mask token that is used for masked patches
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        # Positional encoding ffor the decoder
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, decoder_embed_dim))
        
        # Decoder to reconstruct the full image
        self.decoder = TransformerDecoder(
            embed_dim = decoder_embed_dim,
            depth = decoder_depth,
            num_heads = decoder_num_heads,
            mlp_ratio = mlp_ratio
            )
        # Final prediction Layer: predict pixel values for eeach patch
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias = True)
        
        # Initialize weights for all components
        self.initialize_weights()

        # Store model parameters for later use
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.img_size = img_size

    def initialize_weights(self):
        pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, self.pos_embed.shape[2]))
        decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, self.decoder_pos_embed.shape[2]))
        
        
        # Use truncated normal distributions 
        nn.init.trunc_normal_(self.pos_embed, std = 0.02)
        nn.init.trunc_normal_(self.decoder_pos_embed, std = 0.02)

        nn.init.trunc_normal_(self.cls_token, std = 0.02)
        nn.init.trunc_normal_(self.mask_token, std = 0.02)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std = 0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def random_masking(self, x, mask_ratio):
        N, L, D = x.shape # Batch, length, dimension
        len_keep = int(L * (1 - mask_ratio)) # Number of patches to keep

        # Generate uniform random noise for each patch in each sample
        noise = torch.rand(N, L, device = x.device) 

        # Sort noise to determine which patches to keep/remove
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # Keep the first len_keep patches (lowest noise values)
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index = ids_keep.unsqueeze(-1).repeat(1, 1, D))

        mask = torch.ones([N, L], device = x.device)
        mask[:, :len_keep] = 0
        # Unshuffle to get the binary mask for original sequence
        mask = torch.gather(mask, dim = 1, index = ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        # Convert img to patches
        x = self.patch_embed(x)
    
        # Add positional embeddings
        cls_token = self.cls_token + self.pos_embed[:, :1, :]  # [1, 1, D]
        x = x + self.pos_embed[:, 1:, :]  # [B, N, D]
    
        # Apply random masking
        x, mask, ids_restore = self.random_masking(x, mask_ratio)  # [B, N', D]

        # Expand class token to match batch size
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)  # [B, 1, D]

        # Concatenate cls_token and image tokens
        x = torch.cat((cls_tokens, x), dim=1)  # [B, N'+1, D]

        # Process through transformer encoder
        x = self.encoder(x)

        return x, mask, ids_restore


    def forward_decoder(self, x, ids_restore):
        # embed the encoder output
        x = self.decoder_embed(x)

        # add mask tokens
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)

        # exclude class token x[:, 1:] and append mask tokens
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim = 1)

        # unshuffle: restore the original sequence order
        x_ = torch.gather(x_, dim = 1, index = ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))

        # append class token
        x = torch.cat([x[:, :1, :], x_], dim = 1)

        # apply positional embedding 
        x = x + self.decoder_pos_embed

        # apply transformer decoder 
        x = self.decoder(x)

        # predict pixel values for each patch
        x = self.decoder_pred(x)

        # remove class token from prediictions
        x = x[:, 1:, :]

        return x

    def forward(self, imgs, mask_ratio = 0.75):
        # Forward Pass the entire MAE model

        # run encoder on the images with masking
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)

        # run decoder to predict all patches
        pred = self.forward_decoder(latent, ids_restore)

        # convert input images to patches for loss calculation
        target = self.patchify(imgs)

        # calculate mse loss only for masked patches
        loss = self.calculate_loss(pred, target, mask)

        return loss, pred, mask

    def patchify(self, imgs):
        # convert imgs to patches for calculating loss
        p = self.patch_size
        h = w = self.img_size//p

        x = rearrange(imgs, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        return x

    def unpatchify(self, x):
        p = self.patch_size
        h = w = int(x.shape[1] ** 0.5)

        imgs = rearrange(x, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h = h, w = w, p1 = p, p2 = p)
        return imgs

    def calculate_loss(self, pred, target, mask):
        # calculate mse loss for masked patches only

        loss = (pred - target)**2
        loss = loss.mean(dim = -1)

        loss = (loss*mask).sum()/mask.sum() 

        return loss

In [None]:
class MAEClassifier(nn.Module):
    def __init__(self, mae_model, num_classes):
        super().__init__()
        
        # Use the pre-trained MAE encoder
        self.patch_embed = mae_model.patch_embed
        self.cls_token = mae_model.cls_token
        self.pos_embed = mae_model.pos_embed
        self.encoder = mae_model.encoder
        
        # Determine the embedding dimension dynamically
        embed_dim = mae_model.encoder.blocks[0].norm1.weight.shape[0]
        
        # Freeze encoder weights
        for param in self.patch_embed.parameters():
            param.requires_grad = False
        for param in self.encoder.parameters():
            param.requires_grad = True
        
        # Classification head
        self.classification_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
    
    def forward(self, x):
        # Patch embedding
        x = self.patch_embed(x)
    
        # Add positional embeddings
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        x = x + self.pos_embed[:, 1:, :]
        
        # Expand class token to match batch size
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        
        # Concatenate cls_token and image tokens
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Encode
        x = self.encoder(x)
        
        # Use cls token for classification
        x = x[:, 0]
        
        # Classification
        return self.classification_head(x)

### Training Function for MAE Classifier
This function trains a Masked Autoencoder (MAE) classifier using cross-entropy loss and AdamW optimizer. It employs mixed precision training, cosine annealing learning rate scheduling, and calculates multi-class AUC-ROC scores for performance evaluation on both training and validation sets.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from sklearn.metrics import roc_auc_score
import numpy as np

def train_classifier(model, train_loader, val_loader, num_epochs=5, learning_rate=5e-4):
    # Loss and Optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        model.parameters(),  
        lr=learning_rate, 
        weight_decay=0.05
    )
    
    # Learning Rate Scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 
        T_max=num_epochs
    )
    
    # Mixed Precision Training
    scaler = GradScaler()
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    # Training Loop
    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        # Containers for AUC-ROC calculation
        train_true_labels = []
        train_pred_probs = []
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Mixed precision forward pass
            # with autocast():
            #     outputs = model(images)
            #     loss = criterion(outputs, labels)

            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Backward pass with gradient scaling
            # scaler.scale(loss).backward()
            # scaler.step(optimizer)
            # scaler.update()
            
            # Metrics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
            
            # Collect data for AUC-ROC
            train_true_labels.extend(labels.cpu().numpy())
            train_pred_probs.extend(torch.softmax(outputs, dim=1).cpu().detach().numpy())
        
        # Validation Phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        # Containers for AUC-ROC calculation
        val_true_labels = []
        val_pred_probs = []
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
                
                # Collect data for AUC-ROC
                val_true_labels.extend(labels.cpu().numpy())
                val_pred_probs.extend(torch.softmax(outputs, dim=1).cpu().detach().numpy())
        
        # Calculate per-class AUC-ROC for Training Set
        train_auc_roc = []
        for cls in range(3):  # Assuming 3 classes
            train_binary_labels = np.array([(label == cls).astype(int) for label in train_true_labels])
            train_class_probs = np.array([pred_prob[cls] for pred_prob in train_pred_probs])
            train_auc = roc_auc_score(train_binary_labels, train_class_probs)
            train_auc_roc.append(train_auc)
        
        # Calculate per-class AUC-ROC for Validation Set
        val_auc_roc = []
        for cls in range(3):  # Assuming 3 classes
            val_binary_labels = np.array([(label == cls).astype(int) for label in val_true_labels])
            val_class_probs = np.array([pred_prob[cls] for pred_prob in val_pred_probs])
            val_auc = roc_auc_score(val_binary_labels, val_class_probs)
            val_auc_roc.append(val_auc)
        
        # Learning rate step
        scheduler.step()
        
        # Print metrics
        print(f'Epoch [{epoch+1}/{num_epochs}]')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}, '
              f'Train Accuracy: {100*train_correct/train_total:.2f}%')
        print(f'Train AUC-ROC per class: {[f"{auc:.4f}" for auc in train_auc_roc]}')
        
        print(f'Val Loss: {val_loss/len(val_loader):.4f}, '
              f'Val Accuracy: {100*val_correct/val_total:.2f}%')
        print(f'Val AUC-ROC per class: {[f"{auc:.4f}" for auc in val_auc_roc]}')
    
    return model

In [None]:
def load_mae_weights(model, checkpoint_path):
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
        
    # Load the state dict directly
    model.load_state_dict(checkpoint)
    print("Successfully loaded weights.")
        
    return checkpoint

In [None]:
path = "/kaggle/input/dataset-task-6/Dataset"
train_loader, val_loader = create_dataloaders(path, batch_size=32)

In [None]:
# Usage
checkpoint_path = '/kaggle/input/pretrained-masked-auto-encoder/pytorch/default/1/model.pth'
mae_model = MaskedAutoEncoder()
checkpoint = load_mae_weights(mae_model, checkpoint_path)

# Create the classifier
num_classes = 3
mae_classifier = MAEClassifier(mae_model, num_classes)

# Fine-tune the classifier
trained_model = train_classifier(mae_classifier, train_loader, val_loader)

In [None]:
model_path = "/kaggle/working/finetuned_mae.pth"
torch.save(trained_model.state_dict(), model_path)