<a href="https://colab.research.google.com/github/simecek/dspracticum2023/blob/main/lesson03/Finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install timm

In [2]:
import timm
import os

import matplotlib.pyplot as plt
import PIL
from PIL import Image
import json

import torch
import torchvision
from torchvision import datasets, transforms
from torchsummary import summary

from fastai.vision.all import *

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

## ConvNeXt


In [None]:
model_name = "convnext_tiny.fb_in22k"
convnext = timm.create_model(model_name, pretrained=True).to(device)

In [None]:
# you can also list all models available or search through wildcard
# timm.list_models('*convnext*')

In [None]:
summary(convnext, (3, 256, 256))

### download label mapping for the model

In [None]:
!wget https://dl.fbaipublicfiles.com/convnext/label_to_words.json
imagenet_labels = json.load(open('label_to_words.json'))

### download random image and predict it via ConvNeXt

In [None]:
!wget --output-document=test.jpeg https://upload.wikimedia.org/wikipedia/commons/d/d7/Squirrel_in_Seurasaari_autumn.JPG
img = PIL.Image.open('test.jpeg')

# Define transforms for image
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

transformations = [
              transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), # resize smaller edge to 256
              transforms.ToTensor(),
              transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
              ]

transformations = transforms.Compose(transformations)

img_tensor = transformations(img).unsqueeze(0).to(device)

### predict label for our image

In [None]:
output = torch.softmax(convnext(img_tensor), dim=1)
top5 = torch.topk(output, k=5)
top5_prob = top5.values[0]
top5_indices = top5.indices[0]

for i in range(5):
    labels = imagenet_labels[str(int(top5_indices[i]))]
    prob = "{:.2f}%".format(float(top5_prob[i])*100)
    print(labels, prob)

plt.imshow(img)



---



## Our custom dataset - Tom & Jerry

https://www.kaggle.com/datasets/balabaskar/tom-and-jerry-image-classification

---



In [10]:
DATASET = 'balabaskar/tom-and-jerry-image-classification'
ZIP_PATH = './tom-and-jerry-image-classification.zip'
IMAGES_PATH = './tom_and_jerry/tom_and_jerry'

In [None]:
os.environ['KAGGLE_USERNAME'] = 'evaklimentov'
os.environ['KAGGLE_KEY'] = 'c3161c890c8b21e1e5cba18c9a7505c0'

!kaggle datasets download -d {DATASET} -p ./

In [12]:
import zipfile

with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
    zip_ref.extractall('./')

In [None]:
images = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=Resize(256, method='squish'))

dls = images.dataloaders(IMAGES_PATH, bs=64)

dls.show_batch(max_n=6)

In [None]:
print(len(dls.train.dataset))
print(len(dls.valid.dataset))

## Load ConvNeXt model and fine-tune it

In [None]:
learn = vision_learner(dls, convnext_tiny, metrics=accuracy)
learn.fine_tune(3, freeze_epochs=1)

## See how our model performs:


In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix(figsize=(8,8), dpi=80)

### what is hard to predic?

In [None]:
interp.plot_top_losses(8, figsize=(13,13))

## Predict a new image

In [None]:
from google.colab import files

uploaded = files.upload()

In [None]:
img = PILImage.create(list(uploaded.values())[0])
img

In [None]:
pred,pred_idx,probs = learn.predict(img)
pred,pred_idx,probs

## Data augmentation

use image transformations from https://docs.fast.ai/vision.augment.html

In [None]:
tfms = [] # TODO

dls = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=[Resize(256, method='squish')],
    batch_tfms=tfms
).dataloaders(IMAGES_PATH, bs=64)

dls.show_batch(max_n=6, unique=True)

In [None]:
learn = vision_learner(dls, convnext_tiny, metrics=accuracy)
learn.fine_tune(3, freeze_epochs=1)

### Diving into `fine_tune`

Let's uncover what's inside:

`fine_tune` = `learn.freeze(), learn.fit_one_cycle(), learn.unfreeze(), learn.fit_one_cycle()`

but at first, let's have a look at what happens with the learning rate during the training we performed

In [None]:
learn.recorder.plot_sched(keys='lr')

In [None]:
learn = vision_learner(dls, convnext_tiny, metrics=accuracy)
learn.freeze()
learn.summary()

In [None]:
learn.fit(1, 0.5)

In [None]:
learn.unfreeze()
learn.fit(3, 0.5)