In [1]:
import os

# Create directory to store the scripts
os.makedirs('scripts', exist_ok=True)

In [2]:
%%writefile scripts/get_data.py
import os
import requests
import zipfile
from pathlib import Path

def download_and_extract_data(
  data_path: str,
  data_url: str,
):
  # Create directory to store the data
  data_path = Path(data_path)
  if not data_path.is_dir():
    data_path.mkdir(parents=True, exist_ok=True)

  # Download data from github
  response = requests.get(data_url)
  file_name = data_url.split('/')[-1]
  with open(data_path / file_name, 'wb') as file:
    file.write(response.content)

  # Unzip file
  with zipfile.ZipFile(data_path / file_name, 'r') as zip_file:
    zip_file.extractall(data_path)

  # Remove zip file
  os.remove(data_path / file_name)

Writing scripts/get_data.py


In [3]:
%%writefile scripts/create_vitb16_model.py
import torchvision
import torch.nn as nn

def finetune_vitb16(
  num_classes: int,
):
  vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
  vit_model = torchvision.models.vit_b_16(weights=vit_weights)
  vit_transform = vit_weights.transforms()

  for param in vit_model.parameters():
    param.requires_grad = False

  vit_model.heads = nn.Sequential(
    nn.Linear(in_features=768, out_features=num_classes, bias=True),
  )

  return vit_model, vit_transform

Writing scripts/create_vitb16_model.py


In [4]:
%%writefile scripts/data_setup.py
import os
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader

NUM_WORKERS = os.cpu_count()

def create_dataloaders(
  train_dir: str,
  test_dir: str,
  transform: transforms.Compose,
  batch_size: int,
  num_workers: int = NUM_WORKERS,
):
  train_data = ImageFolder(
    train_dir,
    transform=transform,
  )

  test_data = ImageFolder(
    test_dir,
    transform=transform,
  )

  train_dataloader = DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
  )

  test_dataloader = DataLoader(
    test_data,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
  )

  class_names = train_data.classes

  return train_dataloader, test_dataloader, class_names

Writing scripts/data_setup.py


In [5]:
%%writefile scripts/engine.py
import torch
from tqdm.auto import tqdm

def train_step(
  model: torch.nn.Module,
  dataloader: torch.utils.data.DataLoader,
  criterion: torch.nn.Module,
  optimizer: torch.optim.Optimizer,
  device: str,
):
  train_loss = 0

  model.train()

  for batch, (X, y) in enumerate(dataloader):
    X = X.to(device, non_blocking=True)
    y = y.to(device, non_blocking=True)

    out = model(X)
    loss = criterion(out, y)
    train_loss += loss.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    train_loss = train_loss / len(dataloader)
  
  return train_loss


def test_step(
  model: torch.nn.Module,
  dataloader: torch.utils.data.DataLoader,
  criterion: torch.nn.Module,
  device: str,
):
  test_loss = 0

  model.eval()

  with torch.inference_mode():
    for batch, (X, y) in enumerate(dataloader):
      X = X.to(device, non_blocking=True)
      y = y.to(device, non_blocking=True)

      out = model(X)
      loss = criterion(out, y)
      test_loss += loss.item()

      test_loss = test_loss / len(dataloader)
  
  return test_loss


def train(
  epochs: int,
  model: torch.nn.Module,
  train_dataloader: torch.utils.data.DataLoader,
  test_dataloader: torch.utils.data.DataLoader,
  criterion: torch.nn.Module,
  optimizer: torch.optim.Optimizer,
  device: str,
):
  model = model.to(device)

  for ep in tqdm(range(epochs)):
    train_loss = train_step(
      model=model,
      dataloader=train_dataloader,
      criterion=criterion,
      optimizer=optimizer,
      device=device,
    )

    test_loss = test_step(
      model=model,
      dataloader=test_dataloader,
      criterion=criterion,
      device=device,
    )

    print(f'Epoch: {ep} | Train loss: {train_loss} | Test loss: {test_loss}')

Writing scripts/engine.py


In [6]:
%%writefile scripts/utils.py
import torch
from pathlib import Path

