# Defining a custom dataset


In [None]:
%load_ext autotime
%load_ext autoreload
%autoreload 2

In [None]:
import torch, torch.nn as nn
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from utils import *

from model import PetsDataset
from torchvision.utils import make_grid
from model import Net
import torch.optim as optim
import copy

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
DEVICE_NAME = "cpu" if os.environ["CUDA_VISIBLE_DEVICES"] == "" else "cuda"

os.environ["http_proxy"] = "http://proxy.dev.dszn.cz:3128"
os.environ["HTTP_PROXY"] = "http://proxy.dev.dszn.cz:3128"
os.environ["https_proxy"] = "http://proxy.dev.dszn.cz:3128"
os.environ["HTTPS_PROXY"] = "http://proxy.dev.dszn.cz:3128"

In [None]:
from config import get_config
config = get_config()

{'batch_size_train': 32, 'batch_size_eval': 1}

In [None]:
device = torch.device(DEVICE_NAME)

## Data load

In [None]:
dataset = PetsDataset(config['dataset_path'])

trainset_size = int(len(dataset) - config['val_set_coef'] * len(dataset))
trainset, validset = torch.utils.data.random_split(dataset, [trainset_size, len(dataset) - trainset_size])
print(f"trainset_sz={len(trainset)}, validset_sz={len(validset)}")

In [None]:
train_loader = DataLoader(trainset, batch_size=config['batch_size_train'], shuffle=True, num_workers=0)
valid_loader = DataLoader(validset, batch_size=config['batch_size_eval'], shuffle=False, num_workers=0)

In [None]:
x,s,b, mask = next(iter(train_loader))
x.shape, mask.shape

In [None]:
x = x.to(device)
mask = mask.to(device)
to_viz = torch.cat([x, mask_to_img(mask)], dim=0)

imshow(make_grid(to_viz))

# PyTorch LEGO (for segmentation)


In [None]:
model = Net().to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=config['patience'], factor=config['factor'])

In [None]:
seg_loss_fn = nn.CrossEntropyLoss()     
species_loss_fn = nn.CrossEntropyLoss()      
breed_loss_fn = nn.CrossEntropyLoss()         

def combined_loss(seg_pred, seg_target, species_pred, species_target, breed_pred, breed_target, alpha=1.0, beta=1.0):
    seg_ce = seg_loss_fn(seg_pred, seg_target)
    seg_dice = dice_loss(seg_pred, seg_target)
    seg_loss = seg_ce + seg_dice

    species_loss = species_loss_fn(species_pred, species_target)

    breed_loss = breed_loss_fn(breed_pred, breed_target)

    total_loss = seg_loss + alpha * species_loss + beta * breed_loss
    return total_loss

## Training

In [None]:
epochs = config['epoch']
best_mean_iou = 0.0
min_delta = 0.001
epochs_no_improve = 0
patience = 5

model.breed_idx2name = dataset.breed_idx2name
model.breed_name2idx = dataset.breed_name2idx
model.cat_breed_names = dataset.cat_breed_names
model.dog_breed_names = dataset.dog_breed_names

for epoch in range(epochs):
    model.train()
    for images, species, breed, masks in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{epochs}"):
        images, species, breed, masks = images.to(device), species.to(device), breed.to(device), masks.to(device)

        sp_pred, br_pred, seg_pred = model(images)

        loss_sp = species_loss_fn(sp_pred, species)
        loss_br = breed_loss_fn(br_pred, breed)
        loss_seg = seg_loss_fn(seg_pred, masks)
        loss = combined_loss(seg_pred, masks, sp_pred, species, br_pred, breed)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    species_acc_sum = 0.0
    breed_top3_acc_count = 0
    ious = []

    with torch.no_grad():
        for images, species, breed, masks in valid_loader:
            images, species, breed, masks = images.to(device), species.to(device), breed.to(device), masks.to(device)
            
            species_pred_str, top3_breeds_str, seg_pred_class = model.predict(images)
            seg_pred_class = seg_pred_class.to(device)

            species_map = {'dog': 0, 'cat': 1}
            species_preds_idx = torch.tensor(species_map[species_pred_str]).to(device)

            species_acc_sum += accuracy(species_preds_idx, species)


            top3_breeds_indices = []
            for bnames in top3_breeds_str:
                idxs = model.breed_name2idx[bnames]
                top3_breeds_indices.append(idxs)

            breed_top3_acc_count += breed.item() in top3_breeds_indices

            seg_classes = [0, 1, 2]
            img_ious = []
            for c in seg_classes:
                _, _, iou_c = compute_seg_metrics(seg_pred_class, masks, c)
                img_ious.append(iou_c)
            ious.append(np.mean([iou.cpu().numpy() for iou in img_ious]))

        species_acc = species_acc_sum / len(valid_loader)
        breed_top3_acc = breed_top3_acc_count / len(valid_loader)
        mean_iou = np.mean(ious)
        print(f"Epoch {epoch+1}: Species Acc: {species_acc:.2f}, Breed Top-3 Acc: {breed_top3_acc:.2f}, Mean IoU: {mean_iou:.2f}")
    if mean_iou - best_mean_iou > min_delta:
        best_mean_iou = mean_iou
        best_model_wts = copy.deepcopy(model.state_dict())
        epochs_no_improve = 0
        torch.save(best_model_wts, "best_weights.pth")
    else:
        epochs_no_improve += 1
        print(f"No improvement in Mean IoU for {epochs_no_improve} epoch(s).")

    if epochs_no_improve >= patience:
        print("Early stopping triggered.")
        break

    scheduler.step(mean_iou)

## Evaluation

In [None]:
model.eval()
species_acc_sum = 0.0
breed_top3_acc_count = 0
ious = []

with torch.no_grad():
    for images, species, breed, masks in valid_loader:
        images, species, breed, masks = images.to(device), species.to(device), breed.to(device), masks.to(device)
        
        species_pred_str, top3_breeds_str, seg_pred_class = model.predict(images)
        seg_pred_class = seg_pred_class.to(device)

        species_map = {'dog': 0, 'cat': 1}
        species_preds_idx = torch.tensor(species_map[species_pred_str]).to(device)

        species_acc_sum += accuracy(species_preds_idx, species)


        top3_breeds_indices = []
        for bnames in top3_breeds_str:
            idxs = model.breed_name2idx[bnames]
            top3_breeds_indices.append(idxs)

        breed_top3_acc_count += breed.item() in top3_breeds_indices

        seg_classes = [0, 1, 2]
        img_ious = []
        for c in seg_classes:
            _, _, iou_c = compute_seg_metrics(seg_pred_class, masks, c)
            img_ious.append(iou_c)
        ious.append(np.mean([iou.cpu().numpy() for iou in img_ious]))

    species_acc = species_acc_sum / len(valid_loader)
    breed_top3_acc = breed_top3_acc_count / len(valid_loader)
    mean_iou = np.mean(ious)
    print(f"Epoch {epoch+1}: Species Acc: {species_acc:.2f}, Breed Top-3 Acc: {breed_top3_acc:.2f}, Mean IoU: {mean_iou:.2f}")

In [None]:
model.cat_breed_indices

In [None]:
images, species, breed, masks = next(iter(valid_loader))
images, species, breed, masks = images.to(device), species.to(device), breed.to(device), masks.to(device)

model.predict(images[0])