In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTModel, AutoModel, AutoFeatureExtractor, AutoTokenizer
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import aim
import torch.optim as optim

class CustomSigLIP(nn.Module):
    def __init__(self, vision_model_name, text_model_name, projection_dim=512):
        super(CustomSigLIP, self).__init__()

        # Load Vision Transformer (ViT) for image encoding
        self.vision_model = ViTModel.from_pretrained(vision_model_name)
        vision_config = self.vision_model.config
        self.image_projection = nn.Linear(vision_config.hidden_size, projection_dim)

        # Load Text Transformer (e.g., BERT) for text encoding
        self.text_model = AutoModel.from_pretrained(text_model_name)
        text_config = self.text_model.config
        self.text_projection = nn.Linear(text_config.hidden_size, projection_dim)

        # Logit scaling parameter
        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07)))

    def forward(self, images, text):
        # Encode images
        image_features = self.vision_model(pixel_values=images).pooler_output
        image_features = self.image_projection(image_features)
        image_features = nn.functional.normalize(image_features, dim=-1)

        # Encode text
        text_features = self.text_model(**text).pooler_output
        text_features = self.text_projection(text_features)
        text_features = nn.functional.normalize(text_features, dim=-1)

        # Compute similarity logits
        logit_scale = self.logit_scale.exp()
        logits = logit_scale * torch.matmul(image_features, text_features.t())
        return logits

    def encode_image(self, images):
        # Encode images only
        image_features = self.vision_model(pixel_values=images).pooler_output
        image_features = self.image_projection(image_features)
        return nn.functional.normalize(image_features, dim=-1)

    def encode_text(self, text):
        # Encode text only
        text_features = self.text_model(**text).pooler_output
        text_features = self.text_projection(text_features)
        return nn.functional.normalize(text_features, dim=-1)


# Helper function to load the custom model
def load_custom_siglip(vision_model_name="google/vit-base-patch16-224-in21k",
                       text_model_name="bert-base-uncased",
                       projection_dim=512):
    return CustomSigLIP(vision_model_name, text_model_name, projection_dim)


def compute_metrics(predictions, targets):
    """
    Compute multi-class classification metrics.
    Args:
        predictions: Tensor of predicted class indices.
        targets: Tensor of ground truth class indices.
    Returns:
        dict: Metrics (accuracy, precision, recall, F1-score).
    """
    predictions = predictions.cpu().numpy()
    targets = targets.cpu().numpy()

    metrics = {
        "accuracy": accuracy_score(targets, predictions),
        "precision": precision_score(targets, predictions, average="weighted", zero_division=0),
        "recall": recall_score(targets, predictions, average="weighted", zero_division=0),
        "f1_score": f1_score(targets, predictions, average="weighted", zero_division=0),
    }
    return metrics


# Parameters
vision_model_name = "google/vit-base-patch16-224-in21k"
text_model_name = "bert-base-uncased"
projection_dim = 512
batch_size = 32
num_epochs = 5
learning_rate = 1e-5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Data preparation
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224 (ViT input size)
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.Flowers102(root="./data", split="train", transform=train_transform, download=True)
val_dataset = datasets.Flowers102(root="./data", split="val", transform=val_transform, download=True)
test_dataset = datasets.Flowers102(root="./data", split="test", transform=val_transform, download=True)


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Load SigLIP model
model = CustomSigLIP(vision_model_name, text_model_name, projection_dim)
model.to(device)

label_names = [item.strip().strip("'").strip('"') for item in open('flower_labels.txt').read().split('\n')]
labels_to_texts = {idx: name for idx, name in enumerate(label_names)}


# Initialize Aim tracker
run = aim.Run(repo='.', experiment='multi_class_classification_flowers102')
run["hparams"] = {
    "vision_model_name": vision_model_name,
    "text_model_name": text_model_name,
    "projection_dim": projection_dim,
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    "learning_rate": learning_rate,
}

