<a href="https://colab.research.google.com/github/raihanewubd/selfSupervised/blob/main/i_jepa_aav_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import ViTForImageClassification, ViTImageProcessor
from torchvision.models import vit_b_16



import timm
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

from collections import Counter


import random
import os
import copy
import time
import pickle



In [2]:
#drive.mount('/content/drive')

In [3]:
# Define base directory and file name for saving the classifier checkpoint.
base_dir = "/kaggle/working"
#base_dir = "/content/drive/MyDrive/AAVDATASET/spectrogram"

In [4]:
data_dir = '/kaggle/input/aav-spectrogram/spectrogram'
#data_dir = os.path.join(base_dir,'dataset')
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = datasets.ImageFolder(root=data_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [5]:
# Set device for GPU acceleration if available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [6]:
def extract_blocks(image, context_scale=0.85, target_scale=0.2, num_targets=4, max_overlap=0.5):
    # Extract a central context block.
    _, H, W = image.shape
    context_size = int(context_scale * H)
    top = (H - context_size) // 2
    left = (W - context_size) // 2
    context_block = image[:, top:top+context_size, left:left+context_size]
    context_block = torch.nn.functional.interpolate(
        context_block.unsqueeze(0),
        size=(224, 224),
        mode='bilinear',
        align_corners=False
    ).squeeze(0)

    # Extract num_targets target blocks randomly.
    target_blocks = []
    for _ in range(num_targets):
        target_size = int(target_scale * H)
        top_t = random.randint(0, H - target_size)
        left_t = random.randint(0, W - target_size)
        target_block = image[:, top_t:top_t+target_size, left_t:left_t+target_size]
        target_block = torch.nn.functional.interpolate(
            target_block.unsqueeze(0),
            size=(224, 224),
            mode='bilinear',
            align_corners=False
        ).squeeze(0)
        target_blocks.append(target_block)
    target_blocks = torch.stack(target_blocks)
    return context_block, target_blocks, (top, left, context_size), None



In [7]:
def process_sample(sample, context_scale, target_scale, num_targets):
    # Unpack sample: sample is ((img, label), image_path)
    (img, label), image_path = sample
    # Move image to GPU if available.
    img = img.to(device)
    context_block, target_blocks, _, _ = extract_blocks(img, context_scale, target_scale, num_targets)
    # Bring results back to CPU before caching.
    return (context_block.cpu(), target_blocks.cpu(), label, image_path)

In [8]:
class PrecomputedIJEPADataset(Dataset):
    def __init__(self, base_dataset, context_scale=0.85, target_scale=0.2, num_targets=4, cache_file=None):
        self.cache_file = cache_file
        if cache_file and os.path.exists(cache_file):
            # Load precomputed data from disk.
            with open(cache_file, 'rb') as f:
                self.data = pickle.load(f)
        else:
            # Create a list of samples along with their original image paths (if available) using a progress bar.
            if hasattr(base_dataset, 'samples'):
                base_samples = [
                    (base_dataset[i], base_dataset.samples[i][0])
                    for i in tqdm(range(len(base_dataset)), desc="Loading samples")
                ]
            else:
                base_samples = [
                    (sample, None) for sample in tqdm(base_dataset, desc="Loading samples")
                ]

            # Process samples sequentially with a progress bar.
            self.data = []
            for sample in tqdm(base_samples, desc="Processing samples"):
                result = process_sample(sample, context_scale, target_scale, num_targets)
                self.data.append(result)

            if cache_file:
                with open(cache_file, 'wb') as f:
                    pickle.dump(self.data, f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


In [9]:


# Timing the loading of the dataset and DataLoader
cache_path = os.path.join(base_dir,"precomputed_fulldataset_aav.pkl")
print(cache_path)
start_time = time.time()
dataset_aav_ijepa = PrecomputedIJEPADataset(dataset, cache_file=cache_path)
end_time_train_ijepa_dataset = time.time()
dataloader_aav_ijepa = DataLoader(dataset_aav_ijepa, batch_size=32, shuffle=True)
end_time = time.time()
print(f"Time taken to load dataset: {end_time_train_ijepa_dataset - start_time:.4f} seconds and DataLoader: {end_time - end_time_train_ijepa_dataset:.4f} seconds")


/kaggle/working/precomputed_fulldataset_aav.pkl


Loading samples: 100%|██████████| 3513/3513 [00:28<00:00, 125.14it/s]
Processing samples: 100%|██████████| 3513/3513 [00:12<00:00, 287.50it/s]


Time taken to load dataset: 77.0128 seconds and DataLoader: 0.0011 seconds


In [10]:
num_images = len(dataset_aav_ijepa)
print(f"Number of images in the dataset: {num_images}")

Number of images in the dataset: 3513


In [11]:
total_batches = len(dataloader_aav_ijepa)
print("Total number of batches:", total_batches)

Total number of batches: 110


In [12]:
'''import torch
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split


# Define the split ratio (e.g., 80% train, 20% test)
train_ratio = 0.8
test_ratio = 1 - train_ratio

# Get the total number of samples in the dataset.
num_samples = len(dataset_aav_ijepa)

# Create a list of indices for all samples in the dataset.
indices = list(range(num_samples))

# Split the indices into train and test sets using train_test_split.
train_indices, test_indices = train_test_split(indices, test_size=test_ratio, random_state=42)  # Set random_state for reproducibility.

# Create Subset datasets for train and test using the split indices.
train_dataset_aav_ijepa = Subset(dataset_aav_ijepa, train_indices)
test_dataset_aav_ijepa = Subset(dataset_aav_ijepa, test_indices)

# Create DataLoaders for the train and test datasets.
train_loader_aav_ijepa = torch.utils.data.DataLoader(train_dataset_aav_ijepa, batch_size=32, shuffle=True)
test_loader_aav_ijepa = torch.utils.data.DataLoader(test_dataset_aav_ijepa, batch_size=32, shuffle=False)  # No need to shuffle the test set.

print(f"Training set size: {len(train_dataset_aav_ijepa)}")
print(f"Testing set size: {len(test_dataset_aav_ijepa)}")'''

'import torch\nfrom torch.utils.data import Subset\nfrom sklearn.model_selection import train_test_split\n\n\n# Define the split ratio (e.g., 80% train, 20% test)\ntrain_ratio = 0.8\ntest_ratio = 1 - train_ratio\n\n# Get the total number of samples in the dataset.\nnum_samples = len(dataset_aav_ijepa)\n\n# Create a list of indices for all samples in the dataset.\nindices = list(range(num_samples))\n\n# Split the indices into train and test sets using train_test_split.\ntrain_indices, test_indices = train_test_split(indices, test_size=test_ratio, random_state=42)  # Set random_state for reproducibility.\n\n# Create Subset datasets for train and test using the split indices.\ntrain_dataset_aav_ijepa = Subset(dataset_aav_ijepa, train_indices)\ntest_dataset_aav_ijepa = Subset(dataset_aav_ijepa, test_indices)\n\n# Create DataLoaders for the train and test datasets.\ntrain_loader_aav_ijepa = torch.utils.data.DataLoader(train_dataset_aav_ijepa, batch_size=32, shuffle=True)\ntest_loader_aav_ij

In [13]:
import torch
from torch.utils.data import Subset, random_split
from sklearn.model_selection import train_test_split

# Define the split ratios (e.g., 70% train, 15% validation, 15% test)
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

# Get the total number of samples in the dataset.
num_samples = len(dataset_aav_ijepa)

# 1. Split into train and (validation + test)
train_indices, val_test_indices = train_test_split(
    list(range(num_samples)),
    test_size=val_ratio + test_ratio,
    random_state=42  # Set random_state for reproducibility
)

# 2. Split (validation + test) into validation and test
val_indices, test_indices = train_test_split(
    val_test_indices,
    test_size=test_ratio / (val_ratio + test_ratio),
    random_state=42  # Set random_state for reproducibility
)

# Create Subset datasets for train, validation, and test
train_dataset_aav_ijepa = Subset(dataset_aav_ijepa, train_indices)
val_dataset_aav_ijepa = Subset(dataset_aav_ijepa, val_indices)
test_dataset_aav_ijepa = Subset(dataset_aav_ijepa, test_indices)

# Create DataLoaders for train, validation, and test
train_loader_aav_ijepa = torch.utils.data.DataLoader(train_dataset_aav_ijepa, batch_size=32, shuffle=True)
val_loader_aav_ijepa = torch.utils.data.DataLoader(val_dataset_aav_ijepa, batch_size=32, shuffle=False)
test_loader_aav_ijepa = torch.utils.data.DataLoader(test_dataset_aav_ijepa, batch_size=32, shuffle=False)

print(f"Training set size: {len(train_dataset_aav_ijepa)}")
print(f"Validation set size: {len(val_dataset_aav_ijepa)}")
print(f"Testing set size: {len(test_dataset_aav_ijepa)}")

Training set size: 2459
Validation set size: 527
Testing set size: 527


In [14]:
def get_vit_encoder():
    model = vit_b_16(pretrained=False)
    model.heads = nn.Identity()  # remove classification head
    return model

In [15]:

context_encoder = get_vit_encoder().cuda()
target_encoder  = get_vit_encoder().cuda()
target_encoder.load_state_dict(context_encoder.state_dict())



<All keys matched successfully>

In [16]:
class Predictor(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=768, output_dim=768, num_targets=4):
        super().__init__()
        self.num_targets = num_targets
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim * num_targets)
        )
    def forward(self, context_repr):
        pred = self.mlp(context_repr)
        # Reshape to [B, num_targets, output_dim]
        return pred.view(-1, self.num_targets, pred.size(-1) // self.num_targets)

In [17]:
# 6. Set up optimizer, loss, and EMA update (same as your CIFAR code).
predictor = Predictor().cuda()
optimizer = optim.Adam(list(context_encoder.parameters()) + list(predictor.parameters()), lr=1e-1)
criterion = nn.MSELoss()
ema_decay = 0.99

In [18]:
@torch.no_grad()
def update_ema(model, model_ema, beta):
    for param, param_ema in zip(model.parameters(), model_ema.parameters()):
        param_ema.data.mul_(beta).add_(param.data, alpha=1 - beta)

In [19]:
sample = next(iter(train_dataset_aav_ijepa))
for i, item in enumerate(sample):
    print(f"Item {i}: shape = {item.shape if hasattr(item, 'shape') else type(item)}")
    print(f"Item {i}: shape = {item}")

Item 0: shape = torch.Size([3, 224, 224])
Item 0: shape = tensor([[[-0.5137, -0.5137, -0.5137,  ..., -0.5294, -0.5294, -0.5294],
         [-0.5137, -0.5137, -0.5100,  ..., -0.5294, -0.5294, -0.5294],
         [-0.5089, -0.5089, -0.5070,  ..., -0.5197, -0.5197, -0.5197],
         ...,
         [-0.3569, -0.3569, -0.3569,  ..., -0.3155, -0.3155, -0.3155],
         [-0.3926, -0.3926, -0.3926,  ..., -0.3327, -0.3327, -0.3327],
         [-0.5137, -0.5137, -0.5137,  ..., -0.4902, -0.4902, -0.4902]],

        [[ 0.4510,  0.4510,  0.4510,  ...,  0.4118,  0.4118,  0.4118],
         [ 0.4510,  0.4510,  0.4510,  ...,  0.3671,  0.3633,  0.3633],
         [ 0.4558,  0.4558,  0.4558,  ...,  0.4141,  0.4123,  0.4123],
         ...,
         [ 0.5245,  0.5245,  0.5245,  ...,  0.5383,  0.5383,  0.5383],
         [ 0.5116,  0.5116,  0.5116,  ...,  0.5315,  0.5315,  0.5315],
         [ 0.4510,  0.4510,  0.4510,  ...,  0.4588,  0.4588,  0.4588]],

        [[-0.1373, -0.1373, -0.1373,  ..., -0.1059, -0.105

In [20]:
# Create a directory for visualizations if it doesn't exist.
#viz_dir = "/kaggle/working/viz"
#os.makedirs(viz_dir, exist_ok=True)

num_epochs = 150
ema_decay = 0.1
best_loss = float('inf')
total_start_time = time.time()

for epoch in range(num_epochs):
    epoch_start_time = time.time()
    context_encoder.train()
    predictor.train()
    running_loss = 0.0

    # Enumerate over batches with a progress bar.
    for batch_idx, (context_block, target_blocks, class_label, filepath) in enumerate(tqdm(train_loader_aav_ijepa, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)):
        context_block = context_block.cuda()            # [B, C, 224, 224]
        target_blocks = target_blocks.cuda()              # [B, num_targets, C, 224, 224]

        # Forward pass through context encoder and predictor.
        context_repr = context_encoder(context_block)     # [B, 768]
        preds = predictor(context_repr)                   # [B, num_targets, 768]

        B, num_targets, C, Ht, Wt = target_blocks.shape
        target_blocks_flat = target_blocks.view(B * num_targets, C, Ht, Wt)
        with torch.no_grad():
            target_repr_flat = target_encoder(target_blocks_flat)
        target_repr = target_repr_flat.view(B, num_targets, -1)

        loss = criterion(preds, target_repr)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        update_ema(context_encoder, target_encoder, ema_decay)
        running_loss += loss.item() * context_block.size(0)

        # --- Visualization for first image of the current batch ---
        '''with torch.no_grad():
            # Get the first sample's context block and compute its feature vector.
            context_img = context_block[0].cpu()  # shape: [C, 224, 224]
            context_feat = context_encoder(context_block[0].unsqueeze(0)).cpu().squeeze(0)  # shape: [768]
            # Reshape feature vector to a 2D heatmap (24x32).
            context_heat = context_feat.view(24, 32).numpy()

            # For target, choose the first target block of the first sample.
            target_img = target_blocks[0][0].cpu()  # shape: [C, 224, 224]
            target_feat = target_encoder(target_blocks[0][0].unsqueeze(0).to(context_block.device)).cpu().squeeze(0)
            target_heat = target_feat.view(24, 32).numpy()

            # Plot the images and corresponding heatmaps.
            fig, axs = plt.subplots(2, 2, figsize=(10, 8))

            # Display context block image.
            if context_img.shape[0] == 1:
                axs[0, 0].imshow(context_img.squeeze(), cmap='gray')
            else:
                axs[0, 0].imshow(context_img.permute(1, 2, 0))
            axs[0, 0].set_title("Context Block")
            axs[0, 0].axis("off")

            # Display context feature heatmap.
            im0 = axs[0, 1].imshow(context_heat, cmap="viridis")
            axs[0, 1].set_title("Context Feature Heatmap")
            axs[0, 1].axis("off")
            fig.colorbar(im0, ax=axs[0, 1])

            # Display target block image.
            if target_img.shape[0] == 1:
                axs[1, 0].imshow(target_img.squeeze(), cmap='gray')
            else:
                axs[1, 0].imshow(target_img.permute(1, 2, 0))
            axs[1, 0].set_title("Target Block")
            axs[1, 0].axis("off")

            # Display target feature heatmap.
            im1 = axs[1, 1].imshow(target_heat, cmap="viridis")
            axs[1, 1].set_title("Target Feature Heatmap")
            axs[1, 1].axis("off")
            fig.colorbar(im1, ax=axs[1, 1])

            # Save the visualization figure with epoch and batch number.
            viz_path = os.path.join(viz_dir, f"epoch{epoch+1}_batch{batch_idx+1}.png")
            plt.savefig(viz_path)
            plt.close(fig)'''

    epoch_loss = running_loss / len(train_dataset_aav_ijepa)
    epoch_time = time.time() - epoch_start_time
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {epoch_loss:.10f} - Epoch Time: {epoch_time:.2f}s")

    # Save checkpoint if current epoch loss is lower than previous best.
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        checkpoint = {
            'epoch': epoch+1,
            'context_encoder_state_dict': context_encoder.state_dict(),
            'target_encoder_state_dict': target_encoder.state_dict(),
            'predictor_state_dict': predictor.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': epoch_loss
        }
        torch.save(checkpoint, os.path.join(base_dir,"ijepa_checkpoint_best.pth"))
        print(f"Checkpoint saved at epoch {epoch+1} with loss {epoch_loss:.4f}")


total_train_time = time.time() - total_start_time
print(f"Total Training Time: {total_train_time:.2f}s")

                                                            

Epoch 1/150 - Train Loss: 9933.0125395605 - Epoch Time: 107.27s
Checkpoint saved at epoch 1 with loss 9933.0125


                                                            

Epoch 2/150 - Train Loss: 0.0396372993 - Epoch Time: 107.37s
Checkpoint saved at epoch 2 with loss 0.0396


                                                            

Epoch 3/150 - Train Loss: 0.0000303999 - Epoch Time: 106.85s
Checkpoint saved at epoch 3 with loss 0.0000


                                                            

Epoch 4/150 - Train Loss: 0.0000105425 - Epoch Time: 106.96s
Checkpoint saved at epoch 4 with loss 0.0000


                                                            

Epoch 5/150 - Train Loss: 0.0000105456 - Epoch Time: 106.76s


                                                            

Epoch 6/150 - Train Loss: 0.0000107935 - Epoch Time: 106.89s


                                                            

Epoch 7/150 - Train Loss: 0.0000107122 - Epoch Time: 106.96s


                                                            

Epoch 8/150 - Train Loss: 0.0000107772 - Epoch Time: 106.98s


                                                            

Epoch 9/150 - Train Loss: 0.0000105551 - Epoch Time: 107.09s


                                                             

Epoch 10/150 - Train Loss: 0.0000106493 - Epoch Time: 106.91s


                                                             

Epoch 11/150 - Train Loss: 0.0000106900 - Epoch Time: 106.93s


                                                             

Epoch 12/150 - Train Loss: 0.0000105622 - Epoch Time: 107.12s


                                                             

Epoch 13/150 - Train Loss: 0.0000106139 - Epoch Time: 107.05s


                                                             

Epoch 14/150 - Train Loss: 0.0000105650 - Epoch Time: 107.04s


                                                             

Epoch 15/150 - Train Loss: 0.0000106657 - Epoch Time: 107.06s


                                                             

Epoch 16/150 - Train Loss: 0.0000105473 - Epoch Time: 106.97s


                                                             

Epoch 17/150 - Train Loss: 0.0000105052 - Epoch Time: 107.02s
Checkpoint saved at epoch 17 with loss 0.0000


                                                             

Epoch 18/150 - Train Loss: 0.0000106344 - Epoch Time: 107.25s


                                                             

Epoch 19/150 - Train Loss: 0.0000107107 - Epoch Time: 107.20s


                                                             

Epoch 20/150 - Train Loss: 0.0000106454 - Epoch Time: 107.18s


                                                             

Epoch 21/150 - Train Loss: 0.0000106319 - Epoch Time: 107.21s


                                                             

Epoch 22/150 - Train Loss: 0.0000111355 - Epoch Time: 107.12s


                                                             

Epoch 23/150 - Train Loss: 0.0000105512 - Epoch Time: 107.28s


                                                             

Epoch 24/150 - Train Loss: 0.0000074876 - Epoch Time: 107.26s
Checkpoint saved at epoch 24 with loss 0.0000


                                                             

Epoch 25/150 - Train Loss: 0.0000005274 - Epoch Time: 106.99s
Checkpoint saved at epoch 25 with loss 0.0000


                                                             

Epoch 26/150 - Train Loss: 0.0000004784 - Epoch Time: 107.10s
Checkpoint saved at epoch 26 with loss 0.0000


                                                             

Epoch 27/150 - Train Loss: 0.0000004783 - Epoch Time: 106.97s
Checkpoint saved at epoch 27 with loss 0.0000


                                                             

Epoch 28/150 - Train Loss: 0.0000004836 - Epoch Time: 107.02s


                                                             

Epoch 29/150 - Train Loss: 0.0000004834 - Epoch Time: 107.35s


                                                             

Epoch 30/150 - Train Loss: 0.0000004713 - Epoch Time: 107.13s
Checkpoint saved at epoch 30 with loss 0.0000


                                                             

Epoch 31/150 - Train Loss: 0.0000004805 - Epoch Time: 107.09s


                                                             

Epoch 32/150 - Train Loss: 0.0000004824 - Epoch Time: 107.04s


                                                             

Epoch 33/150 - Train Loss: 0.0000004764 - Epoch Time: 107.18s


                                                             

Epoch 34/150 - Train Loss: 0.0000004753 - Epoch Time: 107.44s


                                                             

Epoch 35/150 - Train Loss: 0.0000004814 - Epoch Time: 107.14s


                                                             

Epoch 36/150 - Train Loss: 0.0000005078 - Epoch Time: 107.34s


                                                             

Epoch 37/150 - Train Loss: 0.0000004749 - Epoch Time: 107.47s


                                                             

Epoch 38/150 - Train Loss: 0.0000004879 - Epoch Time: 107.13s


                                                             

Epoch 39/150 - Train Loss: 0.0000005223 - Epoch Time: 107.34s


                                                             

Epoch 40/150 - Train Loss: 0.0000005056 - Epoch Time: 107.44s


                                                             

Epoch 41/150 - Train Loss: 0.0000004725 - Epoch Time: 107.09s


                                                             

Epoch 42/150 - Train Loss: 0.0000004950 - Epoch Time: 107.40s


                                                             

Epoch 43/150 - Train Loss: 0.0000005075 - Epoch Time: 107.35s


                                                             

Epoch 44/150 - Train Loss: 0.0000004958 - Epoch Time: 107.29s


                                                             

Epoch 45/150 - Train Loss: 0.0000005210 - Epoch Time: 107.41s


                                                             

Epoch 46/150 - Train Loss: 0.0000005225 - Epoch Time: 107.31s


                                                             

Epoch 47/150 - Train Loss: 0.0000005026 - Epoch Time: 107.40s


                                                             

Epoch 48/150 - Train Loss: 0.0000004808 - Epoch Time: 107.22s


                                                             

Epoch 49/150 - Train Loss: 0.0000005004 - Epoch Time: 107.17s


                                                             

Epoch 50/150 - Train Loss: 0.0000005338 - Epoch Time: 107.36s


                                                             

Epoch 51/150 - Train Loss: 0.0000004779 - Epoch Time: 107.37s


                                                             

Epoch 52/150 - Train Loss: 0.0000005533 - Epoch Time: 107.44s


                                                             

Epoch 53/150 - Train Loss: 0.0000004996 - Epoch Time: 107.30s


                                                             

Epoch 54/150 - Train Loss: 0.0000004975 - Epoch Time: 107.43s


                                                             

Epoch 55/150 - Train Loss: 0.0000005309 - Epoch Time: 107.13s


                                                             

Epoch 56/150 - Train Loss: 0.0000006238 - Epoch Time: 107.48s


                                                             

Epoch 57/150 - Train Loss: 0.0000005418 - Epoch Time: 107.53s


                                                             

Epoch 58/150 - Train Loss: 0.0000006051 - Epoch Time: 107.46s


                                                             

Epoch 59/150 - Train Loss: 0.0000005226 - Epoch Time: 107.40s


                                                             

Epoch 60/150 - Train Loss: 0.0000005147 - Epoch Time: 107.42s


                                                             

Epoch 61/150 - Train Loss: 0.0000005316 - Epoch Time: 107.28s


                                                             

Epoch 62/150 - Train Loss: 0.0000005201 - Epoch Time: 107.35s


                                                             

Epoch 63/150 - Train Loss: 0.0000006084 - Epoch Time: 107.20s


                                                             

Epoch 64/150 - Train Loss: 0.0000005030 - Epoch Time: 107.34s


                                                             

Epoch 65/150 - Train Loss: 0.0000005373 - Epoch Time: 107.47s


                                                             

Epoch 66/150 - Train Loss: 0.0000005435 - Epoch Time: 107.32s


                                                             

Epoch 67/150 - Train Loss: 0.0000005381 - Epoch Time: 107.09s


                                                             

Epoch 68/150 - Train Loss: 0.0000005580 - Epoch Time: 107.34s


                                                             

Epoch 69/150 - Train Loss: 0.0000005703 - Epoch Time: 107.40s


                                                             

Epoch 70/150 - Train Loss: 0.0000007013 - Epoch Time: 107.38s


                                                             

Epoch 71/150 - Train Loss: 0.0000007536 - Epoch Time: 107.36s


                                                             

Epoch 72/150 - Train Loss: 0.0000006479 - Epoch Time: 107.19s


                                                             

Epoch 73/150 - Train Loss: 0.0000006654 - Epoch Time: 107.58s


                                                             

Epoch 74/150 - Train Loss: 0.0000007969 - Epoch Time: 107.47s


                                                             

Epoch 75/150 - Train Loss: 0.0000006994 - Epoch Time: 107.36s


                                                             

Epoch 76/150 - Train Loss: 0.0000010518 - Epoch Time: 107.35s


                                                             

Epoch 77/150 - Train Loss: 0.0000006269 - Epoch Time: 107.35s


                                                             

Epoch 78/150 - Train Loss: 0.0000008526 - Epoch Time: 107.50s


                                                             

Epoch 79/150 - Train Loss: 0.0000012404 - Epoch Time: 107.49s


                                                             

Epoch 80/150 - Train Loss: 0.0000010181 - Epoch Time: 107.50s


                                                             

Epoch 81/150 - Train Loss: 0.0000014421 - Epoch Time: 107.45s


                                                             

Epoch 82/150 - Train Loss: 0.0000020086 - Epoch Time: 107.52s


                                                             

Epoch 83/150 - Train Loss: 0.0000013026 - Epoch Time: 107.18s


                                                             

Epoch 84/150 - Train Loss: 0.0000027119 - Epoch Time: 107.41s


                                                             

Epoch 85/150 - Train Loss: 0.0000108288 - Epoch Time: 107.61s


                                                             

Epoch 86/150 - Train Loss: 0.0335191947 - Epoch Time: 107.51s


                                                             

Epoch 87/150 - Train Loss: 0.0002173877 - Epoch Time: 107.59s


                                                             

Epoch 88/150 - Train Loss: 0.0000008129 - Epoch Time: 107.31s


                                                             

Epoch 89/150 - Train Loss: 0.0000011182 - Epoch Time: 107.48s


                                                             

Epoch 90/150 - Train Loss: 0.0000006072 - Epoch Time: 107.37s


                                                             

Epoch 91/150 - Train Loss: 0.0000017031 - Epoch Time: 107.53s


                                                             

Epoch 92/150 - Train Loss: 0.0000033567 - Epoch Time: 107.39s


                                                             

Epoch 93/150 - Train Loss: 0.0012622321 - Epoch Time: 107.54s


                                                             

Epoch 94/150 - Train Loss: 0.0157197408 - Epoch Time: 107.15s


                                                             

Epoch 95/150 - Train Loss: 0.0000031637 - Epoch Time: 107.59s


                                                             

Epoch 96/150 - Train Loss: 0.0000006788 - Epoch Time: 107.44s


                                                             

Epoch 97/150 - Train Loss: 0.0000011293 - Epoch Time: 107.46s


                                                             

Epoch 98/150 - Train Loss: 0.0000017528 - Epoch Time: 107.28s


                                                             

Epoch 99/150 - Train Loss: 0.0000054151 - Epoch Time: 107.05s


                                                              

Epoch 100/150 - Train Loss: 0.0240594276 - Epoch Time: 107.43s


                                                              

Epoch 101/150 - Train Loss: 0.0000903395 - Epoch Time: 107.11s


                                                              

Epoch 102/150 - Train Loss: 0.0000004740 - Epoch Time: 107.36s


                                                              

Epoch 103/150 - Train Loss: 0.0000004091 - Epoch Time: 107.03s
Checkpoint saved at epoch 103 with loss 0.0000


                                                              

Epoch 104/150 - Train Loss: 0.0000004067 - Epoch Time: 107.14s
Checkpoint saved at epoch 104 with loss 0.0000


                                                              

Epoch 105/150 - Train Loss: 0.0000004230 - Epoch Time: 107.20s


                                                              

Epoch 106/150 - Train Loss: 0.0000004331 - Epoch Time: 107.21s


                                                              

Epoch 107/150 - Train Loss: 0.0000004418 - Epoch Time: 107.18s


                                                              

Epoch 108/150 - Train Loss: 0.0000004794 - Epoch Time: 107.13s


                                                              

Epoch 109/150 - Train Loss: 0.0000004372 - Epoch Time: 107.19s


                                                              

Epoch 110/150 - Train Loss: 0.0000004795 - Epoch Time: 107.16s


                                                              

Epoch 111/150 - Train Loss: 0.0000004915 - Epoch Time: 107.21s


                                                              

Epoch 112/150 - Train Loss: 0.0000006456 - Epoch Time: 107.18s


                                                              

Epoch 113/150 - Train Loss: 0.0000008599 - Epoch Time: 107.00s


                                                              

Epoch 114/150 - Train Loss: 0.0001194077 - Epoch Time: 107.29s


                                                              

Epoch 115/150 - Train Loss: 0.0076031929 - Epoch Time: 107.31s


                                                              

Epoch 116/150 - Train Loss: 0.0000017180 - Epoch Time: 107.11s


                                                              

Epoch 117/150 - Train Loss: 0.0000005330 - Epoch Time: 107.13s


                                                              

Epoch 118/150 - Train Loss: 0.0000005618 - Epoch Time: 107.19s


                                                              

Epoch 119/150 - Train Loss: 0.0000005691 - Epoch Time: 107.14s


                                                              

Epoch 120/150 - Train Loss: 0.0000005377 - Epoch Time: 107.23s


                                                              

Epoch 121/150 - Train Loss: 0.0000005937 - Epoch Time: 107.22s


                                                              

Epoch 122/150 - Train Loss: 0.0000005797 - Epoch Time: 107.14s


                                                              

Epoch 123/150 - Train Loss: 0.0000006270 - Epoch Time: 107.09s


                                                              

Epoch 124/150 - Train Loss: 0.0000005586 - Epoch Time: 107.07s


                                                              

Epoch 125/150 - Train Loss: 0.0000005701 - Epoch Time: 107.04s


                                                              

Epoch 126/150 - Train Loss: 0.0000008348 - Epoch Time: 107.07s


                                                              

Epoch 127/150 - Train Loss: 0.0033423548 - Epoch Time: 107.10s


                                                              

Epoch 128/150 - Train Loss: 0.0000105633 - Epoch Time: 107.30s


                                                              

Epoch 129/150 - Train Loss: 0.0000007841 - Epoch Time: 107.13s


                                                              

Epoch 130/150 - Train Loss: 0.0000010025 - Epoch Time: 107.07s


                                                              

Epoch 131/150 - Train Loss: 0.0000008815 - Epoch Time: 107.12s


                                                              

Epoch 132/150 - Train Loss: 0.0000015642 - Epoch Time: 107.28s


                                                              

Epoch 133/150 - Train Loss: 0.0000010456 - Epoch Time: 107.15s


                                                              

Epoch 134/150 - Train Loss: 0.0000008773 - Epoch Time: 107.41s


                                                              

Epoch 135/150 - Train Loss: 0.0000012329 - Epoch Time: 107.15s


                                                              

Epoch 136/150 - Train Loss: 0.0000016365 - Epoch Time: 107.16s


                                                              

Epoch 137/150 - Train Loss: 0.0000014308 - Epoch Time: 107.07s


                                                              

Epoch 138/150 - Train Loss: 0.0000013933 - Epoch Time: 107.31s


                                                              

Epoch 139/150 - Train Loss: 0.0013548351 - Epoch Time: 107.24s


                                                              

Epoch 140/150 - Train Loss: 0.0000191366 - Epoch Time: 107.16s


                                                              

Epoch 141/150 - Train Loss: 0.0000275483 - Epoch Time: 107.28s


                                                              

Epoch 142/150 - Train Loss: 0.0003201611 - Epoch Time: 107.13s


                                                              

Epoch 143/150 - Train Loss: 0.0005432271 - Epoch Time: 107.05s


                                                              

Epoch 144/150 - Train Loss: 0.0010315167 - Epoch Time: 107.13s


                                                              

Epoch 145/150 - Train Loss: 0.0008990226 - Epoch Time: 107.07s


                                                              

Epoch 146/150 - Train Loss: 0.0021160119 - Epoch Time: 107.01s


                                                              

Epoch 147/150 - Train Loss: 0.0754867362 - Epoch Time: 107.03s


                                                              

Epoch 148/150 - Train Loss: 0.0003909256 - Epoch Time: 107.07s


                                                              

Epoch 149/150 - Train Loss: 0.0000001251 - Epoch Time: 107.10s
Checkpoint saved at epoch 149 with loss 0.0000


                                                              

Epoch 150/150 - Train Loss: 0.0000000020 - Epoch Time: 107.08s
Checkpoint saved at epoch 150 with loss 0.0000
Total Training Time: 16125.67s


# Train the Classifier

## 1. Load the Saved Checkpoint for the Self-Supervised Model

In [21]:
checkpoint = torch.load(os.path.join(base_dir, "ijepa_checkpoint_best.pth"))
context_encoder.load_state_dict(checkpoint['context_encoder_state_dict'])
# Freeze the context encoder.
context_encoder.eval()
for param in context_encoder.parameters():
    param.requires_grad = False

  checkpoint = torch.load(os.path.join(base_dir, "ijepa_checkpoint_best.pth"))


## 2. Define the Classifier

In [22]:
num_classes = 3  # Adjust this number based on your dataset.
classifier = nn.Linear(768, num_classes).cuda()


## 3. Set Up Optimizer and Loss Criterion

In [23]:
clf_optimizer = optim.Adam(classifier.parameters(), lr=1e-3)
criterion_cls = nn.CrossEntropyLoss()


## 4. Training Loop for the Classifier (Using Training Data Only)

In [24]:
num_epochs_clf = 500
best_train_acc = 0.0  # Best training accuracy so far.

for epoch in range(num_epochs_clf):
    epoch_start_time = time.time()
    classifier.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0

    for context_block, _, label, _  in train_loader_aav_ijepa:
        context_block = context_block.cuda()  # [B, C, 224, 224]
        label = label.cuda()

        with torch.no_grad():
            features = context_encoder(context_block)  # [B, 768]

        logits = classifier(features)  # [B, num_classes]
        loss = criterion_cls(logits, label)

        clf_optimizer.zero_grad()
        loss.backward()
        clf_optimizer.step()

        running_loss += loss.item() * context_block.size(0)
        preds = logits.argmax(dim=1)
        correct_train += (preds == label).sum().item()
        total_train += label.size(0)

    epoch_train_loss = running_loss / len(train_loader_aav_ijepa)
    epoch_train_acc = correct_train / total_train
    epoch_time = time.time() - epoch_start_time

    print(f"Epoch {epoch+1}/{num_epochs_clf} - Train Loss: {epoch_train_loss:.10f} | Train Acc: {epoch_train_acc*100:.10f}% | Time: {epoch_time:.2f}s")

    # Save checkpoint if training accuracy improves.
    if epoch_train_acc > best_train_acc:
        best_train_acc = epoch_train_acc
        checkpoint = {
            'epoch': epoch+1,
            'classifier_state_dict': classifier.state_dict(),
            'optimizer_state_dict': clf_optimizer.state_dict(),
            'train_loss': epoch_train_loss,
            'train_acc': epoch_train_acc
        }
        torch.save(checkpoint, os.path.join(base_dir,"ijepa_classifier_best.pth"))
        print(f"Checkpoint saved at epoch {epoch+1} with Train Acc: {epoch_train_acc*100:.10f}%")

print("Classifier training complete!")

Epoch 1/500 - Train Loss: 35.3247123539 | Train Acc: 41.3989426596% | Time: 17.46s
Checkpoint saved at epoch 1 with Train Acc: 41.3989426596%
Epoch 2/500 - Train Loss: 35.6016814446 | Train Acc: 40.1789345262% | Time: 14.94s
Epoch 3/500 - Train Loss: 34.9700855559 | Train Acc: 42.4562830419% | Time: 14.98s
Checkpoint saved at epoch 3 with Train Acc: 42.4562830419%
Epoch 4/500 - Train Loss: 34.9525565791 | Train Acc: 42.0089467263% | Time: 14.95s
Epoch 5/500 - Train Loss: 34.5196446134 | Train Acc: 44.4489629931% | Time: 14.99s
Checkpoint saved at epoch 5 with Train Acc: 44.4489629931%
Epoch 6/500 - Train Loss: 35.3358164964 | Train Acc: 42.8629524197% | Time: 14.95s
Epoch 7/500 - Train Loss: 35.1088398478 | Train Acc: 41.8869459130% | Time: 14.95s
Epoch 8/500 - Train Loss: 34.9299192382 | Train Acc: 40.9516063440% | Time: 14.90s
Epoch 9/500 - Train Loss: 35.6973246259 | Train Acc: 40.2602684018% | Time: 14.97s
Epoch 10/500 - Train Loss: 35.0103146751 | Train Acc: 41.7649450996% | Time:

In [25]:
#from google.colab import output
#output.eval_js('google.colab.kernel.disconnect()')