# Load Libraries

In [None]:
!pip install evaluate

In [None]:
%matplotlib inline

import os, cv2
import glob

import shutil
import torch

import pandas as pd
import numpy as np

from datasets import load_dataset
import evaluate
from transformers import Swinv2ForImageClassification, AutoModelForImageClassification, AutoFeatureExtractor, TrainingArguments, Trainer
from pathlib import Path

from matplotlib import pyplot as plt

from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

# Setup

In [None]:
model_name_or_path = "microsoft/swinv2-tiny-patch4-window8-256" # pre-trained model from which to fine-tune
batch_size = 32 # batch size for training and evaluation

# Create dataset

In [None]:
dataset = load_dataset("imagefolder", data_dir='/kaggle/input/buildings/dataset')

In [None]:
dataset

In [None]:
# Load the metrics from the evaluate library (accuracy, precision, recall, F1)
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")

def compute_metrics(p):
    predictions = np.argmax(p.predictions, axis=1)  # Get the predicted class
    references = p.label_ids  # True labels

    # Compute each metric
    accuracy = accuracy_metric.compute(predictions=predictions, references=references)
    precision = precision_metric.compute(predictions=predictions, references=references, average="macro")
    recall = recall_metric.compute(predictions=predictions, references=references, average="macro")
    f1 = f1_metric.compute(predictions=predictions, references=references, average="macro")

    # Return all the computed metrics
    return {
        'accuracy': accuracy['accuracy'],
        'precision': precision['precision'],
        'recall': recall['recall'],
        'f1': f1['f1'],
    }

In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)
feature_extractor

In [None]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
train_transforms = Compose(
        [
            RandomResizedCrop(feature_extractor.size['height']),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(feature_extractor.size['height']),
            CenterCrop(feature_extractor.size['height']),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [
        val_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

In [None]:
train_ds = dataset['train']
val_ds = dataset['test']

In [None]:
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

In [None]:
labels = dataset['train'].features['label'].names

model = AutoModelForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)},
    ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)

model

In [None]:
model_name = model_name_or_path.split("/")[-1]

epochs = 40

args = TrainingArguments(
    f"{model_name}-finetuned-buildings",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=4e-4,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
    warmup_ratio=0.1,
    logging_strategy='epoch',
    load_best_model_at_end=True,
#     metric_for_best_model="accuracy",
    metric_for_best_model="f1",
    push_to_hub=False,
)

In [None]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

In [None]:
!wandb login <ENDER_WANDB_KEY>

In [None]:
train_results = trainer.train()

In [None]:
!mkdir best

In [None]:
trainer.save_model('./best/')

In [None]:
torch.save(model.state_dict(), './best/pytorch_model.bin')

In [None]:
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

In [None]:
%%time

metrics = trainer.evaluate()
# some nice to haves:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)