

```
# This is formatted as code
```

# E3: Multimodal Representations


## 0. Setup

### 0.a Dataset creation

Download the necessary data

In [None]:
!wget https://raw.githubusercontent.com/mlfoundations/imagenet-captions/main/imagenet_captions.zip; unzip imagenet_captions.zip; wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz; tar -xzf imagenette2.tgz

--2025-07-03 12:33:45--  https://raw.githubusercontent.com/mlfoundations/imagenet-captions/main/imagenet_captions.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 34493455 (33M) [application/zip]
Saving to: ‘imagenet_captions.zip’


2025-07-03 12:33:46 (199 MB/s) - ‘imagenet_captions.zip’ saved [34493455/34493455]

Archive:  imagenet_captions.zip
  inflating: imagenet_captions.json  
--2025-07-03 12:33:47--  https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.198.8, 16.182.74.56, 54.231.130.152, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.198.8|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1557161267 (1.5G) [application/x-tar]
Saving to: ‘imagenette

Check the data loads correctly

In [None]:
import os
import json
from pathlib import Path
from typing import Dict, List, Optional
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

from __future__ import annotations

import argparse
import json
import os
import time
from pathlib import Path
from typing import List, Dict
import requests

from transformers import AutoModel, AutoProcessor, get_cosine_schedule_with_warmup

os.environ["TOKENIZERS_PARALLELISM"] = "false"
data_path = '.'

class DirectImageCaptionDataset(Dataset):
    """PyTorch Dataset that loads images and captions directly from files,
       but only includes captioned images in the train split."""

    imagenette_classes = {
        'n01440764': 'tench',
        'n02102040': 'English springer',
        'n02979186': 'cassette player',
        'n03000684': 'chain saw',
        'n03028079': 'church',
        'n03394916': 'French horn',
        'n03417042': 'garbage truck',
        'n03425413': 'gas pump',
        'n03445777': 'golf ball',
        'n03888257': 'parachute',
    }

    def __init__(self,
                 imagenette_path: str,
                 captions_path: str,
                 split: str = 'train',
                 transform=None,
                 use_template: bool = False,
                 n_classes: int = 10):
        """
        Args:
            imagenette_path: Path to imagenette2 directory
            captions_path: Path to imagenet_captions.json
            split: 'train' or 'val'
            transform: Optional transform to be applied to images
        """
        super().__init__()
        self.imagenette_path = Path(imagenette_path)
        self.split = split
        self.transform = transform
        assert 1 <= n_classes <= 10, f"Imagenette contains 10 classes. You cannot use {n_classes} classes."
        self.imagenette_classes = {k: self.imagenette_classes[k] for k in list(self.imagenette_classes.keys())[:n_classes]}

        # Load captions once
        print(f"Loading captions from {captions_path}...")
        with open(captions_path, 'r') as f:
            captions_data = json.load(f)
        # Map filename -> caption dict
        self.filename_to_caption = {item['filename']: item for item in captions_data}

        # Build image list, filtering out un-captioned images if train
        self.image_paths: List[Path] = []
        self._build_image_list()
        self.use_template = use_template
        print(f"Found {len(self.image_paths)} images in '{self.split}' split")

    def _build_image_list(self):
        split_dir = self.imagenette_path / self.split
        for class_dir in split_dir.iterdir():
            if not class_dir.is_dir():
                continue
            wnid = class_dir.name
            if wnid not in self.imagenette_classes:
                continue

            for img_path in class_dir.iterdir():
                if img_path.suffix.lower() not in ('.jpg', '.jpeg', '.png'):
                    continue

                if self.split == 'train':
                    # only keep if we have a caption and wnid matches
                    cap = self.filename_to_caption.get(img_path.name)
                    if cap is not None and cap.get('wnid') == wnid:
                        self.image_paths.append(img_path)
                    # else: skip this image entirely
                else:
                    # val: keep all images, will fall back to template if no caption
                    self.image_paths.append(img_path)

    def _get_caption(self, img_path: Path) -> str:
        """Get or build a suitable caption for an image."""
        fn = img_path.name
        wnid = img_path.parent.name

        cap = self.filename_to_caption.get(fn)
        if not self.use_template and cap is not None and cap.get('wnid') == wnid:
            # build something like "Title. tag1, tag2. description"
            parts = []
            if cap.get('title'):
                parts.append(cap['title'])
            if cap.get('tags'):
                tags = cap['tags'][:5]
                parts.append(', '.join(tags))
            if cap.get('description') and len(cap['description']) < 200:
                parts.append(cap['description'])
            caption = '. '.join(parts).replace('\n',' ').strip()
            return caption or f"A photo of {self.imagenette_classes.get(wnid, wnid)}"
        else:
            # fallback template
            clsname = self.imagenette_classes.get(wnid, wnid)
            return f"A photo of {clsname}"

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        caption = self._get_caption(img_path)
        return {
            'image': image,
            'caption': caption,
            'image_path': str(img_path),
        }

