# Image classification baseline

This is a first test with the [Pl@ntNet-300k](https://github.com/plantnet/PlantNet-300K/tree/main) dataset. Pretrained models are available at [their website](https://lab.plantnet.org/seafile/d/01ab6658dad6447c95ae/).  
There are several different pretrained models available, but we will start with the [resnet152](https://lab.plantnet.org/seafile/d/01ab6658dad6447c95ae/files/?p=%2Fresnet152_weights_best_acc.tar) version.  


## 0 - Prerequisites
 1. Download the pretrained [model](https://lab.plantnet.org/seafile/d/01ab6658dad6447c95ae/files/?p=%2Fresnet152_weights_best_acc.tar) and put it in the `/data/pretrained_image_models` folder.
  
https://lab.plantnet.org/seafile/d/01ab6658dad6447c95ae/files/?p=%2Fvit_base_patch16_224_weights_best_acc.tar
## 1 - Setup

In [None]:
import pandas as pd

import torch
from torch import tensor
from torch.nn import MSELoss
from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import Dataset, DataLoader, random_split

from timm import create_model

from lightning import LightningModule, Trainer, seed_everything
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
seed_everything(42)

from torcheval.metrics import R2Score

import imageio.v3 as imageio

from tqdm.notebook import tqdm
tqdm.pandas()

import wandb

from os import path
import gc

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
BASE_PATH = "../data"
IMAGES_PATH = BASE_PATH + "/train_images"

## 2 - Custom Dataset
We created a custom ImageDataset based on pytorchs Dataset. This dataset is used to load the images from the `data` folder and apply the necessary transformations. The dataset is then used to create a DataLoader which is used to feed the images to the model during training.

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root, dataset_path, transforms):
        super().__init__()
        self.transforms = transforms
        self.dataset = pd.read_csv(dataset_path)
        self.dataset["file_path"] = root + "/" + self.dataset["id"].astype(str) + ".jpeg"
        self.dataset["jpeg_bytes"] = self.dataset["file_path"].progress_apply(lambda fp: open(fp, "rb").read())

        self.target_columns = ["X4", "X11", "X18", "X26", "X50", "X3112"]
        if self.target_columns[0] not in self.dataset.columns:
            self.target_columns = [f"{col}_mean" for col in self.target_columns]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        sample = self.dataset.iloc[index]

        image = self.transforms(image=imageio.imread(sample["jpeg_bytes"]))["image"]

        targets = tensor(sample[self.target_columns].astype(float).values, dtype=torch.float32)
        
        return image, targets

## 3 - Model
Our ImageModel gets a model_type as a string, which is used to load the correct model from the `pretrained_image_models` folder. The model is then loaded and the last layer is replaced with a new layer which has the correct number of output classes.

In [None]:
class ImageModel(LightningModule):
    def __init__(self, model_type, model_path, learning_rate = 1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.load_model(model_type, model_path)
        self.set_last_layer(model_type)
        
        self.loss = MSELoss()
        
    def forward(self, batch):
        x, y = batch
        
        output = self.model(x)
        
        loss = self.loss(output, y)

        metric = R2Score()
        metric.update(output, y)
        
        return loss, metric.compute()

    def training_step(self, batch):
        loss, r2 = self(batch)
        self.log("Train loss", loss, prog_bar=True)
        self.log("Train R2", r2)
        
        return loss
        
    def validation_step(self, batch):
        loss, r2 = self(batch)
        self.log("Validation loss", loss, on_epoch=True)
        self.log("Validation R2", r2, on_epoch=True)
    
    def configure_optimizers(self):
        match self.hparams.model_type:
            case "resnet152" | "resnet18":
                pretrained, classifier = self.get_split_params("fc")
            case "vit_base_patch16_224":
                pretrained, classifier = self.get_split_params("head")            
            case _:
                pretrained, classifier = self.get_split_params("classifier")
        
        optimizer_params = [{"params": pretrained, "lr": self.hparams.learning_rate / 10}]
        optimizer_params += [{"params": classifier, "lr": self.hparams.learning_rate}]
        optimizer = Adam(params=optimizer_params)

        scheduler = LinearLR(optimizer)
        
        return [optimizer], [scheduler]

    def get_split_params(self, classifier_name):
        pretrained = []
        classifier = []
        for name, param in self.model.named_parameters():
            if name.startswith(classifier_name):
                pretrained.append(param)
            else:
                classifier.append(param)
        
        return pretrained, classifier
    
    def load_model(self, model_type, model_path):
        filename = f"{model_path}/{model_type}.tar"
        if not path.exists(filename):
            raise FileNotFoundError

        self.get_model(model_type, filename)

    def get_model(self, model_type, filename):
        d = torch.load(filename)
        
        match model_type:
            case "resnet152":
                model = resnet152(num_classes=1081)
            case "resnet18":
                model = resnet18(num_classes=1081)
            case "densenet121":
                model = densenet121(num_classes=1081)
            case "densenet201":
                model = densenet201(num_classes=1081)
            case _:
                model = create_model(model_type, pretrained=True, num_classes=1081)

        self.model = model
        self.model.load_state_dict(d["model"])

    def get_sequential_layer(self):
        sequence = Sequential()
        for i in range(self.hparams.num_layers):
            sequence.append(Linear(self.model.classifier.in_features, self.model.classifier.in_features))
            sequence.append(ReLU())
        sequence.append(Linear(self.model.classifier.in_features, 6))

        return sequence

    def set_last_layer(self, model_type):
        match model_type:
            case "resnet152" | "resnet18":
                nr_ftrs = self.model.fc.in_features
                #self.set_hidden_dim(nr_ftrs)
                self.model.fc = Linear(nr_ftrs, 6)
            case "vit_base_patch16_224":
                nr_ftrs = self.model.head.in_features
                #self.set_hidden_dim(nr_ftrs)
                self.model.head = Linear(nr_ftrs, 6)                
            case _:
                nr_ftrs = self.model.classifier.in_features
                #self.set_hidden_dim(nr_ftrs)
                self.model.classifier = Linear(nr_ftrs, 6)

    def set_hidden_dim(self, in_ftr):
        self.hparams.hidden_dim = self.hparams.hidden_dim if self.hparams.hidden_dim > 2 else in_ftr * self.hparams.hidden_dim

## 4 - Training
### 4.1 - Constants

In [None]:
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
IMAGE_SIZE = 224

### 4.2 - Transformations

In [None]:
transforms = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomSizedCrop(
            [448, 512],
            IMAGE_SIZE, IMAGE_SIZE, w2h_ratio=1.0, p=0.75),
        A.Resize(IMAGE_SIZE, IMAGE_SIZE),
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.25),
        A.ImageCompression(quality_lower=85, quality_upper=100, p=0.25),
        A.ToFloat(),
        A.Normalize(mean=MEAN, std=STD, max_pixel_value=1),
        ToTensorV2(),
    ])

