<image src="https://raw.githubusercontent.com/semilleroCV/deep-learning-notes/main/assets/banner-notebook.png" width=100%>

# <font color='#4C5FDA'> **DINO: Emerging Properties in Self-Supervised Vision Transformers** </font>

The paper <font color="EB9A54">“DINO: Emerging Properties in Self-Supervised Vision Transformers”</font> presents a novel approach to self-supervised learning using Vision Transformers (ViTs). In simple terms, the goal of the paper is to demonstrate how a model can learn useful representations (attention maps) of images without the need for labels, through a distillation technique.

<font color="EB9A54">**why is DINO relevant?**</font>

 - Self-supervision: The DINO method avoids reliance on large amounts of labeled data, which is useful in scenarios where labeling data is costly or complicated.

 - Vision Transformers: It uses ViTs, a powerful architecture for computer vision tasks, showing that these networks can be effectively trained unsupervised.

 - Emergent Properties: The model trained with DINO learns to capture high-level spatial structures and relationships in images. Surprisingly, it produces highly interpretable attention maps and accurate object locations without being explicitly trained to do so.




<image src="https://i.ibb.co/JymZwqy/Captura-desde-2024-11-29-00-38-53.png" >

In [None]:
%%capture
#@title **Install required packages**

!pip install torchinfo

In [None]:
#@title **Importing libraries**

from torchsummary import summary
import torchinfo
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import vit_b_16
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import numpy as np

In [None]:
# Note: Not all dependencies have the __version__ method.

print(torch.__version__)

2.5.1+cu121


### ViT-Small/16 architecture code

In [None]:
class DINOHead(nn.Module):
    """
    DINO projection head for self-supervised learning
    """
    def __init__(self, in_dim, out_dim, use_bn=True, norm_last_layer=True, nlayers=3, hidden_dim=2048):
        super().__init__()
        nlayers = max(nlayers, 1)
        if nlayers == 1:
            self.mlp = nn.Linear(in_dim, out_dim)
        else:
            layers = [nn.Linear(in_dim, hidden_dim)]
            if use_bn:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.GELU())

            for _ in range(nlayers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                if use_bn:
                    layers.append(nn.BatchNorm1d(hidden_dim))
                layers.append(nn.GELU())

            layers.append(nn.Linear(hidden_dim, out_dim))
            if norm_last_layer:
                layers.append(nn.BatchNorm1d(out_dim, affine=False))

            self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)


class DINO(nn.Module):
    """
    DINO model with ViT-Small/16 as backbone
    """
    def __init__(self,
                 num_classes=1000,
                 out_dim=65536,
                 use_bn_in_head=False,
                 norm_last_layer=True,
                 momentum=0.999,
                 temperature_student=0.1,
                 temperature_teacher=0.1,
                 center_momentum=0.9):
        super().__init__()

        # Load pre-trained ViT Small 16 backbone
        self.backbone = vit_b_16(pretrained=True)

        # Remove the classification head
        self.backbone.heads = nn.Identity()

        # Feature dimension of ViT-Small/16
        feature_dim = 384

        # Create student and teacher heads
        self.student_head = DINOHead(
            feature_dim,
            out_dim,
            use_bn=use_bn_in_head,
            norm_last_layer=norm_last_layer
        )

        self.teacher_head = DINOHead(
            feature_dim,
            out_dim,
            use_bn=use_bn_in_head,
            norm_last_layer=False
        )

        # Freeze teacher head parameters
        for param in self.teacher_head.parameters():
            param.requires_grad = False

        # Define the momentum parameter for EMA update
        self.momentum = momentum

        # Initialize center (C)
        self.center = nn.Parameter(torch.zeros(out_dim), requires_grad=False)

        # Temperatures
        self.temperature_student = temperature_student
        self.temperature_teacher = temperature_teacher

        # Center momentum (for EMA update of the center)
        self.center_momentum = center_momentum

    def update_teacher(self):
        """
        Update teacher model with EMA (Exponential Moving Average)
        """
        with torch.no_grad():
            for student_params, teacher_params in zip(self.student_head.parameters(), self.teacher_head.parameters()):
                teacher_params.data = self.momentum * teacher_params.data + (1. - self.momentum) * student_params.data

    def forward(self, x1, x2):
        """
        Forward pass with two augmented views of the same image
        """
        # Extract features from both augmented views
        z1 = self.backbone(x1)
        z2 = self.backbone(x2)

        # Project features through student head
        p1 = self.student_head(z1)
        p2 = self.student_head(z2)

        # Detach teacher projections (do not compute gradients for teacher)
        with torch.no_grad():
            t1 = F.normalize(self.teacher_head(self.backbone(x1)), dim=-1)
            t2 = F.normalize(self.teacher_head(self.backbone(x2)), dim=-1)

        return p1, p2, t1, t2

    def update_center(self, t1, t2):
        """
        Update the center of the representations using EMA (Exponential Moving Average)
        """
        with torch.no_grad():
            # Concatenate teacher outputs and compute the mean
            center_update = torch.cat([t1, t2]).mean(dim=0)
            self.center.data = self.center_momentum * self.center.data + (1. - self.center_momentum) * center_update