def create_direct_data_loaders(imagenette_path: str,
                               captions_path: str,
                               batch_size: int = 32,
                               use_template: bool = False,
                               n_train_classes: int = 10):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    train_ds = DirectImageCaptionDataset(
        imagenette_path, captions_path, split='train', use_template=use_template,
        transform=None, n_classes=n_train_classes
    )
    val_ds   = DirectImageCaptionDataset(
        imagenette_path, captions_path, split='val', use_template=True,
        transform=None, n_classes=n_train_classes
    )

    def collate_pil(batch):
        """Return a dict whose 'image' field is a list of PIL images (no stacking)."""
        return {
            "image": [item["image"] for item in batch],
            "caption": [item["caption"] for item in batch],
            "image_path": [item["image_path"] for item in batch],
        }

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=4, pin_memory=True, collate_fn=collate_pil, drop_last=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=4, pin_memory=True, collate_fn=collate_pil, drop_last=True)
    return train_loader, val_loader, train_ds, val_ds



_IMAGENET_TXT = (
    "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
)
_IMAGENET_JSON = (
    "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
)

assert torch.cuda.device_count() == 2, "You need two GPUs for this exercise. You can use Kaggle, 30 hours per week are provided on 2x T4s"

def download_imagenet_label_lists() -> tuple[List[str], Dict[str, int]]:
    """Returns (`class_names`, `wnid2idx`)."""

    # Human‑readable label strings (idx→name)
    names = requests.get(_IMAGENET_TXT, timeout=30).text.strip().split("\n")

    # Mapping WNID → integer class index
    wnid2idx: Dict[str, int] = {}
    j = requests.get(_IMAGENET_JSON, timeout=30).json()
    for idx_str, (wnid, _) in j.items():
        wnid2idx[wnid] = int(idx_str)

    return names, wnid2idx


imagenette_path = os.path.join(data_path, "imagenette2")
captions_path  = os.path.join(data_path, "imagenet_captions.json")

train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=256, use_template=False
)

print(f"Train images (with captions only): {len(train_ds)}")
print(f"Val images (all):               {len(val_ds)}\n")

batch = next(iter(train_loader))
print("Sample batch:")
# print(batch['image'].shape)
for i, c in enumerate(batch['caption'][:3]):
    print(f"  {i}: {c}")

def build_text_embeddings(
    processor, model, class_names: List[str], device: torch.device, batch_size: int = 128
):
    """Encode *all* prompt texts once and stack into a single tensor."""

    prompts = [f"This is a photo of a {name}." for name in class_names]
    embeddings = []
    model.eval()
    with torch.no_grad():
        for i in range(0, len(prompts), batch_size):
            batch_prompts = prompts[i : i + batch_size]
            text_inputs = processor(  # padding *must* be max_length for SigLIP
                # text=batch_prompts, padding=True, return_tensors="pt", truncation=True
                text=batch_prompts, padding="max_length", return_tensors="pt"
            ).to(device)
            emb = model.get_text_features(**text_inputs)
            # emb = F.normalize(emb, dim=-1)
            embeddings.append(emb)
    return torch.cat(embeddings, dim=0)  # (1000, D)

2025-07-03 12:38:26.902815: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751546307.125598      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751546307.192614      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Loading captions from ./imagenet_captions.json...
Found 3124 images in 'train' split
Loading captions from ./imagenet_captions.json...
Found 3925 images in 'val' split
Train images (with captions only): 3124
Val images (all):               3925

Sample batch:
  0: Golf ball. I have no golf club and I'm not a golfer but I have a golf ball
  1: Luke in the hole. English Springer Spaniel, dog, Luke, puppy
  2: Fisher-Price cassette player. fisher, price. see more: <a href="http://www.beatsnothing.com" rel="noreferrer nofollow">www.beatsnothing.com</a>


### 0.b Training and evaluating SigLiP

