<a href="https://colab.research.google.com/github/rdkworld/AIPND-2022/blob/main/Generalized/Train_an_Existing_Pytorch_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Pre-requisite Setup

### User Input Parameters including Hyperparameters

In [2]:
#Data
SOURCE_URL = 'https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
BASE_DIRECTORY = 'flowers'
DATA_DIRECTORY = 'data'
MODEL_DIRECTORY = 'models'
FILE_NAME = 'flowers.tar.gz'

# Setup hyperparameters
NUM_EPOCHS = 1
BATCH_SIZE = 64
HIDDEN_UNITS = '' #Not used
LEARNING_RATE = 0.003
MODEL_NAME = 'vit_b_16'
MODEL_WEIGHT = 'ViT_B_16' 
LOSS_FUNCTION = 'CrossEntropyLoss'
OPTIMIZER = 'Adam'
MANUAL_RESIZE = 64 #Not used
NUM_CLASSES = 102
FEATURE_EXTRACT = True
RGB = 3 #(Color picture is 3, black & white is 1) 


###Get Libraries

In [3]:
# Install atleast torch 1.12+ and torchvision 0.13+
try:
    import torch
    import torchvision
    assert int(torch.__version__.split(".")[1]) >= 12, "torch version should be 1.12+"
    assert int(torchvision.__version__.split(".")[1]) >= 13, "torchvision version should be 0.13+"
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")
except:
    print(f"[INFO] torch/torchvision versions not as required, installing nightly versions.")
    !pip3 install -U --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu113
    import torch
    import torchvision
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")

torch version: 1.12.1+cu113
torchvision version: 0.13.1+cu113


###Regular Imports

In [4]:
# Continue with regular imports
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn
from torchvision import transforms

# Try to get torchinfo, install it if it doesn't work
try:
    from torchinfo import summary
except:
    print("[INFO] Couldn't find torchinfo... installing it.")
    !pip install -q torchinfo
    from torchinfo import summary

#Additions from functions
import os
import sys
import tarfile
from pathlib import Path

[INFO] Couldn't find torchinfo... installing it.


###Helpers/functions from Github

In [5]:
# Try to import the helper functions, download it from GitHub if it doesn't work
try:
    import data_setup, engine, model_builder, utils 
    from helper_functions import download_data, set_seeds, plot_loss_curves, create_directory
