# Transfer Learning

**This notebook contains uses of transfer learning to create better model for custom dataset.**

### Imports

In [71]:
import sys
sys.path.insert(0, '../scripts')

import os
import matplotlib.pyplot as plt
import torch
import torchvision
from torchinfo import summary
from scripts.data_setup import create_dataloaders
from scripts.engine import train
from scripts.utils import accuracy_fn, plot_loss_and_accuracy_curves
from pathlib import Path
from PIL import Image
import random
import zipfile

### Device agnostic code

In [72]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

### Getting data

In [73]:
data_name = 'baklava_churros_cheesecake'
data_path = Path('../data/processed')
images_path = data_path / data_name
zip_path = Path(str(images_path) + '.zip')

if not images_path.is_dir():
    images_path.mkdir(parents=True, exist_ok=True)
    if zip_path.exists():
        with zipfile.ZipFile(zip_path, 'r') as zip:
            zip.extractall(images_path)
    else:
        print(f'{data_name} is wrong name or custom dataset was not created.\n'
              'Use custom dataset notebook to create dataset.')
        images_path.rmdir()

train_dir = images_path / 'train'
test_dir = images_path / 'test'

### Creating DataLoaders

In [74]:
weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
transforms = weights.transforms()

train_dataloader, test_dataloader, class_names = create_dataloaders(train_dir=train_dir,
                                                                    test_dir=test_dir,
                                                                    transform=transforms,
                                                                    batch_size=32)

### Setting up pretrained model

In [75]:
models_dir_path = Path('../models')
models_dir_path.mkdir(parents=True,
                      exist_ok=True)

os.environ['TORCH_HOME'] = str(models_dir_path)
model = torchvision.models.efficientnet_b0(weights=weights).to(device)
model

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat

In [79]:
def model_summary(model: torch.nn.Module) -> None:
    """Prints a detailed summary of a Pytorch model.
    
    Args:
        model: The Pytorch model to summarize.
    """
    print(summary(model=model,
                  input_size=(32, 3, 224, 224),
                  verbose=0,
                  col_names=['input_size', 'output_size', 'num_params', 'trainable'],
                  col_width=18,
                  row_settings=['var_names']
                  ))

In [80]:
model_summary(model)

Layer (type (var_name))                                      Input Shape        Output Shape       Param #            Trainable
EfficientNet (EfficientNet)                                  [32, 3, 224, 224]  [32, 1000]         --                 True
├─Sequential (features)                                      [32, 3, 224, 224]  [32, 1280, 7, 7]   --                 True
│    └─Conv2dNormActivation (0)                              [32, 3, 224, 224]  [32, 32, 112, 112] --                 True
│    │    └─Conv2d (0)                                       [32, 3, 224, 224]  [32, 32, 112, 112] 864                True
│    │    └─BatchNorm2d (1)                                  [32, 32, 112, 112] [32, 32, 112, 112] 64                 True
│    │    └─SiLU (2)                                         [32, 32, 112, 112] [32, 32, 112, 112] --                 --
│    └─Sequential (1)                                        [32, 32, 112, 112] [32, 16, 112, 112] --                 True
│    │    └─M

In [81]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

for param in model.features.parameters():
    param.requires_grad = False

output_shape = len(class_names)

model.classifier = torch.nn.Sequential(
    torch.nn.Dropout(p=0.2, inplace=True),
    torch.nn.Linear(in_features=1280,
                    out_features=output_shape,
                    bias=True)
).to(device)

In [82]:
model_summary(model)

Layer (type (var_name))                                      Input Shape        Output Shape       Param #            Trainable
EfficientNet (EfficientNet)                                  [32, 3, 224, 224]  [32, 3]            --                 Partial
├─Sequential (features)                                      [32, 3, 224, 224]  [32, 1280, 7, 7]   --                 False
│    └─Conv2dNormActivation (0)                              [32, 3, 224, 224]  [32, 32, 112, 112] --                 False
│    │    └─Conv2d (0)                                       [32, 3, 224, 224]  [32, 32, 112, 112] (864)              False
│    │    └─BatchNorm2d (1)                                  [32, 32, 112, 112] [32, 32, 112, 112] (64)               False
│    │    └─SiLU (2)                                         [32, 32, 112, 112] [32, 32, 112, 112] --                 --
│    └─Sequential (1)                                        [32, 32, 112, 112] [32, 16, 112, 112] --                 False
│    

### Training a model

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(),
                             lr=0.001)

model_results = train(model=model,
                train_dataloader=train_dataloader,
                test_dataloader=test_dataloader,
                optimizer=optimizer,
                loss_fn=loss_fn,
                accuracy_fn=accuracy_fn,
                epochs=5,
                device=device)

### Plot loss and accuracy curves

In [None]:
plot_loss_and_accuracy_curves(model_results)

### Make prediction on images from test set

In [None]:
def predict_and_plot_image(
        model: torch.nn.Module,
        images_path: str,
        class_names: list[str],
        transform: torchvision.transforms = transforms,
        device: torch.device = device
) -> None:
    """Predicts the class of an image and plots it with the predicted label.

    This function loads an image from the specified path, applies the given
    transformations, passes it through the model to obtain predictions, and
    displays the image with its predicted class.

    Args:
        model: The trained PyTorch model for prediction.
        images_path: Path to the image file to predict.
        class_names: List of class names corresponding to model outputs.
        transform: Transformations to apply to the image.
        device: Target device to run computations on (e.g., 'cpu' or 'cuda').
    """
    img = Image.open(images_path)

    model.to(device)
    model.eval()
    with torch.inference_mode():
        # model requires samples in [batch_size, color_channels, height, width]
        transformed_image = transform(img).unsqueeze(dim=0).to(device)
        image_pred = model(transformed_image)

        # Logits -> Proba -> Labels
        image_label = torch.argmax(torch.softmax(image_pred, dim=1), dim=1)

        plt.figure()
        plt.imshow(img)
        plt.title(f'Prediction: {class_names[image_label]}')
        plt.axis(False)

In [None]:
num_images = 3
test_image_path_list = list(Path(test_dir).glob('*/*.jpg'))
test_image_path_random_sample = random.sample(population=test_image_path_list,
                                              k=num_images)

for image_path in test_image_path_random_sample:
    predict_and_plot_image(model=model,
                           images_path=image_path,
                           class_names=class_names)