# Fine-Tuning
This notebook is used to fine-tune the model on the dataset using the huggingface trainer API.

### Data Preperation

In [None]:
from datasets import load_dataset

In [None]:
# Load the dataset
dataset_train = load_dataset("imagefolder", data_dir="../data/images", split='train')
dataset_test = load_dataset("imagefolder", data_dir="../data/images", split='test')

In [None]:
# Check how many labels/number of classes
num_classes = len(set(dataset_train['label']))
labels = dataset_train.features['label']
num_classes, labels

In [None]:
dataset_train[0]

In [None]:
dataset_train[0]['image']

### Training

In [None]:
from transformers import ViTImageProcessor

model_name_or_path = 'google/vit-base-patch16-224'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)

In [None]:
from PIL import Image

In [None]:
def transform(example_batch):
    """
    This function takes in a batch of images and labels,
    and returns the images as pixel values and the labels
    as a tensor.
    
    Parameters:
    ----------
    example_batch: dict
        A batch of images and labels from the dataset

    Returns:
    -------
    inputs: dict
        A dict of pixel values and labels 
    """
    # Take a list of PIL images and turn them to pixel values
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')

    # Include the labels
    inputs['labels'] = example_batch['label']
    return inputs

In [None]:
prepared_ds_train = dataset_train.with_transform(transform)
prepared_ds_test = dataset_test.with_transform(transform)

In [None]:
import torch

def collate_fn(batch):
    """
    This function takes in a batch of data and collates the data
    into a batch of tensors. It is used to prepare the data for
    the model.

    Parameters:
    ----------
    batch: list
        A batch of data from the dataset

    Returns:
    -------
    dict
        A dict of pixel values and labels 
    """
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

In [None]:
import numpy as np
from datasets import load_metric

# Load the accuracy metric
metric = load_metric("accuracy")
def compute_metrics(p):
    """
    This function takes in a prediction object and computes
    the accuracy of the model.

    Parameters:
    ----------
    p: Prediction
        A prediction object from the model

    Returns:
    -------
    float
        The accuracy of the model
    """
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [None]:
from transformers import ViTForImageClassification

# Labels in the dataset
labels = dataset_train.features['label'].names

# Load the model and pass in the labels
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)}, # Use to convert output ids to labels
    label2id={c: str(i) for i, c in enumerate(labels)}, # Use to convert labels to output ids
    ignore_mismatched_sizes=True # Can be set to True when using an image_size different than the original pretrained model
)

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds_train,
    eval_dataset=prepared_ds_test,
    tokenizer=processor,
)

In [None]:
train_results = trainer.train()

In [None]:
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

In [None]:
metrics = trainer.evaluate(prepared_ds_test)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)