except:
    # Get the scripts
    print("[INFO] Couldn't find the scripts... downloading them from GitHub.")
    !git clone https://github.com/rdkworld/AIPND-2022
    #create_directory(Path().absolute() / BASE_DIRECTORY)
    !mkdir --parents /content/$BASE_DIRECTORY 
    !mv AIPND-2022/Generalized/*.py /content/$BASE_DIRECTORY
    !rm -rf AIPND-2022
    sys.path.append(os.path.join(os.getcwd(), BASE_DIRECTORY))
    import data_setup, engine, model_builder, utils 
    from helper_functions import download_data, set_seeds, plot_loss_curves, create_directory

[INFO] Couldn't find the scripts... downloading them from GitHub.
Cloning into 'AIPND-2022'...
remote: Enumerating objects: 328, done.[K
remote: Counting objects: 100% (218/218), done.[K
remote: Compressing objects: 100% (162/162), done.[K
remote: Total 328 (delta 116), reused 116 (delta 53), pack-reused 110[K
Receiving objects: 100% (328/328), 11.02 MiB | 21.62 MiB/s, done.
Resolving deltas: 100% (138/138), done.


In [8]:
#Create Directory Structure
create_directory(Path(BASE_DIRECTORY))
create_directory(Path(BASE_DIRECTORY) / DATA_DIRECTORY)
create_directory(Path(BASE_DIRECTORY) / MODEL_DIRECTORY)

train_dir = f"{BASE_DIRECTORY}/{DATA_DIRECTORY}/{BASE_DIRECTORY}/train"
valid_dir = f"{BASE_DIRECTORY}/{DATA_DIRECTORY}/{BASE_DIRECTORY}/valid"
test_dir = f"{BASE_DIRECTORY}/{DATA_DIRECTORY}/{BASE_DIRECTORY}/test"

### Connect Colab and Google Drive to save and load models

In [11]:
#Mount Google Drive 
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


###Setup target device

In [12]:
# Setup target device
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

##Download data and categorize into train/valid/test folders as required

In [9]:
#Work in Progress, See next cell as a temporary measure

In [13]:
#Temporarily copy from Google Drive
!cp /content/drive/MyDrive/flowers.tar.gz $BASE_DIRECTORY/$DATA_DIRECTORY #copy from drive to colab

In [11]:
#Temporarily comment
#!tar xf $BASE_DATA_DIRECTORY/$FILE_NAME

In [14]:
#Untar the file
#with tarfile.open(os.path.join(BASE_DIRECTORY, DATA_DIRECTORY, FILE_NAME), "r") as tar_ref:
with tarfile.open(Path(BASE_DIRECTORY) / DATA_DIRECTORY / FILE_NAME, "r") as tar_ref:
    print(f"[INFO] Unzipping {FILE_NAME}...") 
    tar_ref.extractall(Path(BASE_DIRECTORY) / DATA_DIRECTORY)

[INFO] Unzipping flowers.tar.gz...


EOFError: ignored

In [15]:
if (Path(BASE_DIRECTORY) / DATA_DIRECTORY / FILE_NAME).is_file():
  (Path(BASE_DIRECTORY) / DATA_DIRECTORY / FILE_NAME).unlink()
# if os.path.exists(os.path.join(BASE_DIRECTORY, DATA_DIRECTORY, FILE_NAME)):
#   os.remove(os.path.join(BASE_DIRECTORY, DATA_DIRECTORY, FILE_NAME))

## Get info on Pre-Trained Models

###Pre-trained Model & Transform Details

In [16]:
#Get pre-trained model weights and model
pretrained_weights = eval(f"torchvision.models.{MODEL_WEIGHT}_Weights.DEFAULT")
pretrained_model = eval(f"torchvision.models.{MODEL_NAME}(weights = pretrained_weights)").to(device)
auto_transforms = pretrained_weights.transforms()

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth


  0%|          | 0.00/330M [00:00<?, ?B/s]

In [17]:
# Print a summary using torchinfo (uncomment for actual output)
summary(model=pretrained_model,
        input_size= (BATCH_SIZE, RGB, auto_transforms.crop_size[0], auto_transforms.crop_size[0]),  # make sure this is "input_size", not "input_shape"
        col_names=["input_size", "output_size", "num_params", "trainable"],  # col_names=["input_size"], # uncomment for smaller output
        col_width=20,
        row_settings=["var_names"]
), auto_transforms

 Layer (type (var_name))                                                Input Shape          Output Shape         Param #              Trainable
 VisionTransformer (VisionTransformer)                                  [64, 3, 224, 224]    [64, 1000]           768                  True
 ├─Conv2d (conv_proj)                                                   [64, 3, 224, 224]    [64, 768, 14, 14]    590,592              True
 ├─Encoder (encoder)                                                    [64, 197, 768]       [64, 197, 768]       151,296              True
 │    └─Dropout (dropout)                                               [64, 197, 768]       [64, 197, 768]       --                   --
 │    └─Sequential (layers)                                             [64, 197, 768]       [64, 197, 768]       --                   True
 │    │    └─EncoderBlock (encoder_layer_0)                             [64, 197, 768]       [64, 197, 768]       7,087,872            True
 │    │    └─Enco

In [18]:
# Create model with help from model_builder.py
updated_pretrained_model = model_builder.update_last_layer_pretrained_model(pretrained_model, NUM_CLASSES, FEATURE_EXTRACT).to(device)

In [19]:
summary(model=updated_pretrained_model,
        input_size= (BATCH_SIZE, RGB, auto_transforms.crop_size[0], auto_transforms.crop_size[0]),  # make sure this is "input_size", not "input_shape"
        col_names=["input_size", "output_size", "num_params", "trainable"],  # col_names=["input_size"], # uncomment for smaller output
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                                                Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                                  [64, 3, 224, 224]    [64, 102]            768                  Partial
├─Conv2d (conv_proj)                                                   [64, 3, 224, 224]    [64, 768, 14, 14]    (590,592)            False
├─Encoder (encoder)                                                    [64, 197, 768]       [64, 197, 768]       151,296              False
│    └─Dropout (dropout)                                               [64, 197, 768]       [64, 197, 768]       --                   --
│    └─Sequential (layers)                                             [64, 197, 768]       [64, 197, 768]       --                   False
│    │    └─EncoderBlock (encoder_layer_0)                             [64, 197, 768]       [64, 197, 768]       (7,087,872)          False
│    │    └─Encod

In [18]:
# from IPython.lib import pretty
# if getattr(pretrained_model, 'heads'):
#   print(True)
# pretrained_model.heads.head

###Dataloaders

In [20]:
# Create data loaders
train_dataloader, test_dataloader, class_names, class_to_idx = data_setup.create_dataloaders(
    train_dir=train_dir,
    test_dir=test_dir,
    transform=auto_transforms,
    batch_size=BATCH_SIZE
)
if len(class_names) != NUM_CLASSES:
  print("Mismatch in the number of unique classes/labels and user input NUM_CLASSES")
  exit()

###Set loss and optimizer

In [22]:
loss_fn = eval(f"torch.nn.{LOSS_FUNCTION}()")
optimizer = eval(f"torch.optim.{OPTIMIZER}(updated_pretrained_model.parameters(),lr=LEARNING_RATE)")

###Do the training

In [23]:
# Start training with help from engine.py
pretrained_model_results = engine.train(model=updated_pretrained_model,
             train_dataloader=train_dataloader,
             test_dataloader=test_dataloader,
             loss_fn=loss_fn,
             optimizer=optimizer,
             epochs=NUM_EPOCHS,
             device=device)

  0%|          | 0/1 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [None]:
#Plot the results/loss curves
plot_loss_curves(pretrained_model_results)

###Save the model



In [84]:
checkpoint = {'state_dict': updated_pretrained_model.state_dict(),
              'arch': MODEL_NAME,
              'arch_weight': MODEL_WEIGHT,              
              'arch_type': 'EXISTING',
              'loss_function': LOSS_FUNCTION,
              'optimizer': OPTIMIZER,                            
              'class_names' : class_names,
              'class_to_idx' : class_to_idx,
              'hidden_units': HIDDEN_UNITS,
              'num_classes' : NUM_CLASSES,
              'feature_extract' : FEATURE_EXTRACT,
              'gpu_or_cpu' : device
             }

In [87]:
# Save the model on local drive
utils.save_model(model=updated_pretrained_model, target_dir=f"{BASE_DIRECTORY}/{MODEL_DIRECTORY}", model_name=f"{MODEL_NAME}_model.pth")

[INFO] Saving model to: flowers/models/vit_b_16_model.pth


In [None]:
#Size of the model
pretrained_model_size = Path(f"{BASE_DIRECTORY}/{MODEL_DIRECTORY}/{MODEL_NAME}_model.pth").stat().st_size // (1024*1024) # division converts bytes to megabytes (roughly) 
print(f"Pretrained model size: {pretrained_model_size} MB")

In [88]:
#Export/Copy the model on remote
#drive.flush_and_unmount()
#drive.mount('/content/drive')
!cp $BASE_DIRECTORY/$MODEL_DIRECTORY/"$MODEL_NAME"_model.pth /content/drive/MyDrive/ #copy from colab to drive

##Load the model


In [None]:
#Check the device and load the checkpoint
if torch.cuda.is_available():
    device = torch.device("cuda") 
    checkpoint = torch.load(f"{BASE_DIRECTORY}/{MODEL_DIRECTORY}/{MODEL_NAME}_model.pth")
else:
    device = "cpu" #or torch.device("cpu") 
    checkpoint = torch.load(f"{BASE_DIRECTORY}/{MODEL_DIRECTORY}/{MODEL_NAME}_model.pth", map_location = device)
print(f"Using {device} device for predicting/inference")

#Load/initialize the model
pretrained_weights = eval(f"torchvision.models.{MODEL_WEIGHT}_Weights.DEFAULT")
auto_transforms = pretrained_weights.transforms()
pretrained_model = eval(f"torchvision.models.{MODEL_NAME}(weights = None)") 
pretrained_model.class_to_idx = checkpoint['class_to_idx']
pretrained_model.class_names = checkpoint['class_names']
pretrained_model.load_state_dict(checkpoint['state_dict'])
pretrained_model.to(device)

##Inference

In [None]:
!ls flowers

In [None]:
# Predict on custom image
custom_image_path = '/flowers/data/test/xx/xxxx.jpg'
pred_and_plot_image(model=pretrained_model,
                    image_path=custom_image_path,
                    class_names=class_names,
                    transform=auto_transforms)