<a href="https://colab.research.google.com/github/raihanewubd/selfSupervised/blob/main/i_jepa_aav_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
from google.colab import drive

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import ViTForImageClassification, ViTImageProcessor
from torchvision.models import vit_b_16



import timm
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

from collections import Counter


import random
import os
import copy
import time
import pickle



In [11]:
#drive.mount('/content/drive')

In [12]:
# Define base directory and file name for saving the classifier checkpoint.
base_dir = "/kaggle/working"
#base_dir = "/content/drive/MyDrive/AAVDATASET/spectrogram"

In [13]:
data_dir = '/kaggle/input/aav-spectrogram/spectrogram'
#data_dir = os.path.join(base_dir,'dataset')
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = datasets.ImageFolder(root=data_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [14]:
# Set device for GPU acceleration if available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [15]:
def extract_blocks(image, context_scale=0.85, target_scale=0.2, num_targets=4, max_overlap=0.5):
    # Extract a central context block.
    _, H, W = image.shape
    context_size = int(context_scale * H)
    top = (H - context_size) // 2
    left = (W - context_size) // 2
    context_block = image[:, top:top+context_size, left:left+context_size]
    context_block = torch.nn.functional.interpolate(
        context_block.unsqueeze(0),
        size=(224, 224),
        mode='bilinear',
        align_corners=False
    ).squeeze(0)

    # Extract num_targets target blocks randomly.
    target_blocks = []
    for _ in range(num_targets):
        target_size = int(target_scale * H)
        top_t = random.randint(0, H - target_size)
        left_t = random.randint(0, W - target_size)
        target_block = image[:, top_t:top_t+target_size, left_t:left_t+target_size]
        target_block = torch.nn.functional.interpolate(
            target_block.unsqueeze(0),
            size=(224, 224),
            mode='bilinear',
            align_corners=False
        ).squeeze(0)
        target_blocks.append(target_block)
    target_blocks = torch.stack(target_blocks)
    return context_block, target_blocks, (top, left, context_size), None



In [16]:
def process_sample(sample, context_scale, target_scale, num_targets):
    # Unpack sample: sample is ((img, label), image_path)
    (img, label), image_path = sample
    # Move image to GPU if available.
    img = img.to(device)
    context_block, target_blocks, _, _ = extract_blocks(img, context_scale, target_scale, num_targets)
    # Bring results back to CPU before caching.
    return (context_block.cpu(), target_blocks.cpu(), label, image_path)

In [17]:
class PrecomputedIJEPADataset(Dataset):
    def __init__(self, base_dataset, context_scale=0.85, target_scale=0.2, num_targets=4, cache_file=None):
        self.cache_file = cache_file
        if cache_file and os.path.exists(cache_file):
            # Load precomputed data from disk.
            with open(cache_file, 'rb') as f:
                self.data = pickle.load(f)
        else:
            # Create a list of samples along with their original image paths (if available) using a progress bar.
            if hasattr(base_dataset, 'samples'):
                base_samples = [
                    (base_dataset[i], base_dataset.samples[i][0])
                    for i in tqdm(range(len(base_dataset)), desc="Loading samples")
                ]
            else:
                base_samples = [
                    (sample, None) for sample in tqdm(base_dataset, desc="Loading samples")
                ]

            # Process samples sequentially with a progress bar.
            self.data = []
            for sample in tqdm(base_samples, desc="Processing samples"):
                result = process_sample(sample, context_scale, target_scale, num_targets)
                self.data.append(result)

            if cache_file:
                with open(cache_file, 'wb') as f:
                    pickle.dump(self.data, f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


In [18]:


# Timing the loading of the dataset and DataLoader
cache_path = os.path.join(base_dir,"precomputed_fulldataset_aav.pkl")
print(cache_path)
start_time = time.time()
dataset_aav_ijepa = PrecomputedIJEPADataset(dataset, cache_file=cache_path)
end_time_train_ijepa_dataset = time.time()
dataloader_aav_ijepa = DataLoader(dataset_aav_ijepa, batch_size=32, shuffle=True)
end_time = time.time()
print(f"Time taken to load dataset: {end_time_train_ijepa_dataset - start_time:.4f} seconds and DataLoader: {end_time - end_time_train_ijepa_dataset:.4f} seconds")


/kaggle/working/precomputed_fulldataset_aav.pkl


Loading samples: 100%|██████████| 3513/3513 [00:35<00:00, 97.78it/s] 
Processing samples: 100%|██████████| 3513/3513 [00:12<00:00, 283.78it/s]


Time taken to load dataset: 84.2955 seconds and DataLoader: 0.0014 seconds


In [19]:
num_images = len(dataset_aav_ijepa)
print(f"Number of images in the dataset: {num_images}")

Number of images in the dataset: 3513


In [20]:
total_batches = len(dataloader_aav_ijepa)
print("Total number of batches:", total_batches)

Total number of batches: 110


In [21]:
'''import torch
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split


# Define the split ratio (e.g., 80% train, 20% test)
train_ratio = 0.8
test_ratio = 1 - train_ratio

# Get the total number of samples in the dataset.
num_samples = len(dataset_aav_ijepa)

# Create a list of indices for all samples in the dataset.
indices = list(range(num_samples))

# Split the indices into train and test sets using train_test_split.
train_indices, test_indices = train_test_split(indices, test_size=test_ratio, random_state=42)  # Set random_state for reproducibility.

# Create Subset datasets for train and test using the split indices.
train_dataset_aav_ijepa = Subset(dataset_aav_ijepa, train_indices)
test_dataset_aav_ijepa = Subset(dataset_aav_ijepa, test_indices)

# Create DataLoaders for the train and test datasets.
train_loader_aav_ijepa = torch.utils.data.DataLoader(train_dataset_aav_ijepa, batch_size=32, shuffle=True)
test_loader_aav_ijepa = torch.utils.data.DataLoader(test_dataset_aav_ijepa, batch_size=32, shuffle=False)  # No need to shuffle the test set.

print(f"Training set size: {len(train_dataset_aav_ijepa)}")
print(f"Testing set size: {len(test_dataset_aav_ijepa)}")'''

'import torch\nfrom torch.utils.data import Subset\nfrom sklearn.model_selection import train_test_split\n\n\n# Define the split ratio (e.g., 80% train, 20% test)\ntrain_ratio = 0.8\ntest_ratio = 1 - train_ratio\n\n# Get the total number of samples in the dataset.\nnum_samples = len(dataset_aav_ijepa)\n\n# Create a list of indices for all samples in the dataset.\nindices = list(range(num_samples))\n\n# Split the indices into train and test sets using train_test_split.\ntrain_indices, test_indices = train_test_split(indices, test_size=test_ratio, random_state=42)  # Set random_state for reproducibility.\n\n# Create Subset datasets for train and test using the split indices.\ntrain_dataset_aav_ijepa = Subset(dataset_aav_ijepa, train_indices)\ntest_dataset_aav_ijepa = Subset(dataset_aav_ijepa, test_indices)\n\n# Create DataLoaders for the train and test datasets.\ntrain_loader_aav_ijepa = torch.utils.data.DataLoader(train_dataset_aav_ijepa, batch_size=32, shuffle=True)\ntest_loader_aav_ij

In [22]:
import torch
from torch.utils.data import Subset, random_split
from sklearn.model_selection import train_test_split

# Define the split ratios (e.g., 70% train, 15% validation, 15% test)
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

# Get the total number of samples in the dataset.
num_samples = len(dataset_aav_ijepa)

# 1. Split into train and (validation + test)
train_indices, val_test_indices = train_test_split(
    list(range(num_samples)),
    test_size=val_ratio + test_ratio,
    random_state=42  # Set random_state for reproducibility
)

# 2. Split (validation + test) into validation and test
val_indices, test_indices = train_test_split(
    val_test_indices,
    test_size=test_ratio / (val_ratio + test_ratio),
    random_state=42  # Set random_state for reproducibility
)

# Create Subset datasets for train, validation, and test
train_dataset_aav_ijepa = Subset(dataset_aav_ijepa, train_indices)
val_dataset_aav_ijepa = Subset(dataset_aav_ijepa, val_indices)
test_dataset_aav_ijepa = Subset(dataset_aav_ijepa, test_indices)

# Create DataLoaders for train, validation, and test
train_loader_aav_ijepa = torch.utils.data.DataLoader(train_dataset_aav_ijepa, batch_size=32, shuffle=True)
val_loader_aav_ijepa = torch.utils.data.DataLoader(val_dataset_aav_ijepa, batch_size=32, shuffle=False)
test_loader_aav_ijepa = torch.utils.data.DataLoader(test_dataset_aav_ijepa, batch_size=32, shuffle=False)

print(f"Training set size: {len(train_dataset_aav_ijepa)}")
print(f"Validation set size: {len(val_dataset_aav_ijepa)}")
print(f"Testing set size: {len(test_dataset_aav_ijepa)}")

Training set size: 2459
Validation set size: 527
Testing set size: 527


In [23]:
def get_vit_encoder():
    model = vit_b_16(pretrained=False)
    model.heads = nn.Identity()  # remove classification head
    return model

In [24]:

context_encoder = get_vit_encoder().cuda()
target_encoder  = get_vit_encoder().cuda()
target_encoder.load_state_dict(context_encoder.state_dict())



<All keys matched successfully>

In [25]:
class Predictor(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=768, output_dim=768, num_targets=4):
        super().__init__()
        self.num_targets = num_targets
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim * num_targets)
        )
    def forward(self, context_repr):
        pred = self.mlp(context_repr)
        # Reshape to [B, num_targets, output_dim]
        return pred.view(-1, self.num_targets, pred.size(-1) // self.num_targets)

In [26]:
# 6. Set up optimizer, loss, and EMA update (same as your CIFAR code).
predictor = Predictor().cuda()
optimizer = optim.Adam(list(context_encoder.parameters()) + list(predictor.parameters()), lr=1e-1)
criterion = nn.MSELoss()
ema_decay = 0.99

In [27]:
@torch.no_grad()
def update_ema(model, model_ema, beta):
    for param, param_ema in zip(model.parameters(), model_ema.parameters()):
        param_ema.data.mul_(beta).add_(param.data, alpha=1 - beta)

In [30]:
sample = next(iter(train_dataset_aav_ijepa))
for i, item in enumerate(sample):
    print(f"Item {i}: shape = {item.shape if hasattr(item, 'shape') else type(item)}")
    print(f"Item {i}: shape = {item}")

Item 0: shape = torch.Size([3, 224, 224])
Item 0: shape = tensor([[[-0.5137, -0.5137, -0.5137,  ..., -0.5294, -0.5294, -0.5294],
         [-0.5137, -0.5137, -0.5100,  ..., -0.5294, -0.5294, -0.5294],
         [-0.5089, -0.5089, -0.5070,  ..., -0.5197, -0.5197, -0.5197],
         ...,
         [-0.3569, -0.3569, -0.3569,  ..., -0.3155, -0.3155, -0.3155],
         [-0.3926, -0.3926, -0.3926,  ..., -0.3327, -0.3327, -0.3327],
         [-0.5137, -0.5137, -0.5137,  ..., -0.4902, -0.4902, -0.4902]],

        [[ 0.4510,  0.4510,  0.4510,  ...,  0.4118,  0.4118,  0.4118],
         [ 0.4510,  0.4510,  0.4510,  ...,  0.3671,  0.3633,  0.3633],
         [ 0.4558,  0.4558,  0.4558,  ...,  0.4141,  0.4123,  0.4123],
         ...,
         [ 0.5245,  0.5245,  0.5245,  ...,  0.5383,  0.5383,  0.5383],
         [ 0.5116,  0.5116,  0.5116,  ...,  0.5315,  0.5315,  0.5315],
         [ 0.4510,  0.4510,  0.4510,  ...,  0.4588,  0.4588,  0.4588]],

        [[-0.1373, -0.1373, -0.1373,  ..., -0.1059, -0.105

In [33]:
# Create a directory for visualizations if it doesn't exist.
#viz_dir = "/kaggle/working/viz"
#os.makedirs(viz_dir, exist_ok=True)

num_epochs = 5
ema_decay = 0.1
best_loss = float('inf')
total_start_time = time.time()

for epoch in range(num_epochs):
    epoch_start_time = time.time()
    context_encoder.train()
    predictor.train()
    running_loss = 0.0

    # Enumerate over batches with a progress bar.
    for batch_idx, (context_block, target_blocks, class_label, filepath) in enumerate(tqdm(train_dataset_aav_ijepa, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)):
        context_block = context_block.cuda()            # [B, C, 224, 224]
        target_blocks = target_blocks.cuda()              # [B, num_targets, C, 224, 224]

        # Forward pass through context encoder and predictor.
        context_repr = context_encoder(context_block)     # [B, 768]
        preds = predictor(context_repr)                   # [B, num_targets, 768]

        B, num_targets, C, Ht, Wt = target_blocks.shape
        target_blocks_flat = target_blocks.view(B * num_targets, C, Ht, Wt)
        with torch.no_grad():
            target_repr_flat = target_encoder(target_blocks_flat)
        target_repr = target_repr_flat.view(B, num_targets, -1)

        loss = criterion(preds, target_repr)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        update_ema(context_encoder, target_encoder, ema_decay)
        running_loss += loss.item() * context_block.size(0)

        # --- Visualization for first image of the current batch ---
        '''with torch.no_grad():
            # Get the first sample's context block and compute its feature vector.
            context_img = context_block[0].cpu()  # shape: [C, 224, 224]
            context_feat = context_encoder(context_block[0].unsqueeze(0)).cpu().squeeze(0)  # shape: [768]
            # Reshape feature vector to a 2D heatmap (24x32).
            context_heat = context_feat.view(24, 32).numpy()

            # For target, choose the first target block of the first sample.
            target_img = target_blocks[0][0].cpu()  # shape: [C, 224, 224]
            target_feat = target_encoder(target_blocks[0][0].unsqueeze(0).to(context_block.device)).cpu().squeeze(0)
            target_heat = target_feat.view(24, 32).numpy()

            # Plot the images and corresponding heatmaps.
            fig, axs = plt.subplots(2, 2, figsize=(10, 8))

            # Display context block image.
            if context_img.shape[0] == 1:
                axs[0, 0].imshow(context_img.squeeze(), cmap='gray')
            else:
                axs[0, 0].imshow(context_img.permute(1, 2, 0))
            axs[0, 0].set_title("Context Block")
            axs[0, 0].axis("off")

            # Display context feature heatmap.
            im0 = axs[0, 1].imshow(context_heat, cmap="viridis")
            axs[0, 1].set_title("Context Feature Heatmap")
            axs[0, 1].axis("off")
            fig.colorbar(im0, ax=axs[0, 1])

            # Display target block image.
            if target_img.shape[0] == 1:
                axs[1, 0].imshow(target_img.squeeze(), cmap='gray')
            else:
                axs[1, 0].imshow(target_img.permute(1, 2, 0))
            axs[1, 0].set_title("Target Block")
            axs[1, 0].axis("off")

            # Display target feature heatmap.
            im1 = axs[1, 1].imshow(target_heat, cmap="viridis")
            axs[1, 1].set_title("Target Feature Heatmap")
            axs[1, 1].axis("off")
            fig.colorbar(im1, ax=axs[1, 1])

            # Save the visualization figure with epoch and batch number.
            viz_path = os.path.join(viz_dir, f"epoch{epoch+1}_batch{batch_idx+1}.png")
            plt.savefig(viz_path)
            plt.close(fig)'''

    epoch_loss = running_loss / len(train_dataset_aav_ijepa)
    epoch_time = time.time() - epoch_start_time
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {epoch_loss:.10f} - Epoch Time: {epoch_time:.2f}s")

    # Save checkpoint if current epoch loss is lower than previous best.
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        checkpoint = {
            'epoch': epoch+1,
            'context_encoder_state_dict': context_encoder.state_dict(),
            'target_encoder_state_dict': target_encoder.state_dict(),
            'predictor_state_dict': predictor.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': epoch_loss
        }
        torch.save(checkpoint, os.path.join(base_dir,"ijepa_checkpoint_best.pth"))
        print(f"Checkpoint saved at epoch {epoch+1} with loss {epoch_loss:.4f}")


total_train_time = time.time() - total_start_time
print(f"Total Training Time: {total_train_time:.2f}s")

                                                   

ValueError: not enough values to unpack (expected 4, got 3)

# Train the Classifier

## 1. Load the Saved Checkpoint for the Self-Supervised Model

In [None]:
checkpoint = torch.load(os.path.join(base_dir, "ijepa_checkpoint_best.pth"))
context_encoder.load_state_dict(checkpoint['context_encoder_state_dict'])
# Freeze the context encoder.
context_encoder.eval()
for param in context_encoder.parameters():
    param.requires_grad = False

## 2. Define the Classifier

In [None]:
num_classes = 3  # Adjust this number based on your dataset.
classifier = nn.Linear(768, num_classes).cuda()


## 3. Set Up Optimizer and Loss Criterion

In [None]:
clf_optimizer = optim.Adam(classifier.parameters(), lr=1e-3)
criterion_cls = nn.CrossEntropyLoss()


## 4. Training Loop for the Classifier (Using Training Data Only)

In [None]:
num_epochs_clf = 5
best_train_acc = 0.0  # Best training accuracy so far.

for epoch in range(num_epochs_clf):
    epoch_start_time = time.time()
    classifier.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0

    for context_block, _, label in train_loader:
        context_block = context_block.cuda()  # [B, C, 224, 224]
        label = label.cuda()

        with torch.no_grad():
            features = context_encoder(context_block)  # [B, 768]

        logits = classifier(features)  # [B, num_classes]
        loss = criterion_cls(logits, label)

        clf_optimizer.zero_grad()
        loss.backward()
        clf_optimizer.step()

        running_loss += loss.item() * context_block.size(0)
        preds = logits.argmax(dim=1)
        correct_train += (preds == label).sum().item()
        total_train += label.size(0)

    epoch_train_loss = running_loss / len(train_ijepa_dataset)
    epoch_train_acc = correct_train / total_train
    epoch_time = time.time() - epoch_start_time

    print(f"Epoch {epoch+1}/{num_epochs_clf} - Train Loss: {epoch_train_loss:.10f} | Train Acc: {epoch_train_acc*100:.10f}% | Time: {epoch_time:.2f}s")

    # Save checkpoint if training accuracy improves.
    if epoch_train_acc > best_train_acc:
        best_train_acc = epoch_train_acc
        checkpoint = {
            'epoch': epoch+1,
            'classifier_state_dict': classifier.state_dict(),
            'optimizer_state_dict': clf_optimizer.state_dict(),
            'train_loss': epoch_train_loss,
            'train_acc': epoch_train_acc
        }
        torch.save(checkpoint, os.path.join(base_dir,"ijepa_classifier_best.pth"))
        print(f"Checkpoint saved at epoch {epoch+1} with Train Acc: {epoch_train_acc*100:.10f}%")

print("Classifier training complete!")

In [None]:
from google.colab import output
output.eval_js('google.colab.kernel.disconnect()')

In [None]:
!kill -9 -1