In [None]:
def train(
    train_loader: DataLoader,
    val_loader: DataLoader,
    model: Optional[torch.nn.Module] = None,
    model_name: str = "google/siglip-base-patch16-224",
    epochs: int = 5,
    lr: float = 5e-6,
    warmup_steps: Optional[int] = None,
    device: str | torch.device = "cuda",
    eval_only: bool = False,
    use_multi_gpu: bool = True,
):
    device = torch.device(device)

    if warmup_steps is None:
        warmup_steps = 0.1*len(train_loader)*epochs
    print("\nLoading model & processor …")
    model = model or AutoModel.from_pretrained(model_name).to(device)
    processor = AutoProcessor.from_pretrained(model_name)

    class_names, wnid2idx = download_imagenet_label_lists()
    text_embeds = build_text_embeddings(processor, model, class_names, device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0)
    total_steps = epochs * len(train_loader)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )
    if use_multi_gpu:
        model = torch.nn.DataParallel(model)

    global_step = 0
    for epoch in range(epochs + 1):

        # -------- eval -------------------------------------------------------
        model.eval()
        num_correct = 0
        num_total = 0
        base_model = model.module if use_multi_gpu else model

        with torch.no_grad():
            for batch in val_loader:
                images: List["PIL.Image.Image"] = batch["image"]
                wnids = [Path(p).parent.name for p in batch["image_path"]]
                true_indices = torch.tensor([wnid2idx[w] for w in wnids], device=device)

                image_inputs = processor(images=images, padding="max_length", return_tensors="pt").to(device)
                image_embeds = base_model.get_image_features(**image_inputs)

                logits = image_embeds @ text_embeds.T  # (B, 1000)
                preds = logits.argmax(dim=1)
                num_correct += (preds == true_indices).sum().item()
                num_total += len(images)

        acc = num_correct / num_total
        print(f"Epoch {epoch}/{epochs} – top‑1 Imagenette (1000‑way): {acc * 100:.2f}%")
        if eval_only or epoch >= epochs:
            return base_model
        # -------- training -------------------------------------------------------
        model.train()
        epoch_loss = 0.0

        for batch in train_loader:
            images: List["PIL.Image.Image"] = batch["image"]
            captions: List[str] = batch["caption"]

            inputs = processor(
                text=captions,
                images=images,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).to(device)
            outputs = model(**inputs, return_loss=True)

            loss = outputs.loss

            if use_multi_gpu:
                loss = loss.mean()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()

            epoch_loss += loss.item()
            global_step += 1

        mean_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch}/{epochs} – train loss: {mean_loss:.4f}")

    return base_model

## 1. Fix the training evaluation

In [None]:
def train(
    train_loader: DataLoader,
    val_loader: DataLoader,
    model: Optional[torch.nn.Module] = None,
    model_name: str = "google/siglip-base-patch16-224",
    epochs: int = 5,
    lr: float = 5e-6,
    warmup_steps: Optional[int] = None,
    device: str | torch.device = "cuda",
    eval_only: bool = False,
    use_multi_gpu: bool = True,
):
    device = torch.device(device)

    if warmup_steps is None:
        warmup_steps = 0.1*len(train_loader)*epochs
    print("\nLoading model & processor …")
    model = model or AutoModel.from_pretrained(model_name).to(device)
    processor = AutoProcessor.from_pretrained(model_name)

    class_names, wnid2idx = download_imagenet_label_lists()

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0)
    total_steps = epochs * len(train_loader)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )
    if use_multi_gpu:
        model = torch.nn.DataParallel(model)

    global_step = 0
    for epoch in range(epochs + 1):

        # -------- eval -------------------------------------------------------
        model.eval()
        num_correct = 0
        num_total = 0
        base_model = model.module if use_multi_gpu else model

        with torch.no_grad():
            text_embeds = build_text_embeddings(processor, base_model, class_names, device) # Moved down in loop, as text embeds should update?
            for batch in val_loader:
                images: List["PIL.Image.Image"] = batch["image"]
                wnids = [Path(p).parent.name for p in batch["image_path"]]
                true_indices = torch.tensor([wnid2idx[w] for w in wnids], device=device)

                image_inputs = processor(images=images, padding="max_length", return_tensors="pt").to(device)
                image_embeds = base_model.get_image_features(**image_inputs)

                logits = image_embeds @ text_embeds.T  # (B, 1000)
                preds = logits.argmax(dim=1)
                num_correct += (preds == true_indices).sum().item()
                num_total += len(images)

        acc = num_correct / num_total
        print(f"Epoch {epoch}/{epochs} – top‑1 Imagenette (1000‑way): {acc * 100:.2f}%")
        if eval_only or epoch >= epochs:
            return base_model
        # -------- training -------------------------------------------------------
        model.train()
        epoch_loss = 0.0

        for batch in train_loader:
            images: List["PIL.Image.Image"] = batch["image"]
            captions: List[str] = batch["caption"]

            inputs = processor(
                text=captions,
                images=images,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).to(device)
            outputs = model(**inputs, return_loss=True)

            loss = outputs.loss

            if use_multi_gpu:
                loss = loss.mean()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()

            epoch_loss += loss.item()
            global_step += 1

        mean_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch}/{epochs} – train loss: {mean_loss:.4f}")

    return base_model

