# Table of Contents

# 1. Setup 

# 2. Get Data

# 3. Create Datasets and DataLoaders 

# 4. Get and customize a pretrained model 

# 5. Train Model 

# 6. Evaluate the model by plotting loss curves 

# 7. Make predictions on images from the test set 

##### So far, the models that we made were not up to the standard that we wanted. 
##### To train a superb model, we need massive amouts of data and a large and complex model that is balanced right between the underfitting and overfitting paradigm. 
##### To bypass all of these tedious and sometimes impossible task, we can perform **transform learning**.

## What is Transfer Learning?

##### Transfer learning allows us to take the weights and biases from another model that has learned from a similar problem/task and use them for our own problem. 
##### For example, we can utilize the computer vision model, ImageNet and use them to perform our own image task. 

## Why use Transfer Learning?

##### There are mainly two benefits to using transfer learning. 
1) Can leverage an existing model proven to work on problems similar to our own.

2) Can leverage a working model which has already learned patterns on similar data to our own. This often results in achieving great results with less custom data. 

# 1. Setup

##### Let's import necessary libraries first.

In [1]:
try:
    import torch
    import torchvision
    assert int(torch.__version__.split(".")[1]) >= 12, "torch version should be 1.12+"
    assert int(torchvision.__version__.split(".")[1]) >= 13, "torchvision version should be 0.13+"
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")
except:
    print(f"[INFO] torch/torchvision versions not as required, installing nightly versions.")
    !pip3 install -U torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
    import torch
    import torchvision
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")

[INFO] torch/torchvision versions not as required, installing nightly versions.
Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu113
torch version: 2.0.0.dev20230212
torchvision version: 0.15.0.dev20230212


In [2]:
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn
from torchvision import transforms

# Try to get torchinfo, install it if it doesn't work
try:
    from torchinfo import summary
except:
    print("[INFO] Couldn't find torchinfo... installing it.")
    !pip install -q torchinfo
    from torchinfo import summary

# Try to import the going_modular directory, download it from GitHub if it doesn't work
try:
    from going_modular.going_modular import data_setup, engine
except:
    # Get the going_modular scripts
    print("[INFO] Couldn't find going_modular scripts... downloading them from GitHub.")
    !git clone https://github.com/mrdbourke/pytorch-deep-learning
    !mv pytorch-deep-learning/going_modular .
    !rm -rf pytorch-deep-learning
    from going_modular.going_modular import data_setup, engine

[INFO] Couldn't find going_modular scripts... downloading them from GitHub.
Cloning into 'pytorch-deep-learning'...
remote: Enumerating objects: 3487, done.[K
remote: Counting objects: 100% (107/107), done.[K
remote: Compressing objects: 100% (68/68), done.[K
remote: Total 3487 (delta 40), reused 93 (delta 34), pack-reused 3380[K
Receiving objects: 100% (3487/3487), 642.17 MiB | 7.41 MiB/s, done.
Resolving deltas: 100% (1996/1996), done.
Updating files: 100% (223/223), done.


  from .autonotebook import tqdm as notebook_tqdm


##### Let's set up our device too.

In [3]:
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device('cpu')
device

device(type='mps')

# 2. Get Data

##### For the dataset we will be using today, we'll use the pizza_steak_sushi dataset provided from this course's Github.

In [4]:
import os
import zipfile

from pathlib import Path

import requests

# Setup path to data folder
data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi"

# If the image folder doesn't exist, download it and prepare it... 
if image_path.is_dir():
    print(f"{image_path} directory exists.")
else:
    print(f"Did not find {image_path} directory, creating one...")
    image_path.mkdir(parents=True, exist_ok=True)
    
    # Download pizza, steak, sushi data
    with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
        request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
        print("Downloading pizza, steak, sushi data...")
        f.write(request.content)

    # Unzip pizza, steak, sushi data
    with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
        print("Unzipping pizza, steak, sushi data...") 
        zip_ref.extractall(image_path)

    # Remove .zip file
    os.remove(data_path / "pizza_steak_sushi.zip")

data/pizza_steak_sushi directory exists.


##### Let's create paths to our training and test directories.

In [5]:
train_dir = image_path / "train"
test_dir = image_path / "test"