In [None]:
import sys
sys.path.append('/mnt/v-jepa/jepa')
sys.path.append('/mnt/v-jepa/jepa/decord/python')  # For decord
import torch
import torch.nn as nn
from src.models.vision_transformer import VisionTransformer as VideoEncoder
import cv2
import os
import numpy as np

# Define dataset
class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, num_frames=16, frame_size=(224, 224)):
        self.data_path = data_path
        self.num_frames = num_frames
        self.frame_size = frame_size
        self.classes = sorted([d for d in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, d))])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.videos = []
        for cls in self.classes:
            cls_path = os.path.join(data_path, cls)
            for video in os.listdir(cls_path):
                if video.endswith(('.mp4', '.avi')):
                    self.videos.append((os.path.join(cls_path, video), cls))

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        video_path, cls = self.videos[idx]
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if total_frames < self.num_frames:
            frame_indices = [i % total_frames for i in range(self.num_frames)]
        else:
            frame_indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)
        frames = []
        for idx in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if not ret:
                frame = frames[-1] if frames else np.zeros((*self.frame_size, 3), dtype=np.uint8)
            frame = cv2.resize(frame, self.frame_size)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = torch.tensor(frame).permute(2, 0, 1).float() / 255.0
            frames.append(frame)
        cap.release()
        frames = torch.stack(frames)
        label = torch.tensor(self.class_to_idx[cls], dtype=torch.long)
        return {'frames': frames, 'label': label, 'path': video_path}

dataset = VideoDataset(data_path="/path/to/dataset/")
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)

# Load pre-trained model
encoder = VideoEncoder(
    model_name='vit_large_patch16_224',
    num_frames=16,
    tubelet_size=1, # Process each frame individually
    img_size=224,
    patch_size=16,
    embed_dim=1024, # ViT-Large
    depth=24,
    num_heads=16
)
checkpoint = torch.load('/mnt/v-jepa/jepa/vitl16.pth.tar', map_location='cuda')
encoder.load_state_dict(checkpoint['encoder'], strict=False)
encoder.cuda().train()

# Add classification head
num_classes = len(dataset.classes)
classifier = nn.Linear(encoder.embed_dim, num_classes).cuda().train()

# Optimizer and loss
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Fine-tuning loop
for epoch in range(20):
    for batch in loader:
        frames = batch['frames'].cuda() # [B, T, C, H, W]
        frames = frames.permute(0, 2, 1, 3, 4) # [B, C, T, H, W]
        labels = batch['label'].cuda()
        optimizer.zero_grad()
        repr = encoder(frames)
        logits = classifier(repr.mean(dim=1))
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch}, Loss: {loss.item()}")
torch.save({'encoder': encoder.state_dict(), 'classifier': classifier.state_dict()}, '/mnt/v-jepa/jepa/finetuned_vitl16_20e.pth')