### Self Supervised Image Representation Prediction with I-JEPA

In [None]:
import torch
import torch.nn as nn
import copy, random
from einops import rearrange

# ——————————————————————————————————————————————————————————————————————————————
# 1) Patch Embed + Positional Embedding
# ——————————————————————————————————————————————————————————————————————————————
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=64):
        super().__init__()
        self.conv = nn.Conv2d(in_chans, embed_dim,
                              kernel_size=patch_size, stride=patch_size)
        self.grid = img_size // patch_size
        self.n_patches = self.grid * self.grid

    def forward(self, x):
        x = self.conv(x)                       # [B, D, H/ps, W/ps]
        x = rearrange(x, 'b d h w -> b (h w) d')
        return x                               # [B, N, D]

# ——————————————————————————————————————————————————————————————————————————————
# 2) I-JEPA model using PyTorch Transformers
# ——————————————————————————————————————————————————————————————————————————————
class IJEPA_base(nn.Module):
    def __init__(self,
                 img_size=32,
                 patch_size=4,
                 embed_dim=64,
                 enc_depth=6,
                 pred_depth=3,
                 num_heads=8,
                 M=4,
                 ema_m=0.996):
        super().__init__()
        # patch embedding
        self.patch_embed = PatchEmbed(img_size, patch_size,
                                      in_chans=3, embed_dim=embed_dim)
        N = self.patch_embed.n_patches
        self.pos_emb = nn.Parameter(torch.randn(1, N, embed_dim))

        # context & target encoders using TransformerEncoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, batch_first=True
        )
        self.ctx_enc = nn.TransformerEncoder(encoder_layer, num_layers=enc_depth)
        self.tgt_enc = copy.deepcopy(self.ctx_enc)
        for p in self.tgt_enc.parameters():
            p.requires_grad = False

        # predictor using TransformerEncoder as decoder proxy
        pred_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, batch_first=True
        )
        self.pred = nn.TransformerEncoder(pred_layer, num_layers=pred_depth)

        # mask token
        self.mask_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        self.M = M
        self.N = N
        self.grid = int(N**0.5)
        self.ema_m = ema_m

    @torch.no_grad()
    def update_ema(self):
        for p_ctx, p_tgt in zip(self.ctx_enc.parameters(),
                                self.tgt_enc.parameters()):
            p_tgt.data.mul_(self.ema_m) \
                      .add_(p_ctx.data, alpha=1 - self.ema_m)

    def sample_block(self, scale_range, ratio_range):
        S = random.uniform(*scale_range)
        R = random.uniform(*ratio_range)
        area = S * self.N
        h = int((area * R)**0.5)
        w = int((area / R)**0.5)
        h = max(1, min(self.grid, h))
        w = max(1, min(self.grid, w))
        i = random.randint(0, self.grid - h)
        j = random.randint(0, self.grid - w)
        idxs = [(i+di)*self.grid + (j+dj) for di in range(h) for dj in range(w)]
        return idxs

    def forward(self, img):
        B = img.size(0)
        # 1) patch embed + add pos
        x = self.patch_embed(img)      # [B,N,D]
        x = x + self.pos_emb

        # 2) compute target repr
        with torch.no_grad():
            sy = self.tgt_enc(x)       # [B,N,D]

        # 3) sample context blocks and mask
        ctx_idxs = [self.sample_block((0.85,1.0),(1.0,1.0)) for _ in range(B)]
        x_ctx = x.clone()
        for b in range(B):
            keep = set(ctx_idxs[b])
            mask = [i for i in range(self.N) if i not in keep]
            x_ctx[b, mask] = 0
        sx = self.ctx_enc(x_ctx)       # [B,N,D]

        # 4) sample M target blocks
        tgt_idxs = [[self.sample_block((0.15,0.2),(0.75,1.5))
                     for _ in range(self.M)] for _ in range(B)]

        # 5) predictor
        all_preds, all_targs = [], []
        for b in range(B):
            preds_b, targs_b = [], []
            for block in tgt_idxs[b]:
                seq = sx[b].clone()     # [N,D]
                for idx in block:
                    seq[idx] = self.mask_token + self.pos_emb[0, idx]
                out = self.pred(seq.unsqueeze(0))  # [1,N,D]
                preds_b.append(out[0, block, :])
                targs_b.append(sy[b, block, :])
            all_preds.append(torch.cat(preds_b, dim=0))
            all_targs.append(torch.cat(targs_b, dim=0))

        # 6) EMA update
        self.update_ema()

        return all_preds, all_targs


### Training Loop

In [None]:
import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

# Assume IJEPA_base and PatchEmbed are already defined/imported above

# --- 1) Hyperparameters ---
img_size      = 32
patch_size    = 4
embed_dim     = 64
enc_depth     = 6
pred_depth    = 3
num_heads     = 8
M             = 4
ema_m         = 0.996
batch_size    = 64
num_epochs    = 20
learning_rate = 1e-4

# --- 2) Custom PNG dataset ---
class PNGFolder(Dataset):
    def __init__(self, root, transform=None):
        self.paths = sorted(
            glob.glob(os.path.join(root, '*.png')),
            key=lambda p: int(os.path.splitext(os.path.basename(p))[0])
        )
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, 0  # label is unused

# --- 3) Prepare dataset & loader ---
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
])

train_dataset = PNGFolder(
    root='/Users/srirammandalika/Downloads/cifar-10/train',
    transform=transform
)
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,     # avoid pickling with custom class
    pin_memory=False   # pin_memory doesn't apply to MPS
)

# --- 4) Device selection (MPS > CUDA > CPU) ---
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(f'Using device: {device}')

# --- 5) Initialize model & optimizer ---
model = IJEPA_base(
    img_size=img_size,
    patch_size=patch_size,
    embed_dim=embed_dim,
    enc_depth=enc_depth,
    pred_depth=pred_depth,
    num_heads=num_heads,
    M=M,
    ema_m=ema_m
).to(device)

optimizer = torch.optim.AdamW(
    list(model.ctx_enc.parameters()) + list(model.pred.parameters()),
    lr=learning_rate,
    weight_decay=0.05
)

# --- 6) Training loop ---
for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0.0

    for imgs, _ in train_loader:
        imgs = imgs.to(device, non_blocking=False)
        preds_list, targs_list = model(imgs)

        # compute average L2 loss over all predicted patches
        loss_sum = 0.0
        patch_count = 0
        for preds, targs in zip(preds_list, targs_list):
            # ensure targs on same device
            targs = targs.to(device, non_blocking=False)
            loss_sum += F.mse_loss(preds, targs, reduction='sum')
            patch_count += preds.numel()
        loss = loss_sum / patch_count

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)

    avg_loss = running_loss / len(train_dataset)
    print(f"Epoch {epoch}/{num_epochs} — Loss: {avg_loss:.6f}")

# --- 7) Quick inference check ---
model.eval()
with torch.no_grad():
    batch, _ = next(iter(train_loader))
    batch = batch[:8].to(device, non_blocking=False)
    preds_list, targs_list = model(batch)
    for i, preds in enumerate(preds_list):
        print(f"Sample {i+1}: predicted {preds.shape[0]} patches × {preds.shape[1]}-d embeddings")
