## 2. Single GPU Training with PyTorch

### 2.1. Overview

We will start by fitting a `ResNet18` model to an `MNIST` dataset. Conceptually we will follow the below recipe presented below.

|<img src="https://anyscale-public-materials.s3.us-west-2.amazonaws.com/ray-ai-libraries/diagrams/single_gpu_pytorch_v3.png" width="70%" loading="lazy">|
|:--|
|An overview of the single GPU training process. At a high level, here is how training loop in PyTorch looks like. The key stages include loading the dataset; run the training on mini-batches on a single GPU; saving the model checkpoint to the persistent storage.

In [None]:
from tqdm.notebook import tqdm
from pathlib import Path
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

def train_loop_torch(num_epochs: int = 2, batch_size: int = 128, local_path: str = "./checkpoints"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    criterion = CrossEntropyLoss()
    model = load_model_torch().to(device)
    optimizer = Adam(model.parameters(), lr=1e-5)
    data_loader = build_data_loader_torch(batch_size=batch_size)

    for epoch in tqdm(range(num_epochs), desc="Epochs"):
        epoch_loss = 0.0

        for images, labels in tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(data_loader)

        metrics = report_metrics_torch(loss=avg_loss, epoch=epoch)
        Path(local_path).mkdir(parents=True, exist_ok=True)
        save_checkpoint_and_metrics_torch(metrics=metrics, model=model, local_path=local_path)


<div class="alert alert-block alert-info">

Quick notes:

<ul>
    <li><code>report_metrics_torch</code> and <code>save_checkpoint_and_metrics_torch</code> are defined below,</li>
    <li><code>local_path</code> is used for checkpointing. (default) Current working directory simply points to the notebook location (check <code>pwd</code> below).</li>
</ul>
</div>

In [None]:
!pwd

### 2.2. Build model and load it on the GPU

Build [Resnet18](https://pytorch.org/vision/main/models/resnet.html#resnet)

In [None]:
def build_resnet18():
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        in_channels=1, # grayscale MNIST images
        out_channels=64,
        kernel_size=(7, 7),
        stride=(2, 2),
        padding=(3, 3),
        bias=False,
    )
    return model

<div class="alert alert-block alert-info">

resnet18's <code>model.conv1</code> has <code>in_channels=3</code> by default. Here, we work with the <a href="https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html#mnist" target="_blank">MNIST</a> grayscale images, thus <code>in_channels=1</code>.
</div>

Load the model on a single GPU

In [None]:
def load_model_torch() -> torch.nn.Module:
    model = build_resnet18()

    # Move to the GPU device if available, otherwise use CPU
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    return model

### 2.3. Create Dataset and DataLoader

In [None]:
dataset = MNIST(root="./data", train=True, download=True)

In [None]:
!tree ./data

Let's display 9 example (image, target) pairs:

In [None]:
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3

for i in range(1, cols * rows + 1):
    sample_idx = np.random.randint(0, len(dataset.data))
    img, label = dataset[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img, cmap="gray")

Define a DataLoader to apply transformations and load data in batches

In [None]:
def build_data_loader_torch(batch_size: int) -> torch.utils.data.DataLoader:
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    dataset = MNIST(root="./data", train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    return train_loader

### 2.4. Create metrics and checkpointing

Compute and report the metrics using a simple print statement, and also save them to a CSV file.

In [None]:
def report_metrics_torch(loss: torch.Tensor, epoch: int) -> None:
    metrics = {"loss": loss, "epoch": epoch}
    print(metrics)
    return metrics

Save the checkpoint in a previously defined local directory.

In [None]:
def save_checkpoint_and_metrics_torch(metrics: dict[str, float], model: torch.nn.Module, local_path: str) -> None:

    # Save the metrics
    with open(os.path.join(local_path, "metrics.csv"), "a") as f:
        writer = csv.writer(f)
        writer.writerow(metrics.values())

    # Save the model
    checkpoint_path = os.path.join(local_path, "model.pt")
    torch.save(model.state_dict(), checkpoint_path)

### 2.5. Run the training loop
Schedule the training loop on a single GPU

In [None]:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
storage_folder = "/mnt/cluster_storage"  # Modify this path to your local folder if it runs on your local environment
local_path = f"{storage_folder}/torch_{timestamp}/"

<div class="alert alert-info">

<b>Note about Anyscale storage options</b>

In this example <code>local_path</code> points to the Anyscale's <a href="https://docs.anyscale.com/configuration/storage/#local-storage-for-a-node" target="_blank">local storage</a>. It's a convenient and quick access location for this basic example.

* Anyscale provides each node with its own volume and disk and doesn't share them with other nodes.
* Local storage is very fast - Anyscale supports the Non-Volatile Memory Express (NVMe) interface.
* This is not a persistent storage, Anyscale deletes data in the local storage after instances are terminated. 

Read more about available <a href="https://docs.anyscale.com/configuration/storage" target="_blank">storage</a> options.
</div>

Start the training:
If you run it on your local, it can take up to 15 minutes.

In [None]:
train_loop_torch(
    num_epochs=3,
    local_path=local_path
)

Let's inspect the produced checkpoints and metrics

In [None]:
!ls -l {local_path}

In [None]:
metrics = pd.read_csv(
    os.path.join(local_path, "metrics.csv"),
    header=None,
    names=["loss", "epoch"],
)

metrics

### 2.6. Use checkpointed model to generate predictions

Load model checkpoint to the device

Generate predictions on randomly selected 9 images from the MNIST dataset.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
loaded_model = build_resnet18()
loaded_model.load_state_dict(torch.load(os.path.join(local_path, "model.pt")))
loaded_model.to(device)
loaded_model.eval()

In [None]:
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3

for i in range(1, cols * rows + 1):
    sample_idx = np.random.randint(0, len(dataset.data))
    img, label = dataset[sample_idx]
    normalized_img = Normalize((0.5,), (0.5,))(ToTensor()(img))
    normalized_img = normalized_img.to(device)

    # use loaded model to generate preds
    with torch.no_grad():        
        prediction = loaded_model(normalized_img.unsqueeze(0)).argmax().cpu()

    figure.add_subplot(rows, cols, i)
    plt.title(f"label: {label}; pred: {int(prediction)}")
    plt.axis("off")
    plt.imshow(img, cmap="gray")