In [4]:
# Training loop
def train_with_logging(model, train_loader, labels_to_texts, tokenizer, optimizer, criterion, device, num_epochs, val_loader=None):
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0.0
        all_predictions = []
        all_targets = []

        for images, labels in tqdm(train_loader):
            # Move data to device
            images = images.to(device)
            labels = labels.to(device)

            # Generate text descriptions for the labels
            text_descriptions = [labels_to_texts[label.item()][0] for label in labels]
            encoded_texts = tokenizer(text_descriptions, return_tensors="pt", padding=True, truncation=True).to(device)

            # Forward pass
            logits = model(images, encoded_texts)

            # Compute labels for cross-entropy loss (diagonal matrix)
            targets = torch.arange(len(images)).to(device)
            loss = criterion(logits, targets)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accumulate loss
            total_loss += loss.item()

            # Collect predictions and targets for metrics
            predictions = torch.argmax(logits, dim=1)
            all_predictions.append(predictions)
            all_targets.append(targets)

        # Concatenate all predictions and targets
        all_predictions = torch.cat(all_predictions, dim=0)
        all_targets = torch.cat(all_targets, dim=0)

        # Compute metrics
        metrics = compute_metrics(all_predictions, all_targets)

        # Log metrics and loss to Aim
        run.track(total_loss / len(train_loader), name='train_loss', step=epoch, context={"subset":"train"})
        for metric_name, metric_value in metrics.items():
            run.track(metric_value, metric_name, step=epoch, context={"subset":"train"})

        # Print epoch summary
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(train_loader):.4f}")
        print(f"Train Metrics: {metrics}")

        # Validation loop (if val_loader is provided)
        if val_loader:
            val_metrics = test_with_logging(model, val_loader, labels_to_texts, tokenizer, device, log_prefix="val")
            print(f"Validation Metrics: {val_metrics}")


def test_with_logging(model, test_loader, labels_to_texts, tokenizer, device, log_prefix="test"):
    model.eval()
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            # Move data to device
            images = images.to(device)
            labels = labels.to(device)

            # Generate text descriptions for the labels
            text_descriptions = [labels_to_texts[label.item()][0] for label in labels]
            encoded_texts = tokenizer(text_descriptions, return_tensors="pt", padding=True, truncation=True).to(device)

            # Forward pass
            logits = model(images, encoded_texts)

            # Collect predictions and targets
            predictions = torch.argmax(logits, dim=1)
            all_predictions.append(predictions)
            all_targets.append(torch.arange(len(labels)).to(device))  # Target is the diagonal (index matches)

    # Concatenate all predictions and targets
    all_predictions = torch.cat(all_predictions, dim=0)
    all_targets = torch.cat(all_targets, dim=0)

    # Compute metrics
    metrics = compute_metrics(all_predictions, all_targets)

    # Log metrics to Aim
    for metric_name, metric_value in metrics.items():
        run.track(metric_value, name=metric_name, context={"subset":log_prefix})

    return metrics



In [None]:
# Load feature extractor and tokenizer
feature_extractor = AutoFeatureExtractor.from_pretrained(vision_model_name)
tokenizer = AutoTokenizer.from_pretrained(text_model_name)

# Define optimizer and loss function
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

# Train the model
train_with_logging(model, train_loader, labels_to_texts, tokenizer, optimizer, criterion, device, num_epochs, val_loader)

# Test the model
test_metrics = test_with_logging(model, test_loader, labels_to_texts, tokenizer, device)
print(f"Test Metrics: {test_metrics}")

run.close()

100%|██████████| 32/32 [00:13<00:00,  2.41it/s]


Epoch [1/5], Loss: 3.4169
Train Metrics: {'accuracy': 0.05196078431372549, 'precision': 0.0476224736571144, 'recall': 0.05196078431372549, 'f1_score': 0.04835113497085057}


100%|██████████| 32/32 [00:07<00:00,  4.39it/s]


Validation Metrics: {'accuracy': 0.06470588235294118, 'precision': 0.03180562637670194, 'recall': 0.06470588235294118, 'f1_score': 0.0391933441254895}


100%|██████████| 32/32 [00:12<00:00,  2.67it/s]


Epoch [2/5], Loss: 2.9995
Train Metrics: {'accuracy': 0.1803921568627451, 'precision': 0.21347992340882366, 'recall': 0.1803921568627451, 'f1_score': 0.16676826429080457}


100%|██████████| 32/32 [00:07<00:00,  4.10it/s]


Validation Metrics: {'accuracy': 0.09313725490196079, 'precision': 0.04867427357739729, 'recall': 0.09313725490196079, 'f1_score': 0.05656420490982364}


 47%|████▋     | 15/32 [00:05<00:06,  2.48it/s]