<a href="https://colab.research.google.com/github/parthava-adabala/learning/blob/main/04_Pytorch_custom_datasets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn

torch.__version__

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
!nvidia-smi

In [None]:
import requests
import zipfile
from pathlib import Path

data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi"

if image_path.is_dir():
  print(f"{image_path} directory exists.")
else:
  print(f"Did not find {image_path} directory, creating one...")
  image_path.mkdir(parents=True, exist_ok=True)

In [None]:
with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
  request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
  print("Downlaoding pizza, steak, sushi data...")
  f.write(request.content)

with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
  print("Unzipping pizza, steak, sushi data...")
  zip_ref.extractall(image_path)

# Data preparation and data exploration

In [None]:
import os
def walk_through_dir(dir_path):
  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")
walk_through_dir(image_path)

In [None]:
train_dir = image_path/"train"
test_dir = image_path/"test"
train_dir, test_dir

# visualizing and image

In [None]:
import random
from PIL import Image

random.seed(42)

image_path_list = list(image_path.glob("*/*/*.jpg"))

random_image_path = random.choice(image_path_list)
print(random_image_path)

image_class = random_image_path.parent.stem
print(image_class)

image = Image.open(random_image_path)
print(image.height, image.width)
image

In [None]:
import numpy as np
import matplotlib.pyplot as plt

image_as_array = np.asarray(image)

plt.figure(figsize=(10,7))
plt.imshow(image_as_array)
plt.title(f"Image class: {image_class}")
plt.axis(False)

# Transforming data

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

In [None]:
data_transform = transforms.Compose([transforms.Resize(size=(64,64)),
                                     transforms.RandomHorizontalFlip(p = 0.5),
                                     transforms.ToTensor()])

In [None]:
data_transform(image)

In [None]:
def plot_transformed_image(image_path, transform, n=3, seed=42):
  if seed:
    random.seed(seed)
  random_image_path = random.sample(image_path, k=n)
  for image_path in random_image_path:
    with Image.open(image_path) as f:
      fig, ax = plt.subplots(nrows=1, ncols=2)
      ax[0].imshow(f)
      ax[0].set_title(f"Original \nSize: {f.size}")
      ax[0].axis("off")

      transformed_image = transform(f).permute(1,2,0)
      ax[1].imshow(transformed_image)
      ax[1].set_title(f"Transformed \nSize: {transformed_image.shape}")
      ax[1].axis("off")
      fig.suptitle(f"Class: {image_path.parent.stem}", fontsize=30)
plot_transformed_image(image_path=image_path_list, transform = data_transform, n=3, seed=42)

# Loading image data using image folder

In [None]:
from torchvision import datasets
train_data = datasets.ImageFolder(root=train_dir, transform=data_transform, target_transform=None)
test_data = datasets.ImageFolder(root=test_dir, transform=data_transform)
train_data, test_data

In [None]:
class_names = train_data.classes
class_names

In [None]:
class_dict = train_data.class_to_idx
class_dict

In [None]:
len(train_data), len(test_data)

In [None]:
train_data.samples[0]

In [None]:
img, label = train_data[0][0], train_data[0][1]
img, label, class_names[label]

In [None]:
img_permute = img.permute(1,2,0)
print(img_permute.shape, img.shape)

plt.figure(figsize=(10,7))
plt.imshow(img_permute)
plt.title(f"Class: {class_names[label]}")
plt.axis(False)

# data loaders

In [None]:
from torch.utils.data import DataLoader
BATCH_SIZE=1
train_dataloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
test_dataloader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)
train_dataloader, test_dataloader

In [None]:
len(train_dataloader), len(test_dataloader)

In [None]:
img, label = next(iter(train_dataloader))
print(img.shape, label.shape)

# Loading Image data with a custom dataset

In [None]:
import os
import pathlib
import torch

from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from typing import Tuple, Dict, List

In [None]:
train_data.classes, train_data.class_to_idx

In [None]:
# Creating a helper function to get class names

