In [None]:
!pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu121

Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.1.1
  Using cached https://download.pytorch.org/whl/cu121/torch-2.1.1%2Bcu121-cp39-cp39-linux_x86_64.whl (2200.7 MB)
Collecting torchvision==0.16.1
  Using cached https://download.pytorch.org/whl/cu121/torchvision-0.16.1%2Bcu121-cp39-cp39-linux_x86_64.whl (6.8 MB)
Collecting torchaudio==2.1.1
  Using cached https://download.pytorch.org/whl/cu121/torchaudio-2.1.1%2Bcu121-cp39-cp39-linux_x86_64.whl (3.3 MB)
Collecting filelock (from torch==2.1.1)
  Using cached https://download.pytorch.org/whl/filelock-3.13.1-py3-none-any.whl (11 kB)
Collecting sympy (from torch==2.1.1)
  Using cached https://download.pytorch.org/whl/sympy-1.12-py3-none-any.whl (5.7 MB)
Collecting networkx (from torch==2.1.1)
  Using cached https://download.pytorch.org/whl/networkx-3.2.1-py3-none-any.whl (1.6 MB)
Collecting fsspec (from torch==2.1.1)
  Using cached https://download.pytorch.org/whl/fsspec-2024.2.0-py3-none-any.whl (170 kB)
Coll

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils import data as dt
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import numpy as np
import warnings
import mlflow
import requests
warnings.simplefilter("ignore")

In [None]:
import pathlib
# BASE_DIR will be like '/home/jovyan/DemoExample/'
BASE_DIR = pathlib.Path().absolute()
print(f"Working dir: {BASE_DIR}")

## Download dataset

In [None]:
def save_file(url, filename):
    # Download file and place it on local storage
    try:
        r = requests.get(url, timeout=10)

        with open(filename, 'wb') as f:
            f.write(r.content)
        print(f"{filename} downloaded from {url}")
    except requests.exceptions.Timeout:
        print(f"No internet connection")

In [None]:
save_file("https://github.com/sbercloud-ai/aicloud-examples/raw/master/quick-start/notebooks_gpu/mnist.npz", BASE_DIR.joinpath("mnist.npz"))

## Load dataset

In [None]:
data = np.load(BASE_DIR.joinpath('mnist.npz'))
mnist_images_train = np.expand_dims(data['x_train'], 1)
mnist_labels_train = data['y_train']

mnist_images_test = np.expand_dims(data['x_test'], 1)
mnist_labels_test = data['y_test']
data.close()

dataset_train = dt.TensorDataset(torch.Tensor(mnist_images_train), torch.Tensor(mnist_labels_train).long())
dataset_test = dt.TensorDataset(torch.Tensor(mnist_images_test), torch.Tensor(mnist_labels_test).long())

train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=50)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=50)

## Define model

In [None]:
class CNNClassifier(nn.Module):
    """Custom module for a simple convnet classifier"""

    def __init__(self):
        super(CNNClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.dropout = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.dropout(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x)

In [None]:
clf = CNNClassifier()
device = torch.device(f'cuda:0')

## DataParallel if several GPUs

In [None]:
torch.cuda.is_available()

In [None]:
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    clf = nn.DataParallel(clf)

In [None]:
clf.to(device)

In [None]:
current_time = datetime.now().strftime("%Y%m%d-%H_%M")
writer = SummaryWriter(log_dir=BASE_DIR.joinpath('logs/log_' + current_time))

In [None]:
optimizer = optim.SGD(clf.parameters(), lr=0.01, momentum=0.5)

In [None]:
def train(epoch, clf, optimizer, writer):
    clf.train()  # set model in training mode (need this because of dropout)

    # dataset API gives us pythonic batching
    for batch_id, (data, target) in enumerate(train_loader):

        data = data.to(device)
        target = target.to(device)
        # forward pass, calculate loss and backprop!
        optimizer.zero_grad()
        preds = clf(data)
        loss = F.nll_loss(preds, target)
        loss.backward()

        optimizer.step()

        if batch_id % 100 == 0:
            print(f'train loss = {loss.item()}')
            writer.add_scalar('Train', loss.item(), epoch * len(train_loader) + batch_id)

In [None]:
def test(epoch, clf, writer):
    clf.eval()  # set model in inference mode (need this because of dropout)
    test_loss = 0
    correct = 0

    for data, target in test_loader:
        data = data.to(device)
        target = target.to(device)
        output = clf(data)
        test_loss += F.nll_loss(output, target).item()
        pred = output.data.max(1)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data).cpu().sum()
    
    test_loss = test_loss
    test_loss /= len(test_loader)  # loss function already averages over batch size
    accuracy = 100. * correct / len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        accuracy))
    
    
    mlflow.log_metric("Test loss", test_loss)  # add mlflow metrics
    mlflow.log_metric("Accuracy", np.round(accuracy.item(),1)) # add mlflow metrics

In [None]:
num_epochs = 3
print(f'Start train {num_epochs} epochs total')

# Loading from checkpoint
# https://pytorch.org/tutorials/beginner/saving_loading_models.html
last_epoch = 0
import os
for root, dirs, files in os.walk(BASE_DIR.joinpath('logs')):
    saved_models = [model_filename for model_filename in files if ".bin" in model_filename]

if saved_models:
    checkpoint = torch.load(os.path.join(root, saved_models[-1]))
    clf.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    last_epoch = checkpoint['epoch']
    print(f"Continue training from {last_epoch} epoch")

# Start training
mlflow.set_tracking_uri('file:/home/jovyan/mlruns')
mlflow.set_experiment("pytorch_tensorboard_mlflow.ipynb")
with mlflow.start_run(nested=True) as run:
    for epoch in range(num_epochs):
        if last_epoch:
            epoch += last_epoch + 1

        print("Epoch %d" % epoch)
        train(epoch, clf, optimizer, writer)
        test(epoch, clf, writer)
        # Save checkpoint every epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': clf.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, BASE_DIR.joinpath('logs/log_' + current_time + f"/model_epoch_{epoch}.bin"))
        writer.close()