def dino_loss(student_output, teacher_output, temperature=0.1):
    """
    DINO loss function (cross-entropy between student and teacher outputs)
    """
    # Apply log softmax to student output and softmax to teacher output
    student_output = F.log_softmax(student_output / temperature, dim=-1)
    teacher_output = F.softmax(teacher_output / temperature, dim=-1)

    # Compute the DINO loss (cross-entropy)
    loss = torch.sum(-teacher_output * student_output, dim=-1).mean()
    return loss


def prepare_dino_transforms():
    """
    Prepare data augmentations for DINO
    """
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return train_transform


def train_dino(model, optimizer, train_loader, device):
    """
    Training loop for DINO
    """
    model.train()
    for images, _ in train_loader:
        # Get two augmentations of each image
        x1, x2 = images.to(device), images.to(device)

        # Forward pass
        p1, p2, t1, t2 = model(x1, x2)

        # Compute loss
        loss1 = dino_loss(p1, t2)
        loss2 = dino_loss(p2, t1)
        loss = (loss1 + loss2) / 2

        # Backpropagate
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update teacher network (EMA update)
        model.update_teacher()

        # Optional: print loss for monitoring
        print(f"Loss: {loss.item()}")


def extract_features(model, dataloader, device):
    """
    Extract features using the DINO backbone
    """
    model.eval()
    all_features = []
    all_labels = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            features = model.backbone(images)
            all_features.append(features.cpu().numpy())
            all_labels.append(labels.numpy())

    return np.concatenate(all_features), np.concatenate(all_labels)


def knn_evaluation(train_features, train_labels, test_features, test_labels, k=5):
    """
    K-Nearest Neighbors evaluation
    """
    scaler = StandardScaler()
    train_features_scaled = scaler.fit_transform(train_features)
    test_features_scaled = scaler.transform(test_features)

    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(train_features_scaled, train_labels)

    predictions = knn.predict(test_features_scaled)
    accuracy = np.mean(predictions == test_labels)

    return accuracy


def linear_classifier_evaluation(train_features, train_labels, test_features, test_labels):
    """
    Linear Classifier (Logistic Regression) evaluation
    """
    scaler = StandardScaler()
    train_features_scaled = scaler.fit_transform(train_features)
    test_features_scaled = scaler.transform(test_features)

    linear_clf = LogisticRegression(max_iter=1000)
    linear_clf.fit(train_features_scaled, train_labels)

    predictions = linear_clf.predict(test_features_scaled)
    accuracy = np.mean(predictions == test_labels)

    return accuracy


def evaluate_representations(model, train_loader, test_loader, device, knn_k=5):
    """
    Comprehensive evaluation of learned representations
    """
    # Extract features
    train_features, train_labels = extract_features(model, train_loader, device)
    test_features, test_labels = extract_features(model, test_loader, device)

    # KNN Evaluation
    knn_accuracy = knn_evaluation(train_features, train_labels,
                                  test_features, test_labels, k=knn_k)

    # Linear Classifier Evaluation
    linear_accuracy = linear_classifier_evaluation(train_features, train_labels,
                                                   test_features, test_labels)

    return {
        'knn_accuracy': knn_accuracy,
        'linear_classifier_accuracy': linear_accuracy
    }



In [None]:
# Example usage
if __name__ == '__main__':
    # Hyperparameters
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = DINO().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # Note: This is a skeleton. Real implementation would require
    # a complete dataset and proper data loaders
    print("DINO model with ViT-Small/16 backbone initialized!")

DINO model with ViT-Small/16 backbone initialized!
