<a href="https://colab.research.google.com/github/ossnat/VSD_foundation_model_evaluation/blob/main/testDino2D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision transformers

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import AutoModel

# ------------------------
# Preprocess function (your code with resize)
# ------------------------
def preprocess_vsd_clip(clip, model_name):
    """
    clip: torch.Tensor or np.ndarray with shape [T, C, H, W] or [T, H, W]
    """
    # Convert numpy array to torch tensor if necessary
    if isinstance(clip, np.ndarray):
        clip = torch.from_numpy(clip)

    clip = clip.float()

    # If missing channel dim → add one
    if clip.ndim == 3:  # [T, H, W]
        clip = clip.unsqueeze(1)  # → [T, 1, H, W]
    elif clip.ndim == 2:  # [H, W], single frame
        clip = clip.unsqueeze(0).unsqueeze(0)  # → [1, 1, H, W]

    # Now guaranteed [T, C, H, W]

    if "dino" in model_name.lower():
        if clip.shape[1] == 1:
            clip = clip.repeat(1, 3, 1, 1)  # → [T, 3, H, W]
        clip = F.interpolate(clip, size=(224, 224), mode="bilinear", align_corners=False)

    elif "resnet" in model_name.lower():
        if clip.shape[1] == 1:
            clip = clip.repeat(1, 3, 1, 1)
        clip = F.interpolate(clip, size=(224, 224), mode="bilinear", align_corners=False)

    return clip

# ------------------------
# Simple Dino2D Backbone wrapper
# ------------------------
class Dino2DBackbone(nn.Module):
    def __init__(self, model_name="facebook/dinov2-base", num_classes=2, freeze_backbone=True):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_name)
        hidden_dim = self.backbone.config.hidden_size

        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # Simple linear classifier
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        """
        x: [B, T, C, H, W]
        """
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)  # flatten time

        # forward backbone
        features = self.backbone(x).last_hidden_state[:, 0]  # CLS token
        features = features.view(B, T, -1)  # [B, T, hidden]

        # simple mean pooling over time
        pooled = features.mean(dim=1)  # [B, hidden]
        out = self.fc(pooled)
        return out





In [None]:
# Create a random mock clip: 8 frames, grayscale, 64x64
mock_clip = torch.randn(8, 64, 64)  # [T, H, W]

# Preprocess
processed = preprocess_vsd_clip(mock_clip, "dino")
print("Processed shape:", processed.shape)  # [T, 3, 224, 224]

# Add batch dim
batch = processed.unsqueeze(0)  # [1, T, 3, 224, 224]

# Model
model = Dino2DBackbone(num_classes=2, freeze_backbone=True)

# Forward
logits = model(batch)
print("Output logits shape:", logits.shape)  # [1, 2]

# Check trainable params
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in model.parameters() if not p.requires_grad)
print(f"Trainable params: {trainable}, Frozen params: {frozen}")