![Py4Eng](../logo.png)

# Finetuning pretrained model
## Yoav Ram

This notebook shows a simple, step-by-step finetuning workflow using a pretrained [EfficientNetV2](https://arxiv.org/pdf/2104.00298) backbone (from `timm`) and PyTorch. The target dataset is the [Hyena ID 2022](https://lila.science/datasets/hyena-id-2022/) dataset (3104 photos, 256 individuals) — the task is per-individual classification.

We will load the pretrained model using [`timm`](https://timm.fast.ai), a collection of state-of-the-art computer vision models. Install it with `pip install timm`.

In [15]:
import os
import shutil
import tarfile
import urllib.request
import random
from tqdm import tqdm
from pathlib import Path
import json
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
import timm

print('torch', torch.__version__)


torch 2.7.1


# Data

We Download the [Hyena ID 2022 dataset](https://lila.science/datasets/hyena-id-2022/) and extract to `data/hyena`. 

In [10]:
url = "http://us-west-2.opendata.source.coop.s3.amazonaws.com/agentmorris/lila-wildlife/wild-me/hyena.coco.tar.gz"
out_path = '../data/hyena.coco.tar.gz'
extract_dir = '../data/hyena'
chunk = 1024 * 1024

os.makedirs(os.path.dirname(out_path), exist_ok=True)

In [11]:
if not os.path.exists(extract_dir):
    if not os.path.exists(out_path):
        print(f'Downloading {url} to {out_path}')
        with urllib.request.urlopen(url) as r, open(out_path, 'wb') as f:
            total = r.getheader('Content-Length')
            total = int(total) if total else None
            with tqdm(total=total, unit='B', unit_scale=True, desc='download') as p:
                while True:
                    data = r.read(chunk)
                    if not data:
                        break
                    f.write(data)
                    p.update(len(data))
    else:
        print(f'File {out_path} already exists, skipping download.')

    print(f"Extracting {out_path}")
    with tarfile.open(out_path, 'r:gz') as t:
        os.makedirs(extract_dir, exist_ok=True)
        t.extractall(path=extract_dir)
else:
    print(f'Extraction directory already exists, skipping extraction.')

Extraction directory already exists, skipping extraction.


All the hyena images are now in `../data/hyena/hyena.coco/images/train2022` folder (`test2022` and `val2022` are empty). Image filenames are just the image running number.

The metadata is in `../data/hyena/hyena.coco/annotations/instances_train2022.json`, which contains the bounding box of the individuals in the images, as well as their the identities.

In [17]:

IMG_DIR = Path('../data/hyena/hyena.coco/images/train2022')
ANNO = Path('../data/hyena/hyena.coco/annotations/instances_train2022.json')
if not IMG_DIR.exists() or not ANNO.exists():
    print('Images or annotations not found:')
    print('IMG_DIR =', IMG_DIR.exists(), IMG_DIR)
    print('ANNO =', ANNO.exists(), ANNO)

Load COCO-style annotations and build mapping from filename to identity.

In [31]:
with open(ANNO, 'r') as f:
    metadata = json.load(f)
img2name = {a['image_id']: a['name'] for a in metadata['annotations']}

In [None]:
# Build label->index mapping
unique_labels = sorted({lab for _, lab in samples})
label2idx = {lab: i for i, lab in enumerate(unique_labels)}
NUM_CLASSES = len(unique_labels)

# Convert samples to (path, idx)
samples = [(p, label2idx[l]) for p, l in samples]

# Shuffle and split 90/10
random.seed(42)
random.shuffle(samples)
cut = int(len(samples) * 0.9)
train_samples = samples[:cut]
val_samples = samples[cut:]

print(f'Total images with annotation: {len(samples)}, classes: {NUM_CLASSES}')
print('Train / Val sizes =', len(train_samples), len(val_samples))

# Simple Dataset that reads images from path


class HyenaDataset(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        p, label = self.samples[idx]
        img = Image.open(p).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label



Definte the image transforms: resize to the image size suitable for the model, convert to PyTorch tensor, and normalize colors. For the training transforms, we also augment the data using horizontal flips.

In [None]:
IMG_SIZE = 224

val_transforms = T.Compose([    
    T.CenterCrop(IMG_SIZE),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_transforms = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

Now we create the datasets and data loaders.

In [None]:
# DataLoader params
BATCH_SIZE = 32
IMG_SIZE = 224
NUM_WORKERS = 4

train_ds = HyenaDataset(train_samples, transform=train_transforms)
val_ds = HyenaDataset(val_samples, transform=val_transforms)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
print('DataLoaders ready.')


## Model setup
Load a pretrained EfficientNetV2 from `timm`, replace the classifier head with a new linear layer sized to `NUM_CLASSES`. We keep the code defensive to handle a couple of naming conventions.

In [None]:
# Show some sample images and their labels (visual check)
import matplotlib.pyplot as plt
import numpy as np

if 'train_ds' in globals():
    fig, axs = plt.subplots(2, 5, figsize=(15, 6))
    axs = axs.ravel()
    for i in range(10):
        img, label = train_ds[i]
        # img is normalized tensor; un-normalize for display
        img_np = img.numpy().transpose(1, 2, 0)
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_np = img_np * std + mean
        img_np = np.clip(img_np, 0, 1)
        axs[i].imshow(img_np)
        axs[i].set_title(f'label={label}')
        axs[i].axis('off')
    plt.show()
else:
    print('train_ds not found. Run Data loaders cell first.')


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# Create model (small variant for speed; change to efficientnetv2_m or _l if you want)
MODEL_NAME = 'efficientnetv2_s'

# lazy guard: NUM_CLASSES must exist from dataset cell
try:
    NUM_CLASSES
except NameError:
    NUM_CLASSES = 256  # fallback; change to actual number after you build split

model = timm.create_model(MODEL_NAME, pretrained=True)
# replace classifier robustly
if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Linear):
    in_f = model.classifier.in_features
    model.classifier = nn.Linear(in_f, NUM_CLASSES)
elif hasattr(model, 'fc') and isinstance(model.fc, nn.Linear):
    in_f = model.fc.in_features
    model.fc = nn.Linear(in_f, NUM_CLASSES)
else:
    # timm models often have a .get_classifier / .reset_classifier API; try that
    try:
        model.reset_classifier(num_classes=NUM_CLASSES)
    except Exception as e:
        print('Could not find a known classifier attribute; inspect model to set classifier manually:', e)

model.to(device)
print(model)


### Freeze backbone and train head
We freeze all parameters except the classifier parameters (by name heuristic) and train only the head for a few epochs.

In [None]:
# Freeze everything except classifier
for name, p in model.named_parameters():
    if 'classifier' in name or 'head' in name or 'fc' in name:
        p.requires_grad = True
    else:
        p.requires_grad = False

# Simple optimizer for the trainable params only
head_params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(head_params, lr=1e-3)
criterion = nn.CrossEntropyLoss()

def evaluate(loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            out = model(xb)
            pred = out.argmax(dim=1)
            correct += (pred == yb).sum().item()
            total += yb.size(0)
    return correct / total if total else 0.0

def train_one_epoch(loader, model, opt, loss_fn):
    model.train()
    running_loss = 0.0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad()
        out = model(xb)
        loss = loss_fn(out, yb)
        loss.backward()
        opt.step()
        running_loss += loss.item() * xb.size(0)
    return running_loss / (len(loader.dataset) if hasattr(loader, 'dataset') else 1)

# Small sanity run (only if loaders exist)
if 'train_loader' in globals():
    EPOCHS_HEAD = 3
    for ep in range(EPOCHS_HEAD):
        loss = train_one_epoch(train_loader, model, optimizer, criterion)
        val_acc = evaluate(val_loader, model)
        print(f'Head epoch {ep+1}/{EPOCHS_HEAD} - loss {loss:.4f} - val_acc {val_acc:.4f}')
else:
    print('No dataloaders found. Create splits and DataLoaders first.')


### Unfreeze and finetune the whole model
Now we unfreeze all parameters and continue training with a smaller learning rate. This typically improves performance but is slower.

In [None]:
# Unfreeze all params
for p in model.parameters():
    p.requires_grad = True

# New optimizer with lower LR
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

# Train a few more epochs
if 'train_loader' in globals():
    EPOCHS_FULL = 3
    for ep in range(EPOCHS_FULL):
        loss = train_one_epoch(train_loader, model, optimizer, criterion)
        val_acc = evaluate(val_loader, model)
        print(f'Finetune epoch {ep+1}/{EPOCHS_FULL} - loss {loss:.4f} - val_acc {val_acc:.4f}')
else:
    print('No dataloaders found. Create splits and DataLoaders first.')


## Next steps and tips
- Increase `IMG_SIZE`, `BATCH_SIZE`, and `EPOCHS` for higher final accuracy.
- Use a learning rate schedule (Cosine/Step), weight decay, and label smoothing for better generalization.
- Consider stratified splits by individual to ensure balanced train/val per-class.
- If memory is tight, use mixed precision (`torch.cuda.amp`).

That's it — the notebook is intentionally simple so you can run and understand each step.