# Image Classification using PyTorch and 🔭 Galileo

In this tutorial, we'll train a model with PyTorch and explore the results in Galileo.

**Make sure to select GPU in your Runtime! (Runtime -> Change Runtime type)**

In [None]:
#@title Install `dataquality`
# Upgrade pip
!pip install -U pip &> /dev/null

# Install all dependecies
!pip install -U dataquality matplotlib==3.1.3 torch torchmetrics==0.10.0 datasets &> /dev/null

print('👋 Installed necessary libraries!')

In [None]:
#@markdown Check that a GPU is available

import torch
# Check Cuda.
if torch.cuda.is_available():
  print("⚡ You are connected to a GPU!")
else:
  print("❗You are NOT connected to a GPU ❗It is recommended to connect to a GPU before training")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# 1. Login to Galileo

In [None]:
import dataquality as dq

dq.login()

# 2. Load Data

In [None]:
#@title Load a 🤗 HuggingFace Dataset
#@markdown You can find more datasets [here](https://huggingface.co/datasets?task_categories=task_categories:image-classification&sort=downloads).

dataset_name = "CVdatasets/food27" #@param ["mnist", "fashion_mnist", "cifar10", "cifar100", "Maysee/tiny-imagenet", "frgfm/imagenette"] {allow-input: true}
print(f"You selected the {dataset_name} dataset")

from IPython.utils import io
from datasets import load_dataset, get_dataset_config_names

# Try to load the data. If a config (subset) is needed, pick one
try:
  with io.capture_output() as captured:
    data = load_dataset(dataset_name, trust_remote_code=True)
except ValueError as e:
  if "Config name is missing" not in repr(e):
    raise e

  configs = get_dataset_config_names(dataset_name)
  print(f"The dataset {dataset_name} has multiple subsets {configs}.")
  config = input(f"🖖 Enter the name of the subset to pick (or leave blank for any): ")
  if config:
    assert config in configs, f"{config} is not a valid subset"
  else:
    config = configs[0]
  with io.capture_output() as captured:
    data = load_dataset(dataset_name, name=config, trust_remote_code=True)

# Check that the dataset has at least train and either of validation/test
assert "train" in data and len({"validation", "valid", "test"}.intersection(data)), \
f"💾 The dataset {dataset_name} has either no train, or no validation or test splits, select another one."
test_split_name = list({"validation", "valid", "test"}.intersection(data))[0]
if test_split_name == 'valid':
  data['validation'] = data['valid']
  test_split_name = 'validation'

print(f"\n🏆 Dataset {dataset_name} loaded succesfully")

# Select a small portion of the dataset for CI.
import os
def _minimize_for_ci() -> bool:
    return os.getenv("MINIMIZE_FOR_CI", "false") == "true"

if _minimize_for_ci():
  from datasets.features.features import ClassLabel

  # Find the name of the ground truth column
  good_col_names = [name for name in list(data['train'].features) if "label" in name]
  if len(good_col_names) == 1:
      label_col = good_col_names[0]
  else:
    col_names = list(data['train'].features)
    print(f"The name of the columns are {col_names}.")
    label_col = input(f"🏅 Please enter the name of the column containing the labels: ")
    assert label_col in col_names, f"{label_col} is not an existing column"

  # Create a tiny dataset with only label 0 as ground truth and 10 samples.
  data['train'] = data['train'].filter(lambda example: example[label_col] == 0).select(range(100))
  data['train'].features[label_col] = ClassLabel(names = [data['train'].features[label_col].names[0]])

  data[test_split_name] = data[test_split_name].filter(lambda example: example[label_col] == 0).select(range(100))
  data[test_split_name].features[label_col] = ClassLabel(names = [data[test_split_name].features[label_col].names[0]])

# 3. Initialize Galileo

