# 5. Computer Vision classification challenge

## Exploratory analysis and baseline model training and evaluation

Use a sample of images from the swimming-pool dataset to develop a model that classifies whether an image contains a swimming pool or not. Use the provided labels to validate your model.

## Setup

The cell below is specifically for running the notebook in a Kaggle environment. If running the databook in Kaggle, make sure the `stefan87/cct-ds-code-challenge` dataset is attached to the Kaggle notebook. For all other environments, it is assumed that all dependencies have already been installed.

In [None]:
import os

is_kaggle = True if os.environ.get('KAGGLE_URL_BASE', '') else False

if is_kaggle:
    if not os.path.exists("/kaggle/working/ds_code_challenge"):
        print("This notebook is running on Kaggle.")
        print("Installing additional packages...")
        os.system("curl -LsSf https://astral.sh/uv/install.sh | sh")
        os.system("git clone https://github.com/stefan027/ds_code_challenge.git")
        os.system("cd ds_code_challenge && uv pip install -r requirements.txt")
    os.chdir("/kaggle/working/ds_code_challenge")

In [None]:
import sys

# Add the repo root to the Python path
sys.path.append("../")

import os
from pathlib import Path
from importlib.metadata import version
from typing import Union
import random

from fastai.vision.all import (
    set_seed, DataLoader, DataLoaders, Learner,
    BCEWithLogitsLossFlat, Adam,
)
from PIL import Image
import pandas as pd
from IPython.display import display
import torch

from src.data import TiffImageDataset
from src.modeling import (
    create_timm_model, freeze_except_head, unfreeze_all
)
from src.metrics import (
    balanced_accuracy, ap_score, precision, recall, roc_auc
)

In [None]:
pckgs = ["torch", "torchvision", "fastai", "timm"]
for pckg in pckgs:
    print(f"{pckg}=={version(pckg)}")

In [None]:
KAGGLE_IMAGE_DIR = Path("/kaggle/input/cct-ds-code-challenge/images/swimming-pool")
IMAGE_DIR = KAGGLE_IMAGE_DIR if is_kaggle else Path("../data/images/swimming-pool")
CLASSES = ["no", "yes"]
POSITIVE_CLASS = "yes"
VALIDATION_PCT = 0.2

Set the random seed for reproducibility. The `set_seed` function sets the seed for `numpy`, `random`, and `torch`.

In [None]:
set_seed(42, reproducible=True)

## Load data

In [None]:
image_paths = [IMAGE_DIR/f"{category}/{f}" for category in CLASSES for f in os.listdir(IMAGE_DIR/category)]

Check the file extensions and get value counts:

In [None]:
# Check the file extensions and get value counts
pd.Series([fp.suffix for fp in image_paths]).value_counts()

Remove non-TIF files:

In [None]:
# Remove non-TIF files
image_paths = [fp for fp in image_paths if fp.suffix.lower() == ".tif"]
print(f"Number of image files: {len(image_paths)}")

Create a function to derive the binary label from an image file path:

In [None]:
def get_label(fp: Union[str, Path], positive_class: str) -> str:
    """Extracts the label from the file path."""
    label = Path(fp).parts[-2]
    return int(label == positive_class)

In [None]:
# example usage
get_label(image_paths[0], POSITIVE_CLASS)

Get the labels for each image file and look at the class distribution:

In [None]:
labels = [get_label(fp, POSITIVE_CLASS) for fp in image_paths]
# Look at class distribution
pd.Series(labels).value_counts()

Create training- and validation splits. While there is class imbalance, given that we only have two classes and we have a relatively large sample of images, doing a simple random split should result in similarly balanced training and validation splits.

In [None]:
valid_idx = sorted(random.sample(range(len(image_paths)),
                                 k=int(len(image_paths)*VALIDATION_PCT)))
train_idx = sorted(list(set(range(len(image_paths))) - set(valid_idx)))
train_fps, train_labels = [image_paths[i] for i in train_idx], [labels[i] for i in train_idx]
valid_fps, valid_labels = [image_paths[i] for i in valid_idx], [labels[i] for i in valid_idx]
print(f"Training set size:   {len(train_fps)}")
print(f"Validation set size: {len(valid_fps)}")
print("Percentage positive class:")
print(f"  - Training:   {sum(train_labels) / len(train_labels) * 100:.1f}%")
print(f"  - Validation: {sum(valid_labels) / len(valid_labels) * 100:.1f}%")

Look at some sample images:

In [None]:
# Look at some sample images
sz = (448, 448)
for _ in range(2):
    i = random.randint(0, len(image_paths)-1)
    print(f"Path: {image_paths[i]}, Label: {labels[i]}")
    display(Image.open(image_paths[i]).resize(sz))

## Create `Dataset`s and `DataLoaders`

In [None]:
ds_trn = TiffImageDataset(paths=train_fps, labels=train_labels)
ds_val = TiffImageDataset(paths=valid_fps, labels=valid_labels)

In [None]:
dls_trn = DataLoader(ds_trn, batch_size=16, shuffle=True)
dls_val = DataLoader(ds_val, batch_size=16, shuffle=False)
dls = DataLoaders(dls_trn, dls_val)

In [None]:
x, y = dls.one_batch()
print(f"Image tensor shape: {x.shape}")
print(f"Label tensor shape: {y.shape}")

## Setup model

In [None]:
model, cfg = create_timm_model("convnext_tiny", n_out=1, pretrained=True)

In [None]:
with torch.no_grad():
    logits = model(x)
logits.shape, logits

In [None]:
learn = Learner(
    dls, model, loss_func=BCEWithLogitsLossFlat(), opt_func=Adam,
    metrics=[balanced_accuracy(), ap_score(), precision(), recall(), roc_auc()]
)

In [None]:
learn.summary()

## Train

Freeze all model parameters except for the randomly initialised classification head. We will first train the classification head, and then fine-tune all model parameters.

In [None]:
model = freeze_except_head(model)

Use `fastai`'s learning rate finder to help set a reasonable initial learning rate:

In [None]:
learn.lr_find()

Train the classification head for one epoch

In [None]:
# Train the classification head for one epoch
learn.fit_one_cycle(1, 1e-3)

Unfreeze all model parameters and fine-tune for a further 3 epochs:

In [None]:
model = unfreeze_all(model)
learn.fit_one_cycle(3, 1e-4)

Save the model weights for further evaluation and inference:

In [None]:
os.makedirs("./models", exist_ok=True)
torch.save(learn.model.state_dict(), "./models/classification_model_v0.pth")