# Flower classification model training notebook
In this notebook we finetune a model to classify flowers

In [None]:
import os
from os.path import join as opj
import shutil
from collections import Counter
from datetime import datetime

In [None]:
import kagglehub
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
from torchvision.transforms import v2
from torchvision.datasets import ImageFolder
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score
from tqdm.autonotebook import tqdm
import splitfolders
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# 1. Download and prepare Data

## 1.1 Download dataset
Download dataset from kaggle and visualise some samples

In [None]:
def visualise_samples(dataset, samples_per_class=5, title="Dataset samples"):
    fig, axes = plt.subplots(samples_per_class, len(dataset.classes), figsize=(15, samples_per_class * 2))
    fig.suptitle(title, fontsize=15)
    axes = axes
    for i in range(0, len(dataset.classes)):
        idxs = np.where(np.array(dataset.targets) == i)[0]
        for j, idx in enumerate(idxs[:samples_per_class]):
            if j == 0:
                axes[j][i].set_title(idx_to_class[i], fontsize=10)
            axes[j][i].imshow(dataset[idx][0])
            axes[j][i].axis('off')
    plt.show()

In [None]:
base_path = kagglehub.dataset_download('alxmamaev/flowers-recognition')
dataset_dir = opj(base_path, "flowers")
dataset = ImageFolder(dataset_dir)

idx_to_class = {v:k for k, v in dataset.class_to_idx.items()}
targets_names = [idx_to_class[v] for v in dataset.targets]
targets = np.array(dataset.targets)
n_classes = len(dataset.classes)

In [None]:
print(f"Dataset summary:\n {len(dataset)} samples")
for k, v in Counter(dataset.targets).items():
    print (f" * {idx_to_class[k]}: {v} samples ({int(100 * v / len(dataset))}%)")
    
visualise_samples(dataset)

## 1.2 Train/test/val split

In [None]:
splitted_path = './data/flowers'
splitfolders.ratio(
    input=dataset_dir, output=splitted_path, seed=RANDOM_SEED, ratio=(0.7, 0.15, 0.15)
)

In [None]:
train_ds = ImageFolder(opj(splitted_path, "train"))
val_ds = ImageFolder(opj(splitted_path, "val"))
test_ds = ImageFolder(opj(splitted_path, "test"))
class_distribution = {}

for ds, split_name in zip((train_ds, val_ds, test_ds), ("train", "test", "val")):
    split_count = Counter(ds.targets)
    total_samples = sum(split_count.values())
    class_distribution[split_name] = {}
    for idx, count in split_count.items():
        class_distribution[split_name][idx_to_class[idx]] = \
            f"{count} ({int(100 * count / total_samples)}%)"
    class_distribution[split_name]["total"] = total_samples

class_distribution_df = pd.DataFrame(class_distribution).T
class_distribution_df

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test_ds, batch_size=BATCH_SIZE, shuffle=True
)

### 1.3 Add augmentations

In [None]:
augmentations = v2.Compose([
    v2.RandomRotation(degrees=(0, 15)),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomVerticalFlip(p=0.5),
    v2.RandomResizedCrop(size=(224, 224), scale=(0.5, 1.0))
])

In [None]:
fig, axes = plt.subplots(5, 5, figsize=(15, 10))
fig.suptitle("Augmentations example", fontsize=15)
axes = axes.flatten()
dataset.transform = augmentations
for i in range(len(axes)):
    axes[i].imshow(dataset[0][0])
    axes[i].axis('off')

### 1.4 Create dataloaders

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False
)
test_loader = torch.utils.data.DataLoader(
    test_ds, batch_size=BATCH_SIZE, shuffle=False
)

# 2. Train models

# 3. Similarity Search

# 4. Convert to ONNX