### 4.3 - Train/Validation Split

In [None]:
image_dataset = ImageDataset(BASE_PATH + "/train_images", BASE_PATH + "/cleaned/cleaned_train.csv", transforms=transforms)
train, val = random_split(image_dataset, [0.9, 0.1])

### 4.4 - Sweep function

In [None]:
def sweep(config = None):
    wandb.finish()    
    with wandb.init() as run:
        wandb.define_metric("Validation R2", summary="max")
        
        config = wandb.config
        
        model = ImageModel(config["model_type"], BASE_PATH + "/models", config["learning_rate"])

        logger = WandbLogger(project="aicomp", log_model="all")

        callbacks = [ModelCheckpoint(monitor="Validation R2", mode="max")]
        callbacks += [EarlyStopping(monitor="Validation loss", mode="min", patience=30)]
        callbacks += [EarlyStopping(monitor="Validation R2", mode="max", patience=30)]
        callbacks += [EarlyStopping(monitor="Train loss", mode="min", patience=30, stopping_threshold=0.01, check_on_train_epoch_end=True)]
        callbacks += [EarlyStopping(monitor="Train R2", mode="max", patience=30, stopping_threshold=0.99, check_on_train_epoch_end=True)]
        callbacks += [LearningRateMonitor(logging_interval="step")]
        
        trainer = Trainer(max_epochs=150, logger=logger, num_sanity_val_steps=0, callbacks=callbacks)

        num_workers = 120
        
        train_dataloader = DataLoader(train, batch_size=120, shuffle=True, num_workers=num_workers)
        val_dataloader = DataLoader(val, batch_size=120, num_workers=num_workers)
        
        trainer.fit(model, train_dataloader, val_dataloader)

        run.finish()

        del model
        del trainer
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
wandb.login()
wandb.agent("nlpfs24/aicomp/stjn09mr", sweep)

In [None]:
wandb.teardown()