target_directory = train_dir
class_names_found = sorted([entry.name for entry in os.scandir(target_directory) if entry.is_dir()])
class_names_found

In [None]:
list(os.scandir(target_directory))

In [None]:
def find_classes(target_directory: str) -> Tuple[List[str],Dict[str, int]]:
  classes = sorted([entry.name for entry in os.scandir(target_directory) if entry.is_dir()])
  if not classes:
    raise FileNotFoundError(f"Couldn't find any classes in {target_directory}")

  class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
  return classes, class_to_idx

In [None]:
find_classes(target_directory)

In [None]:
# create custom dataset to replicate image folder
from torch.utils.data import Dataset

class ImageFolderCustom(Dataset):
  def __init__(self, target_dir: str, transform=None) -> None:

    self.paths = list(pathlib.Path(target_dir).glob("*/*.jpg"))
    self.transform = transform
    self.classes, self.class_to_idx = find_classes(target_dir)

  def load_image(self, index: int) -> Image.Image:
    image_path = self.paths[index]
    return Image.open(image_path)

  def __len__(self) -> int:
    return len(self.paths)

  def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
    image = self.load_image(index)
    class_name = self.paths[index].parent.name
    class_idx = self.class_to_idx[class_name]

    if self.transform:
      return self.transform(image), class_idx
    else:
      return image, class_idx

In [None]:
# create a transform
from torchvision import transforms
train_transforms = transforms.Compose([transforms.Resize(size=(64,64)),
                                        transforms.RandomHorizontalFlip(p=0.5),
                                        transforms.ToTensor()])
test_transforms = transforms.Compose([transforms.Resize(size=(64,64)),
                                       transforms.ToTensor()])

In [None]:
train_data_custom = ImageFolderCustom(target_dir=train_dir, transform=train_transforms)
test_data_custom = ImageFolderCustom(target_dir=test_dir, transform=test_transforms)
train_data_custom, test_data_custom

In [None]:
len(train_data_custom), len(test_data_custom)

In [None]:
len(test_data), len(test_data_custom)

In [None]:
train_data_custom.classes

In [None]:
train_data_custom.class_to_idx

In [None]:
# check for quality
print(train_data.classes == train_data_custom.classes)
print(train_data.class_to_idx == train_data_custom.class_to_idx)

In [None]:
# create a function to display random images
import random

def display_random_images(dataset: torch.utils.data.Dataset, classes:List[str]=None, n:int=10, display_shape: bool = True, seed:int = None):
  if n>10:
    n=10
    display_shape = False
    print(f"For display purposes, n shouldn't be greater than 10, setting to 10.")

  if seed:
    random.seed(seed)
  random_samples_idx = random.sample(range(len(dataset)), k=n)

  plt.figure(figsize=(16,8))


  for i, targ_sample in enumerate(random_samples_idx):
    targ_img, targ_label = dataset[targ_sample][0], dataset[targ_sample][1]
    targ_img_adjusted = targ_img.permute(1,2,0)
    plt.subplot(1,n,i+1)
    plt.imshow(targ_img_adjusted)
    if classes:
      plt.title(f"class: {classes[targ_label]}")
      if display_shape:
        plt.title(f"class: {classes[targ_label]} \n shape: {targ_img_adjusted.shape}")

In [None]:
display_random_images(dataset=train_data, classes=train_data.classes, n=5, seed=42)

In [None]:
display_random_images(dataset=train_data_custom, classes=train_data_custom.classes, n=5, seed=42)