def save_model(
  model: torch.nn.Module,
  target_path: str,
  model_name: str,
):
  assert model_name.endswith('pth') or model_name.endswith('.pt'), "[Invalid model name]: model_name should end with '.pth' or '.pt'."

  target_path = Path(target_path)
  target_path.mkdir(parents=True, exist_ok=True)

  torch.save(
    obj = model.state_dict(),
    f = target_path / model_name,
  )

Writing scripts/utils.py


In [7]:
%%writefile scripts/train.py
import torch
import get_data, create_vitb16_model, data_setup, engine, utils

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATA_PATH = 'data'
DATA_URL = 'https://github.com/oschan77/AnimalsVision-App/raw/main/data/elephant_chicken_sheep_data.zip'
NUM_CLASSES = 3
TRAIN_DIR = 'data/elephant_chicken_sheep_data/train'
TEST_DIR = 'data/elephant_chicken_sheep_data/test'
BATCH_SIZE = 32
LEARNING_RATE = 5e-3
EPOCHS = 10
TARGET_PATH = 'saved_models'
MODEL_NAME = 'vitb16_v1.pth'

get_data.download_and_extract_data(
  data_path=DATA_PATH,
  data_url=DATA_URL,
)

vit_model, vit_transform = create_vitb16_model.finetune_vitb16(
  num_classes=NUM_CLASSES,
)

train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
  train_dir=TRAIN_DIR,
  test_dir=TEST_DIR,
  transform=vit_transform,
  batch_size=BATCH_SIZE,
)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(vit_model.parameters(), lr=LEARNING_RATE)

engine.train(
  epochs=EPOCHS,
  model=vit_model,
  train_dataloader=train_dataloader,
  test_dataloader=test_dataloader,
  criterion=criterion,
  optimizer=optimizer,
  device=DEVICE,
)

utils.save_model(
  model=vit_model,
  target_path=TARGET_PATH,
  model_name=MODEL_NAME,
)

Writing scripts/train.py


In [8]:
%%writefile scripts/predict.py
import torch
from torch import nn
from torchvision import transforms
from timeit import default_timer as timer
from typing import List

def predict_single_image(
  image,
  model: nn.Module,
  transform: transforms.Compose,
  class_names: List[str],
  device: str,
):
  start_time = timer()
  image = transform(image).unsqueeze(0).to(device)
  model.eval()
  logits = model(image)
  with torch.inference_mode():
    probs = torch.softmax(logits, dim=1)

  classes_and_probs = {class_names[i]: float(probs[0][i]) for i in range(len(class_names))}
  inference_time = round(timer() - start_time, 5)

  print(f'classes_and_probs: {classes_and_probs}')
  print(f'inference_time: {inference_time}')

  return classes_and_probs, inference_time

Writing scripts/predict.py


In [9]:
!python scripts/train.py

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100% 330M/330M [00:03<00:00, 88.4MB/s]
  0% 0/10 [00:00<?, ?it/s]Epoch: 0 | Train loss: 0.00024761712973031536 | Test loss: 0.0008062787690303392
 10% 1/10 [00:11<01:40, 11.11s/it]Epoch: 1 | Train loss: 2.9407685559104668e-05 | Test loss: 4.519576651984136e-05
 20% 2/10 [00:15<00:55,  6.97s/it]Epoch: 2 | Train loss: 6.888002583932233e-06 | Test loss: 1.6687422801085092e-05
 30% 3/10 [00:19<00:39,  5.67s/it]Epoch: 3 | Train loss: 1.4119983562620175e-06 | Test loss: 8.81089979682454e-06
 40% 4/10 [00:23<00:31,  5.21s/it]Epoch: 4 | Train loss: 6.934541221814e-06 | Test loss: 1.6546397394752905e-05
 50% 5/10 [00:27<00:24,  4.82s/it]Epoch: 5 | Train loss: 2.039409309366469e-06 | Test loss: 4.486573170989114e-06
 60% 6/10 [00:32<00:18,  4.61s/it]Epoch: 6 | Train loss: 4.972168890094082e-06 | Test loss: 5.635954645792178e-06
 70% 7/10 [00:36<00:13,  4.60s/it]Epo