In [None]:
# 🔭🌕 Initializing a new run in Galileo. Each run is part of a project.
dq.init(task_type="image_classification", 
        project_name="image-classification-demo", 
        run_name=f"example_run_{dataset_name.replace('/', '-')}")

# 4. Create Dataset and Log Input Data with Galileo

Input data can be logged via `log_image_dataset`. This step will log the images, gold labels, data split, and list of all labels. You can achieve this adding 1 line of code to the standard PyTorch Dataset Class.

In [None]:
#@markdown Fix a random Seed and load helper methods.
from typing import Optional, List
from io import BytesIO
from PIL import Image
import numpy as np
import random

# Fix a random seed.
def seed_all(seed: int) -> None:
    """Set all relevant seed for training a Pytorch Model.

    Based on the following post:
    https://discuss.pytorch.org/t/reproducibility-with-all-the-bells-and-whistles/81097
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def seed_worker(worker_id: int) -> None:
    """Set seed for dataloader worker.

    Based on the following post:
    https://discuss.pytorch.org/t/reproducibility-with-all-the-bells-and-whistles/81097
    """
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# Methods for converting between bytes, numpy, image.
def _bytes_to_img(b: bytes) -> Image:
    return Image.open(BytesIO(b))

def _bytes_to_np(b: bytes, dtype: str = "float") -> np.ndarray:    
    array = np.frombuffer(b, dtype=np.uint8)
    if dtype == "uint8":
      return array
    elif dtype == "float":
      return array / 255

# Methods for loading the df into a dataset.
def find_label_col_name(col_names: List[str]) -> Optional[str]:
    for col_name in col_names:
        if "label" in col_name:
            return col_name
    return None

def find_image_col_name(col_names: List[str]) -> Optional[str]:
    for col_name in col_names:
        if "img" in col_name or "image" in col_name:
            return col_name
    return None

def find_imgs_location_col_name(col_names: List[str]) -> Optional[str]:
    for col_name in col_names:
        if "path" in col_name:
            return col_name
    return None

def find_raw_image_col_name(col_names: List[str]) -> Optional[str]:
  for name in col_names:
    if any([ n in name for n in ["img", "image"]]):
      return name
  return None

def _write_to_disk(x, path) -> str:
    os.makedirs(os.path.dirname(path), exist_ok=True)
    Image.open(BytesIO(x["bytes"])).save(path)
    return path

In [None]:
#@markdown 🔭🌕 Galileo -- Log Input Data
from uuid import uuid4

from datasets import Image as datasetsImage
from datasets import DatasetDict
from datasets.features.features import ClassLabel
from torch.utils.data import Dataset as TorchDataset
from torchvision import transforms

STANDARD_DATA_COLUMNS_CV = ["id", "text", "label_idx", "path"]

class ImageDataset(TorchDataset):
    def __init__(
        self, 
        hf: DatasetDict,
        split: str,
        transform: transforms.Compose = None, 
        class_labels: ClassLabel = None
    ):  
        """
        Args:
          hf: a HuggingFace dataset
          split: the split for the hf dataset
          transform [Optional]: a transform to apply to the images dynamically 
            before training
          class_labels [Optional]: the ClassLabel object containing the list of 
            labels and the method to convert between label (string) and 
            label_idx (int). To insure consistency pass the class_labels of the
            training dataset to the test/val datasets.
        """
        hf = hf[split] 
        self.imgs_dir = os.path.dirname(hf.cache_files[0]["filename"])
        self.transform = transform

        # Find the column containing the images' paths. If not specify, save them to disk.
        self.raw_img_location_colname = find_raw_image_col_name(hf.column_names)
        if self.raw_img_location_colname is None:
          raise ValueError("Could not find the images location column in the dataframe")
        STANDARD_DATA_COLUMNS_CV.append(self.raw_img_location_colname)

        # Convert to pandas df.
        hf = hf.cast_column(self.raw_img_location_colname, datasetsImage(decode=False))
        self.ds = hf.to_pandas()
        self.ds["text"] = self.ds[self.raw_img_location_colname].apply(lambda x: _write_to_disk(x, f"{self.imgs_dir}/{uuid4()}.png"))

        # Find the label column name: could be label, labels, coarse_label, etc.
        self.label_col_name = find_label_col_name(self.ds.columns)
        if self.label_col_name is None:
            raise ValueError(f"Could not find the label column in the dataframe")
        STANDARD_DATA_COLUMNS_CV.append(self.label_col_name)

        # Set the list of labels for this split.
        self.class_labels = class_labels
        if self.class_labels is None:
            self.class_labels = hf.features[self.label_col_name]
        self.list_of_labels = self.class_labels.names
        if split == "train":
          dq.set_labels_for_run(self.list_of_labels)

        # Add column with labels as string (for dq).
        self.ds["label_idx"] = self.ds[self.label_col_name]
        labels_int2str = self.class_labels.int2str
        self.ds[self.label_col_name] = self.ds["label_idx"].map(labels_int2str)

        # Find the id column, or create it if it doesn't exist.
        if "id" not in self.ds.columns:
            self.ds = self.ds.reset_index().rename(columns={"index": "id"})

        # Get the metadata columns.
        meta_data_cols = [
            column
            for column in self.ds.columns
            if column not in STANDARD_DATA_COLUMNS_CV
        ]

        # 🔭🌕 Galileo logging -- Log Input Data
        dq.log_image_dataset(
            dataset=self.ds,
            label=self.label_col_name,
            split=split,
            meta=meta_data_cols,
            imgs_local_colname="text",
        )

    def __getitem__(self, idx: int):
        row = self.ds.loc[idx]
        image = Image.open(row["text"]).convert('RGB')
        label, id = row["label_idx"], row["id"]

        if self.transform is not None:
            image = self.transform(image)

        return {"image": image, "label": label, "id": id}

    def __len__(self) -> int:
        return len(self.ds)

In [None]:
#@markdown Create the Dataset and DataLoader

# Create the Datasets.
image_crop_size = (224, 224)

val_transforms = transforms.Compose(
    [
        transforms.Resize((image_crop_size[0], image_crop_size[1])),
        transforms.ToTensor()
    ]
)
train_transforms = transforms.Compose(val_transforms.transforms + [transforms.RandomHorizontalFlip()])

TRAIN_SPLIT_NAME = "train"
train_dataset = ImageDataset(data, split=TRAIN_SPLIT_NAME, transform=train_transforms)
VAL_SPLIT_NAME = test_split_name # this var is needed in dq.set_split down below
test_dataset = ImageDataset(data, split=VAL_SPLIT_NAME, transform=val_transforms, class_labels=train_dataset.class_labels)

print(f"Loaded {TRAIN_SPLIT_NAME} dataset with {len(train_dataset.ds)} samples and {len(train_dataset.list_of_labels)} labels")
print(f"Loaded {VAL_SPLIT_NAME} dataset with {len(test_dataset.ds)} samples and  {len(test_dataset.list_of_labels)} labels")


# Create the DataLoaders.
from torch.utils.data import DataLoader as TorchDataLoader

BATCH_SIZE = 64

NUM_WORKERS = 0
SEED_WORKER = 42

seed_all(SEED_WORKER)

train_dataloader = TorchDataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    worker_init_fn=seed_worker,
    pin_memory=True
)
test_dataloader = TorchDataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    worker_init_fn=seed_worker,
    pin_memory=True
)

In [None]:
#@markdown Visualize the Data.
# Visualizing a few images of the dataset (post-processing/augmentation)
import random
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
idxs = [random.randint(0, len(train_dataset) -1) for _ in range(20)]
grid_img = make_grid([train_dataset[idx]["image"] for idx in idxs], nrow=5)
plt.figure(figsize = (20,10))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

# 6. Log model data with Galileo

Model data is logged by wrapping the model with `watch` function. This step will log the model logits and embeddings. You can achieve this by adding 1 line of code to the standard pytorch model. 

In [None]:
from torchvision.models import resnet34, resnet50

EPOCHS = 3
if _minimize_for_ci():
    EPOCHS = 1

# Load model and replace last layer.
model = resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, len(train_dataset.list_of_labels))
torch.nn.init.xavier_uniform_(model.fc.weight)

model = model.to(device)

# Set optimizer and loss.
params_1x = [  # get the original weights, they'll be updated with a lower learning rate
    param
    for name, param in model.named_parameters()
    if "fc" not in str(name)
]
lr, weight_decay = 1e-5, 5e-4
optimizer = torch.optim.Adam(
    [
        {"params": params_1x, "lr": lr},
        {"params": model.fc.parameters(), "lr": lr * 10},
    ],
    weight_decay=weight_decay,
)
criterion = torch.nn.CrossEntropyLoss()

from dataquality.integrations.torch import watch

# 🔭🌕 Galileo logging -- Log Embeddings
watch(
    model=model,
    classifier_layer=model.fc,
    dataloaders=[train_dataloader, test_dataloader]
)

# 7. Putting into Action: Training a Model

We complete the training pipeline by using a standard PyTorch training setup. While training, we log the current `epoch` and `split`. To complete logging, we call `dq.finish()` after training.

In [None]:
from tqdm import tqdm
from time import sleep, time

# Train !
start = time()
print(f"Training for {EPOCHS} epochs on {device}")

for epoch in range(1, EPOCHS + 1):
    print(f"Epoch {epoch}/{EPOCHS}")
    dq.set_epoch(epoch)  # 🔭🌕 Galileo -- Set split

    model.train()
    train_loss = torch.tensor(0.0, device=device)
    train_correct = torch.tensor(0, device=device)
    
    dq.set_split(TRAIN_SPLIT_NAME)
    with tqdm(train_dataloader, unit="batch") as train_minibatchs:
      for train_minibatch in train_minibatchs:
          train_minibatchs.set_description(f"Epoch {epoch}")

          images = train_minibatch["image"].to(device)
          labels = train_minibatch["label"].to(device)

          preds = model(images)
          loss = criterion(preds, labels)

          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          with torch.no_grad():
            train_loss += loss
            train_batch_correct = (torch.argmax(preds, dim=1) == labels).sum()
            train_correct += train_batch_correct

          train_minibatchs.set_postfix(batch_loss=loss.item(), batch_accuracy=float(train_batch_correct) / BATCH_SIZE)
          sleep(0.01)

    print(f"Training loss: {train_loss:.2f}")
    print(f"Training accuracy: {100 * float(train_correct) / len(train_dataloader.dataset):.2f}")
    
    dq.set_split(VAL_SPLIT_NAME)  # 🔭🌕 Galileo -- Set split
    if test_dataloader is not None:
        model.eval()
        val_loss = torch.tensor(0.0, device=device)
        val_correct = torch.tensor(0, device=device)

        with torch.no_grad():
            for val_minibatch in tqdm(test_dataloader):
                images = val_minibatch["image"].to(device)
                labels = val_minibatch["label"].to(device)
                
                preds = model(images)
                loss = criterion(preds, labels)

                val_loss += loss
                val_correct += (torch.argmax(preds, dim=1) == labels).sum()

        print(f"{VAL_SPLIT_NAME} loss: {val_loss:.2f}")
        print(f"{VAL_SPLIT_NAME} accuracy: {100*val_correct/len(test_dataloader.dataset):.2f}")

end = time()
print(f"Total training time: {end-start:.1f} seconds")
dq.finish()

# General Help and Docs
- To get help with your task's requirements, call `dq.get_data_logger().doc()`
- To see more general data and model logging docs, run `dq.docs()`

In [None]:
dq.get_data_logger().doc()
help(dq.log_dataset)