In [None]:
# Turn loaded images into dataloader
from torch.utils.data import DataLoader
BATCH_SIZE = 32
train_dataloader_custom = DataLoader(dataset=train_data_custom, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_dataloader_custom = DataLoader(dataset=test_data_custom, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
train_dataloader_custom, test_dataloader_custom

In [None]:
image_custom, label_custom = next(iter(train_dataloader_custom))
print(image_custom.shape, label_custom.shape)

In [None]:
# Other forms of transforms (data augmentation)
from torchvision import transforms
train_transforms = transforms.Compose([transforms.Resize(size=(64,64)),
                                            transforms.TrivialAugmentWide(num_magnitude_bins=31),
                                        transforms.ToTensor()])
test_transforms = transforms.Compose([transforms.Resize(size=(64,64)),
                                       transforms.ToTensor()])

In [None]:
image_path_list = list(image_path.glob("*/*/*.jpg"))
image_path_list[:10]

In [None]:
plot_transformed_image(image_path=image_path_list, transform = train_transforms, n=3, seed=42)

# Model 0 Tiny VGG without data augmentation

In [None]:
simple_transform = transforms.Compose([transforms.Resize(size=(64,64)),
                                        transforms.ToTensor()])

In [None]:
from torchvision import datasets
train_data_simple = datasets.ImageFolder(root=train_dir, transform=simple_transform, target_transform=None)
test_data_simple = datasets.ImageFolder(root=test_dir, transform=simple_transform)

import os
from torch.utils.data import DataLoader
BATCH_SIZE=1
NUM_WORKERS=os.cpu_count()
train_dataloader_simple = DataLoader(dataset=train_data_simple, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_dataloader_simple = DataLoader(dataset=test_data_simple, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
train_dataloader_simple, test_dataloader_simple

In [None]:
NUM_WORKERS

In [None]:
class TinyVGG(nn.Module):
  def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:
    super().__init__()
    self.conv_block_1 = nn.Sequential(
        nn.Conv2d(in_channels=input_shape, out_channels=hidden_units, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )
    self.conv_block_2 = nn.Sequential(
        nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )
    self.classifier = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_features=hidden_units*16*16, out_features=output_shape)
    )
  def forward(self, x: torch.Tensor):
    x = self.conv_block_1(x)
    x = self.conv_block_2(x)
    x = self.classifier(x)
    return x

In [None]:
torch.manual_seed(42)
model_0 = TinyVGG(input_shape=3, hidden_units=10, output_shape=len(train_data_simple.classes)).to(device)
model_0

In [None]:
# Try a single image
image_batch, label_batch = next(iter(train_dataloader_simple))
image_batch.shape, label_batch.shape

In [None]:
model_0(image_batch)

In [None]:
# Use torch info
try:
  from torchinfo import summary
except:
  print("[INFO] Couldn't find torchinfo... installing it.")
  !pip install -q torchinfo
  from torchinfo import summary

summary(model=model_0, input_size=(1, 3, 64, 64))

# train and test loops

In [None]:
def train_step(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device):
  model.train()
  train_loss, train_acc = 0, 0
  for batch, (X,y) in enumerate(dataloader):
    X, y = X.to(device), y.to(device)
    y_pred = model(X)
    loss = loss_fn(y_pred, y)
    train_loss += loss.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
    train_acc += (y_pred_class == y).sum().item()/len(y_pred)
  train_loss /= len(dataloader)
  train_acc /= len(dataloader)
  return train_loss, train_acc

In [None]:
def test_step(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module, device: torch.device):
  model.eval()
  test_loss, test_acc = 0, 0
  with torch.inference_mode():
    for batch, (X,y) in enumerate(dataloader):
      X, y = X.to(device), y.to(device)
      test_pred_logits = model(X)
      loss = loss_fn(test_pred_logits, y)
      test_loss += loss.item()
      test_pred_labels = test_pred_logits.argmax(dim=1)
      test_acc+= ((test_pred_labels == y).sum().item()/len(test_pred_labels))
  test_loss /= len(dataloader)
  test_acc /= len(dataloader)
  return test_loss, test_acc

In [None]:
from tqdm.auto import tqdm
def train(model: torch.nn.Module, train_dataloader: DataLoader, test_dataloader: DataLoader, optimizer: torch.optim.Optimizer, loss_fn: torch.nn.Module = nn.CrossEntropyLoss(), epochs: int=5, device: torch.device=device):
  results = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}
  for epoch in tqdm(range(epochs)):
    train_loss, train_acc = train_step(model=model, dataloader=train_dataloader, loss_fn=loss_fn, optimizer=optimizer, device=device)
    test_loss, test_acc = test_step(model=model, dataloader=test_dataloader, loss_fn=loss_fn, device=device)
    print(f"Epoch: {epoch} | train_loss: {train_loss:.4f} | train_acc: {train_acc:.4f} | test_loss: {test_loss:.4f} | test_acc: {test_acc:.4f}")
    results["train_loss"].append(train_loss)
    results["train_acc"].append(train_acc)
    results["test_loss"].append(test_loss)
    results["test_acc"].append(test_acc)
  return results

In [None]:
#train and evaluate model 0
torch.manual_seed(42)
torch.cuda.manual_seed(42)

NUM_EPOCHS = 5

model_0 = TinyVGG(input_shape=3, hidden_units=10, output_shape=len(train_data.classes))

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_0.parameters(), lr=0.001)

from timeit import default_timer as timer
start_time = timer()

model_0_results = train(model=model_0, train_dataloader=train_dataloader_simple, test_dataloader=test_dataloader_simple, optimizer=optimizer, loss_fn=loss_fn, epochs=NUM_EPOCHS)

end_time = timer()
print(f"Total training time: {end_time-start_time:.3f} seconds")

In [None]:
model_0_results.keys()

In [None]:
def plot_loss_curves(results: Dict[str, List[float]]):
  loss = results["train_loss"]
  test_loss = results["test_loss"]
  accuracy = results["train_acc"]
  test_accuracy = results["test_acc"]
  epochs = range(len(results["train_loss"]))
  plt.figure(figsize=(15,7))
  plt.subplot(1,2,1)
  plt.plot(epochs, loss, label="train_loss")
  plt.plot(epochs, test_loss, label="test_loss")
  plt.title("Loss")
  plt.xlabel("Epochs")
  plt.legend()
  plt.subplot(1,2,2)
  plt.plot(epochs, accuracy, label="train_accuracy")
  plt.plot(epochs, test_accuracy, label="test_accuracy")
  plt.title("Accuracy")
  plt.xlabel("Epochs")
  plt.legend()

In [None]:
plot_loss_curves(model_0_results)

In [None]:
# create model 1 with data augmentation
from torchvision import transforms
train_transforms_trivial = transforms.Compose([transforms.Resize(size=(64,64)),
                                        transforms.TrivialAugmentWide(num_magnitude_bins=31),
                                        transforms.ToTensor()])
test_transforms_trivial = transforms.Compose([transforms.Resize(size=(64,64)),
                                       transforms.ToTensor()])

In [None]:
train_data_augmented = ImageFolderCustom(target_dir=train_dir, transform=train_transforms_trivial)
test_data_simple = ImageFolderCustom(target_dir=test_dir, transform=test_transforms)

In [None]:
import os
from torch.utils.data import DataLoader
BATCH_SIZE=32
NUM_WORKERS=os.cpu_count()

torch.manual_seed(42)
train_dataloader_augmented = DataLoader(dataset=train_data_augmented, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_dataloader_simple = DataLoader(dataset=test_data_simple, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
train_dataloader_augmented, test_dataloader_simple

In [None]:
# construct and train model 1
torch.manual_seed(42)

model_1 = TinyVGG(input_shape=3, hidden_units=10, output_shape=len(train_data_augmented.classes)).to(device)
model_1

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

NUM_EPOCHS=5

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_1.parameters(), lr=0.001)

from timeit import default_timer as timer
start_time = timer()

model_1_results = train(model=model_1, train_dataloader=train_dataloader_augmented, test_dataloader=test_dataloader_simple, optimizer=optimizer, loss_fn=loss_fn, epochs=NUM_EPOCHS)

end_time = timer()
print(f"Total training time: {end_time-start_time:.3f} seconds")

In [None]:
# Plot loss curves of model 1 results
plot_loss_curves(model_1_results)

In [None]:
# compare model results
import pandas as pd
model_0_results_df = pd.DataFrame(model_0_results)
model_1_results_df = pd.DataFrame(model_1_results)
model_0_results_df, model_1_results_df

In [None]:
plt.figure(figsize=(15,10))

epochs = range(len(model_0_results_df))

plt.subplot(2,2,1)
plt.plot(epochs, model_0_results_df["train_loss"], label="model_0")
plt.plot(epochs, model_1_results_df["train_loss"], label="model_1")
plt.title("Train loss")
plt.xlabel("Epochs")
plt.legend()

plt.subplot(2,2,2)
plt.plot(epochs, model_0_results_df["test_loss"], label="model_0")
plt.plot(epochs, model_1_results_df["test_loss"], label="model_1")
plt.title("Test loss")
plt.xlabel("Epochs")
plt.legend()

plt.subplot(2,2,3)
plt.plot(epochs, model_0_results_df["train_acc"], label="model_0")
plt.plot(epochs, model_1_results_df["train_acc"], label="model_1")
plt.title("Train accuracy")
plt.xlabel("Epochs")
plt.legend()

plt.subplot(2,2,4)
plt.plot(epochs, model_0_results_df["test_acc"], label="model_0")
plt.plot(epochs, model_1_results_df["test_acc"], label="model_1")
plt.title("Test accuracy")
plt.xlabel("Epochs")
plt.legend()

In [None]:
# making a prediction on a custom image
import requests

custom_image_path = data_path / "04-pizza-dad.jpeg"

if not custom_image_path.is_file():
  with open(custom_image_path, "wb") as f:
    request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/images/04-pizza-dad.jpeg")
    f.write(request.content)
else:
  print(f"{custom_image_path} already exists")

In [None]:
custom_image_path

In [None]:
# load custom image
import torchvision
custom_image_uint8 = torchvision.io.read_image(str(custom_image_path))
custom_image_uint8, custom_image_uint8.shape, custom_image_uint8.dtype

In [None]:
plt.imshow(custom_image_uint8.permute(1,2,0))

In [None]:
custom_image = torchvision.io.read_image(str(custom_image_path)).type(torch.float32)

custom_image = custom_image / 255.

# Print out image data
print(f"Custom image tensor:\n{custom_image}\n")
print(f"Custom image shape: {custom_image.shape}\n")
print(f"Custom image dtype: {custom_image.dtype}")

In [None]:
plt.imshow(custom_image.permute(1, 2, 0))
plt.title(f"Image shape: {custom_image.shape}")
plt.axis(False);

In [None]:
# Create transform pipleine to resize image
custom_image_transform = transforms.Compose([transforms.Resize((64, 64))])

custom_image_transformed = custom_image_transform(custom_image)

print(f"Original shape: {custom_image.shape}")
print(f"New shape: {custom_image_transformed.shape}")

In [None]:
model_1.eval()
with torch.inference_mode():
  custom_image_pred = model_1(custom_image_transformed.unsqueeze(dim=0).to(device))
custom_image_pred

In [None]:
# convert logits
custom_image_pred_probs = torch.softmax(custom_image_pred, dim=1)
custom_image_pred_probs

In [None]:
custom_image_pred_labels = torch.argmax(custom_image_pred_probs, dim=1)
custom_image_pred_labels

In [None]:
class_names[custom_image_pred_labels]

In [None]:
# putting together
def pred_plot_image(model: torch.nn.Module, image_path: str, class_names: List[str], transform: torchvision.transforms=None, device: torch.device=device):
  target_image = torchvision.io.read_image(str(image_path)).type(torch.float32)
  target_image = target_image / 255.

  if transform:
    target_image = transform(target_image)
  model.to(device)

  model.eval()
  with torch.inference_mode():
    target_image = target_image.unsqueeze(0)
    target_image_pred = model(target_image.to(device))
  target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
  target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)

  #plotting

  plt.imshow(target_image.squeeze().permute(1,2,0))
  if class_names:
    title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
  else:
    title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
  plt.title(title)
  plt.axis(False)

In [None]:
pred_plot_image(model=model_1, image_path=custom_image_path, class_names=class_names, transform=custom_image_transform, device=device)