In [None]:
from difflib import unified_diff

# pick the execution numbers you want to compare
left  = In[3].splitlines()   # cell with prompt In [3]:
right = In[4].splitlines()   # cell with prompt In [5]:

diff_lines = unified_diff(
    left,
    right,
    fromfile='cell 3',
    tofile='cell 4',
    lineterm=''            # avoid extra blank lines
)

print('\n'.join(diff_lines))

--- cell 3
+++ cell 4
@@ -19,7 +19,6 @@
     processor = AutoProcessor.from_pretrained(model_name)
 
     class_names, wnid2idx = download_imagenet_label_lists()
-    text_embeds = build_text_embeddings(processor, model, class_names, device)
 
     optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0)
     total_steps = epochs * len(train_loader)
@@ -39,6 +38,7 @@
         base_model = model.module if use_multi_gpu else model
 
         with torch.no_grad():
+            text_embeds = build_text_embeddings(processor, base_model, class_names, device) # Moved down in loop, as text embeds should update?
             for batch in val_loader:
                 images: List["PIL.Image.Image"] = batch["image"]
                 wnids = [Path(p).parent.name for p in batch["image_path"]]


## 2. Finetuning SigLip

### 2.a) Using templates vs captions

In [None]:
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=64, use_template=False, n_train_classes=10
)

model = train(train_loader, val_loader, )
print(f"\nFinal validation accuracy using templates:")
model = train(train_loader, val_loader, model=model, eval_only=True)

print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=64, use_template=True, n_train_classes=10
)
model = train(train_loader, val_loader,)
print(f"\nFinal validation accuracy using templates:")
model = train(train_loader, val_loader, model=model, eval_only=True)

Loading captions from ./imagenet_captions.json...
Found 3124 images in 'train' split
Loading captions from ./imagenet_captions.json...
Found 3925 images in 'val' split

Loading model & processor …


config.json:   0%|          | 0.00/432 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/813M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/368 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/711 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/798k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/409 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Epoch 0/5 – top‑1 Imagenette (1000‑way): 77.79%




Epoch 0/5 – train loss: 2.0811
Epoch 1/5 – top‑1 Imagenette (1000‑way): 70.13%
Epoch 1/5 – train loss: 1.1454
Epoch 2/5 – top‑1 Imagenette (1000‑way): 72.64%
Epoch 2/5 – train loss: 0.8410
Epoch 3/5 – top‑1 Imagenette (1000‑way): 70.90%
Epoch 3/5 – train loss: 0.6911
Epoch 4/5 – top‑1 Imagenette (1000‑way): 69.88%
Epoch 4/5 – train loss: 0.6261
Epoch 5/5 – top‑1 Imagenette (1000‑way): 69.49%

Final validation accuracy using templates:

Loading model & processor …
Epoch 0/5 – top‑1 Imagenette (1000‑way): 69.49%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 Using template 
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Loading captions from ./imagenet_captions.json...
Found 3124 images in 'train' split
Loading captions from ./imagenet_captions.json...
Found 3925 images in 'val' split

Loading model & processor …
Epoch 0/5 – top‑1 Imagenette (1000‑way): 77.79%
Epoch 0/5 – train l

### 2.b) Explore the effect of batch size and number of classes

In [None]:
# ---- YOUR CODE STARTS HERE ----
batch_size = 128
print(f"{'+'*100} \nUsing batch size of: {batch_size} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=10
)
model = train(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = train(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 
Using batch size of: 128 and template
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 Using template 
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Loading captions from ./imagenet_captions.json...
Found 3124 images in 'train' split
Loading captions from ./imagenet_captions.json...
Found 3925 images in 'val' split

Loading model & processor …
Epoch 0/5 – top‑1 Imagenette (1000‑way): 78.05%
Epoch 0/5 – train loss: 5.2859
Epoch 1/5 – top‑1 Imagenette (1000‑way): 82.76%
Epoch 1/5 – train loss: 3.2692
Epoch 2/5 – top‑1 Imagenette (1000‑way): 86.22%
Epoch 2/5 – train loss: 3.0853
Epoch 3/5 – top‑1 Imagenette (1000‑way): 86.64%
Epoch 3/5 – train loss: 3.0743
Epoch 4/5 – top‑1 Imagenette (1000‑way): 86.95%
Epoch 4/5 – train loss: 3.0524
Epoch 5/5 – top‑1 Imagenette (1000‑way): 86.95%


In [None]:
# ---- YOUR CODE STARTS HERE ----
batch_size = 32
print(f"{'+'*100} \nUsing batch size of: {batch_size} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=10
)
model = train(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = train(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

In [None]:
# ---- YOUR CODE STARTS HERE ----
batch_size = 16
print(f"{'+'*100} \nUsing batch size of: {batch_size} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=10
)
model = train(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = train(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 
Using batch size of: 16 and template
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 Using template 
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Loading captions from ./imagenet_captions.json...
Found 3124 images in 'train' split
Loading captions from ./imagenet_captions.json...
Found 3925 images in 'val' split

Loading model & processor …
Epoch 0/5 – top‑1 Imagenette (1000‑way): 77.73%
Epoch 0/5 – train loss: 1.5715
Epoch 1/5 – top‑1 Imagenette (1000‑way): 87.58%
Epoch 1/5 – train loss: 1.2983
Epoch 2/5 – top‑1 Imagenette (1000‑way): 89.16%
Epoch 2/5 – train loss: 1.2574
Epoch 3/5 – top‑1 Imagenette (1000‑way): 90.46%
Epoch 3/5 – train loss: 1.2617
Epoch 4/5 – top‑1 Imagenette (1000‑way): 91.25%
Epoch 4/5 – train loss: 1.2339
Epoch 5/5 – top‑1 Imagenette (1000‑way): 91.05%
F

In [None]:
# ---- YOUR CODE STARTS HERE ----
batch_size = 8
print(f"{'+'*100} \nUsing batch size of: {batch_size} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=10
)
model = train(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = train(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 
Using batch size of: 8 and template
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 Using template 
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Loading captions from ./imagenet_captions.json...
Found 3124 images in 'train' split
Loading captions from ./imagenet_captions.json...
Found 3925 images in 'val' split

Loading model & processor …
Epoch 0/5 – top‑1 Imagenette (1000‑way): 77.73%
Epoch 0/5 – train loss: 1.0309
Epoch 1/5 – top‑1 Imagenette (1000‑way): 89.72%
Epoch 1/5 – train loss: 0.8393
Epoch 2/5 – top‑1 Imagenette (1000‑way): 93.09%
Epoch 2/5 – train loss: 0.7668
Epoch 3/5 – top‑1 Imagenette (1000‑way): 90.28%
Epoch 3/5 – train loss: 0.7438
Epoch 4/5 – top‑1 Imagenette (1000‑way): 93.42%
Epoch 4/5 – train loss: 0.7157
Epoch 5/5 – top‑1 Imagenette (1000‑way): 93.44%
Fi

In [None]:
# ---- YOUR CODE STARTS HERE ----
batch_size = 4
print(f"{'+'*100} \nUsing batch size of: {batch_size} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=10
)
model = train(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = train(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 
Using batch size of: 4 and template
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 Using template 
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Loading captions from ./imagenet_captions.json...
Found 3124 images in 'train' split
Loading captions from ./imagenet_captions.json...
Found 3925 images in 'val' split

Loading model & processor …
Epoch 0/5 – top‑1 Imagenette (1000‑way): 77.68%
Epoch 0/5 – train loss: 0.5712
Epoch 1/5 – top‑1 Imagenette (1000‑way): 93.55%


KeyboardInterrupt: 

In [None]:
# ---- YOUR CODE STARTS HERE ----
n_train_classes = 1
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=64, use_template=True, n_train_classes=n_train_classes
)
model = train(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = train(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 
Using n_train_classes of: 1 and template
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 Using template 
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Loading captions from ./imagenet_captions.json...
Found 154 images in 'train' split
Loading captions from ./imagenet_captions.json...
Found 387 images in 'val' split

Loading model & processor …
Epoch 0/5 – top‑1 Imagenette (1000‑way): 89.06%
Epoch 0/5 – train loss: 89.6434
Epoch 1/5 – top‑1 Imagenette (1000‑way): 28.12%
Epoch 1/5 – train loss: 10.4061
Epoch 2/5 – top‑1 Imagenette (1000‑way): 0.78%
Epoch 2/5 – train loss: 14.2488
Epoch 3/5 – top‑1 Imagenette (1000‑way): 0.52%
Epoch 3/5 – train loss: 13.7850
Epoch 4/5 – top‑1 Imagenette (1000‑way): 1.04%
Epoch 4/5 – train loss: 12.9667
Epoch 5/5 – top‑1 Imagenette (1000‑way): 1.30

In [None]:
# ---- YOUR CODE STARTS HERE ----
n_train_classes = 2
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=64, use_template=True, n_train_classes=n_train_classes
)
model = train(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = train(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

In [None]:
# ---- YOUR CODE STARTS HERE ----
n_train_classes = 4
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=64, use_template=True, n_train_classes=n_train_classes
)
model = train(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = train(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

In [None]:
# ---- YOUR CODE STARTS HERE ----
n_train_classes = 8
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=64, use_template=True, n_train_classes=n_train_classes
)
model = train(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = train(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

In [None]:
# ---- YOUR CODE STARTS HERE ----
n_train_classes = 4
batch_size = 32
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=n_train_classes
)
model = train(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = train(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 
Using n_train_classes of: 4 and template
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 Using template 
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Loading captions from ./imagenet_captions.json...
Found 1069 images in 'train' split
Loading captions from ./imagenet_captions.json...
Found 1525 images in 'val' split

Loading model & processor …
Epoch 0/5 – top‑1 Imagenette (1000‑way): 71.48%
Epoch 0/5 – train loss: 4.0838
Epoch 1/5 – top‑1 Imagenette (1000‑way): 83.44%
Epoch 1/5 – train loss: 2.6809
Epoch 2/5 – top‑1 Imagenette (1000‑way): 86.97%
Epoch 2/5 – train loss: 2.6245
Epoch 3/5 – top‑1 Imagenette (1000‑way): 86.97%
Epoch 3/5 – train loss: 2.5890
Epoch 4/5 – top‑1 Imagenette (1000‑way): 87.23%
Epoch 4/5 – train loss: 2.5678
Epoch 5/5 – top‑1 Imagenette (1000‑way): 87.2

## 3. Finetuning SigLiT

In [None]:
# ---- YOUR CODE STARTS HERE ----
def fine_tune_siglit_style(
    train_loader: DataLoader,
    val_loader: DataLoader,
    model: Optional[torch.nn.Module] = None,
    model_name: str = "google/siglip-base-patch16-224",
    epochs: int = 5,
    lr: float = 5e-6,
    warmup_steps: Optional[int] = None,
    device: str | torch.device = "cuda",
    eval_only: bool = False,
    use_multi_gpu: bool = True,
):
    device = torch.device(device)

    if warmup_steps is None:
        warmup_steps = 0.1*len(train_loader)*epochs
    print("\nLoading model & processor …")
    model = model or AutoModel.from_pretrained(model_name).to(device)
    processor = AutoProcessor.from_pretrained(model_name)

    # ❄️❄️❄️freeze the vision tower❄️❄️❄️
    for parameter in model.vision_model.parameters():
        parameter.requires_grad = False


    # only optimize text tower
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0)

    class_names, wnid2idx = download_imagenet_label_lists()

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0)
    total_steps = epochs * len(train_loader)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )


    if use_multi_gpu:
        model = torch.nn.DataParallel(model)

    global_step = 0
    for epoch in range(epochs + 1):

        base_model = model.module if use_multi_gpu else model

        # -------- eval -------------------------------------------------------
        model.eval()
        num_correct = 0
        num_total = 0
        base_model = model.module if use_multi_gpu else model

        with torch.no_grad():
            text_embeds = build_text_embeddings(processor, base_model, class_names, device)

            for batch in val_loader:
                images: List["PIL.Image.Image"] = batch["image"]
                wnids = [Path(p).parent.name for p in batch["image_path"]]
                true_indices = torch.tensor([wnid2idx[w] for w in wnids], device=device)

                image_inputs = processor(images=images, padding="max_length", return_tensors="pt").to(device)
                image_embeds = base_model.get_image_features(**image_inputs)

                logits = image_embeds @ text_embeds.T  # (B, 1000)
                preds = logits.argmax(dim=1)
                num_correct += (preds == true_indices).sum().item()
                num_total += len(images)

        acc = num_correct / num_total
        print(f"Epoch {epoch}/{epochs} – top‑1 Imagenette (1000‑way): {acc * 100:.2f}%")
        if eval_only or epoch >= epochs:
            return base_model
        # -------- training -------------------------------------------------------
        model.train()
        epoch_loss = 0.0

        for batch in train_loader:
            images: List["PIL.Image.Image"] = batch["image"]
            captions: List[str] = batch["caption"]

            inputs = processor(
                text=captions,
                images=images,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).to(device)
            outputs = model(**inputs, return_loss=True)

            loss = outputs.loss

            if use_multi_gpu:
                loss = loss.mean()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()

            epoch_loss += loss.item()
            global_step += 1

        mean_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch}/{epochs} – train loss: {mean_loss:.4f}")

    return base_model
# ---- YOUR CODE ENDS HERE ----

In [None]:
# ---- YOUR CODE STARTS HERE ----
# Fine tuning SigLit model
n_train_classes = 10
batch_size = 128
print(f"{'+'*100} \nUsing batch_size of: {batch_size} and template")
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=n_train_classes
)
model = fine_tune_siglit_style(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = fine_tune_siglit_style(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 
Using batch_size of: 128 and template
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 
Using n_train_classes of: 10 and template
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 Using template 
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Loading captions from ./imagenet_captions.json...
Found 3124 images in 'train' split
Loading captions from ./imagenet_captions.json...
Found 3925 images in 'val' split

Loading model & processor …
Epoch 0/5 – top‑1 Imagenette (1000‑way): 78.05%
Epoch 0/5 – train loss: 5.3876
Epoch 1/5 – top‑1 Imagenette (1000‑way): 76.64%
Epoch 1/5 – train loss: 3.5065
Epoch 2/5 – top‑1 Imagenette (1000‑way): 80.83%
Epoch 2/5 – train loss: 3.2617
Epoch 3/5 – top‑1 Imagenette (1000‑way): 81.51%
Epoch 3/5 – tr

In [None]:
# ---- YOUR CODE STARTS HERE ----
# Fine tuning SigLit model
n_train_classes = 10
batch_size = 64
print(f"{'+'*100} \nUsing batch_size of: {batch_size} and template")
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=n_train_classes
)
model = fine_tune_siglit_style(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = fine_tune_siglit_style(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 
Using batch_size of: 64 and template
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 
Using n_train_classes of: 10 and template
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 Using template 
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Loading captions from ./imagenet_captions.json...
Found 3124 images in 'train' split
Loading captions from ./imagenet_captions.json...
Found 3925 images in 'val' split

Loading model & processor …
Epoch 0/5 – top‑1 Imagenette (1000‑way): 77.79%
Epoch 0/5 – train loss: 3.4323
Epoch 1/5 – top‑1 Imagenette (1000‑way): 79.59%
Epoch 1/5 – train loss: 2.5941
Epoch 2/5 – top‑1 Imagenette (1000‑way): 80.20%
Epoch 2/5 – train loss: 2.5091
Epoch 3/5 – top‑1 Imagenette (1000‑way): 80.71%
Epoch 3/5 – tra

In [None]:
# ---- YOUR CODE STARTS HERE ----
# Fine tuning SigLit model
n_train_classes = 10
batch_size = 32
print(f"{'+'*100} \nUsing batch_size of: {batch_size} and template")
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=n_train_classes
)
model = fine_tune_siglit_style(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = fine_tune_siglit_style(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 
Using n_train_classes of: 10 and template
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 Using template 
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Loading captions from ./imagenet_captions.json...
Found 3124 images in 'train' split
Loading captions from ./imagenet_captions.json...
Found 3925 images in 'val' split

Loading model & processor …
Epoch 0/5 – top‑1 Imagenette (1000‑way): 77.79%
Epoch 0/5 – train loss: 2.3395
Epoch 1/5 – top‑1 Imagenette (1000‑way): 83.66%
Epoch 1/5 – train loss: 1.9254
Epoch 2/5 – top‑1 Imagenette (1000‑way): 83.20%
Epoch 2/5 – train loss: 1.8958
Epoch 3/5 – top‑1 Imagenette (1000‑way): 82.33%
Epoch 3/5 – train loss: 1.8720
Epoch 4/5 – top‑1 Imagenette (1000‑way): 82.48%
Epoch 4/5 – train loss: 1.8036
Epoch 5/5 – top‑1 Imagenette (1000‑way): 83.

In [None]:
# ---- YOUR CODE STARTS HERE ----
# Fine tuning SigLit model
n_train_classes = 10
batch_size = 16
print(f"{'+'*100} \nUsing batch_size of: {batch_size} and template")
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=n_train_classes
)
model = fine_tune_siglit_style(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = fine_tune_siglit_style(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 
Using batch_size of: 16 and template
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 
Using n_train_classes of: 10 and template
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 Using template 
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Loading captions from ./imagenet_captions.json...
Found 3124 images in 'train' split
Loading captions from ./imagenet_captions.json...
Found 3925 images in 'val' split

Loading model & processor …
Epoch 0/5 – top‑1 Imagenette (1000‑way): 77.73%
Epoch 0/5 – train loss: 1.6496
Epoch 1/5 – top‑1 Imagenette (1000‑way): 82.65%
Epoch 1/5 – train loss: 1.4073
Epoch 2/5 – top‑1 Imagenette (1000‑way): 81.84%
Epoch 2/5 – train loss: 1.3022
Epoch 3/5 – top‑1 Imagenette (1000‑way): 80.61%
Epoch 3/5 – tra

In [None]:
# ---- YOUR CODE STARTS HERE ----
# Fine tuning SigLit model
n_train_classes = 10
batch_size = 8
print(f"{'+'*100} \nUsing batch_size of: {batch_size} and template")
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=n_train_classes
)
model = fine_tune_siglit_style(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = fine_tune_siglit_style(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

In [None]:
# ---- YOUR CODE STARTS HERE ----
# Fine tuning SigLit model
n_train_classes = 10
batch_size = 4
print(f"{'+'*100} \nUsing batch_size of: {batch_size} and template")
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=n_train_classes
)
model = fine_tune_siglit_style(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = fine_tune_siglit_style(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

In [None]:
# ---- YOUR CODE STARTS HERE ----
# Fine tuning SigLit model
n_train_classes = 8
batch_size = 32
print(f"{'+'*100} \nUsing batch_size of: {batch_size} and template")
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=n_train_classes
)
model = fine_tune_siglit_style(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = fine_tune_siglit_style(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

In [None]:
# ---- YOUR CODE STARTS HERE ----
# Fine tuning SigLit model
n_train_classes = 6
batch_size = 32
print(f"{'+'*100} \nUsing batch_size of: {batch_size} and template")
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=n_train_classes
)
model = fine_tune_siglit_style(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = fine_tune_siglit_style(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

In [None]:
# ---- YOUR CODE STARTS HERE ----
# Fine tuning SigLit model
n_train_classes = 4
batch_size = 32
print(f"{'+'*100} \nUsing batch_size of: {batch_size} and template")
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=n_train_classes
)
model = fine_tune_siglit_style(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = fine_tune_siglit_style(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

In [None]:
# ---- YOUR CODE STARTS HERE ----
# Fine tuning SigLit model
n_train_classes = 2
batch_size = 32
print(f"{'+'*100} \nUsing batch_size of: {batch_size} and template")
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=n_train_classes
)
model = fine_tune_siglit_style(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = fine_tune_siglit_style(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----

In [None]:
# ---- YOUR CODE STARTS HERE ----
# Fine tuning SigLit model
n_train_classes = 1
batch_size = 32
print(f"{'+'*100} \nUsing batch_size of: {batch_size} and template")
print(f"{'+'*100} \nUsing n_train_classes of: {n_train_classes} and template")
print(f"{'%'*100}\n Using template \n {'%'*100}")
train_loader, val_loader, train_ds, val_ds = create_direct_data_loaders(
    imagenette_path, captions_path, batch_size=batch_size, use_template=True, n_train_classes=n_train_classes
)
model = fine_tune_siglit_style(train_loader, val_loader, )
print(f"Final validation accuracy using templates:")
model = fine_tune_siglit_style(train_loader, val_loader, model=model, eval_only=True)
# ---- YOUR CODE ENDS HERE ----