# Tractorun advanced

This notebook provides an extended demonstration of the advanced capabilities of the `tractorun` library. The focus will be on two key features:
1. **Checkpoints for PyTorch**: how to save a checkpoint and restore training from a checkpoint.
2. **Distributed Model Training**: how to run distributed training by `tractorun` on multiple nodes with multiple processes.

<font color="red">IMPORTANT NOTE</font> this notebook won't run on the [playground](https://playground.tracto.ai/) - it requires multiple hosts for distributed training, but only 1 host is available

For a basic example, please refer to [tractorun-torch-mnist](./tractorun-torch-mnist.ipynb).

In [1]:
import uuid
import sys

from yt import wrapper as yt
from yt import type_info

## Create a base directory for examples

In [3]:
# configure environment to run this notebooks
import uuid
import yt.wrapper as yt

username = yt.get_user_name()
if yt.exists(f"//sys/users/{username}/@user_info/home_path"):
    # prepare working directory on distributed file system
    user_info = yt.get(f"//sys/users/{yt.get_user_name()}/@user_info")
    homedir = user_info["home_path"]
    # find avaliable vm presets
    cpu_pool_trees = [pool_tree for pool_tree in user_info["available_pool_trees"] if pool_tree.endswith("cpu")] or ["default"]
    h100_pool_trees = [pool_tree for pool_tree in user_info["available_pool_trees"] if pool_tree.endswith("h100")]
    h100_8_pool_trees = [pool_tree for pool_tree in user_info["available_pool_trees"] if pool_tree.endswith("h100_8")]
    workdir = f"{homedir}/tmp/demo_workdir/{uuid.uuid4().hex}"
else:
    cpu_pool_trees = ["default"]
    h100_pool_trees = ["gpu_h100"]
    h100_8_pool_trees = ["gpu_h100"]
    workdir = f"//tmp/examples/{uuid.uuid4().hex}"

yt.create("map_node", workdir, recursive=True, ignore_existing=True)
print("Current working directory:", workdir)

Current working directory: //home/equal_amethyst_vulture/tmp/demo_workdir/f901bc5a573e450cb0074d7814b72c73


## Ensure torch and torchvision exist

Let's ensure that the system has installed `torch` and `torchvision`.

In [5]:
import torch
import torchvision 

## Run distributed training

Let's use [MNIST dataset](https://en.wikipedia.org/wiki/MNIST_database). This process of uploading data is described in the [basic tractorun notebook](./tractorun-torch-mnist.ipynb)

In [7]:
dataset_train_path = "//home/samples/mnist-torch-train"
dataset_test_path = "//home/samples/mnist-torch-test"

In order to run tractorun in distributed mode and using checkpoints:
1. Use `toolbox.checkpoint_manager` to manage checkpoints.
2. Set distributed training configuration by `tractorun.mesh.Mesh`

<details>
  <summary>Show the full diff</summary>

```diff
@@ -6,7 +6,15 @@
 from torchvision import datasets, transforms
 from torch.optim.lr_scheduler import StepLR

+from tractorun.backend.tractorch import YtTensorDataset, Tractorch
+from tractorun.toolbox import Toolbox
+from tractorun.run import run
+from tractorun.mesh import Mesh
+from tractorun.resources import Resources
+from tractorun.stderr_reader import StderrMode
+from tractorun.backend.tractorch.serializer import TensorSerializer

 class Net(nn.Module):
     def __init__(self):
         super(Net, self).__init__()
@@ -33,9 +41,12 @@
         return output


-def train(args, model, device, train_loader, optimizer, epoch):
+def train(args, model, device, train_loader, optimizer, epoch, first_batch_index, checkpoint_manager):
     model.train()
+    ts = TensorSerializer()
     for batch_idx, (data, target) in enumerate(train_loader):
+        if batch_idx < first_batch_index:
+            continue
         data, target = data.to(device), target.to(device)
         optimizer.zero_grad()
         output = model(data)
@@ -45,9 +56,18 @@
         if batch_idx % args.log_interval == 0:
             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                 epoch, batch_idx * len(data), len(train_loader.dataset),
-                100. * batch_idx / len(train_loader), loss.item()))
+                100. * batch_idx / len(train_loader), loss.item()), file=sys.stderr)
             if args.dry_run:
                 break
+            state_dict = {
+                "model": model.state_dict(),
+                "optimizer": optimizer.state_dict(),
+            }
+            metadata_dict = {
+                "first_batch_index": batch_idx + 1,
+                "loss": loss.item(),
+            }
+            checkpoint_manager.save_checkpoint(ts.serialize(state_dict), metadata_dict)


 def test(model, device, test_loader):
@@ -66,10 +86,10 @@

     print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
         test_loss, correct, len(test_loader.dataset),
-        100. * correct / len(test_loader.dataset)))
+        100. * correct / len(test_loader.dataset)), file=sys.stderr)


-def main():
+def main(toolbox: Toolbox):
     # Training settings
     parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
     parser.add_argument('--batch-size', type=int, default=64, metavar='N',
@@ -94,7 +114,7 @@
                         help='how many batches to wait before logging training status')
     parser.add_argument('--save-model', action='store_true', default=False,
                         help='For Saving the current Model')
-    args = parser.parse_args()
+    args = parser.parse_args([])
     use_cuda = not args.no_cuda and torch.cuda.is_available()
     use_mps = not args.no_mps and torch.backends.mps.is_available()

@@ -120,26 +140,48 @@
         transforms.ToTensor(),
         transforms.Normalize((0.1307,), (0.3081,))
         ])
-    dataset1 = datasets.MNIST('../data', train=True, download=True,
-                       transform=transform)
-    dataset2 = datasets.MNIST('../data', train=False,
-                       transform=transform)
+    dataset1 = YtTensorDataset(toolbox=toolbox, path=dataset_train_path, columns=['data', 'labels'])
+    dataset2 = YtTensorDataset(toolbox=toolbox, path=dataset_test_path, columns=['data', 'labels'])
+
     train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
     test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

     model = Net().to(device)
     optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

+    ts = TensorSerializer()
+    first_batch_index = 0
+    checkpoint = toolbox.checkpoint_manager.get_last_checkpoint()
+    if checkpoint is not None:
+        first_batch_index = checkpoint.metadata["first_batch_index"]
+        print(
+            "Found checkpoint with index",
+            checkpoint.index,
+            "and first batch index",
+            first_batch_index,
+            file=sys.stderr,
+        )
+
     scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
     for epoch in range(1, args.epochs + 1):
-        train(args, model, device, train_loader, optimizer, epoch)
+        train(args, model, device, train_loader, optimizer, epoch, first_batch_index, toolbox.checkpoint_manager)
         test(model, device, test_loader)
         scheduler.step()

     if args.save_model:
-        torch.save(model.state_dict(), "mnist_cnn.pt")
+        toolbox.save_model(ts.serialize(model.state_dict()), dataset_train_path, metadata={})


-if __name__ == '__main__':
-    main()
+run(
+    main,
+    backend=Tractorch(),
+    yt_path=training_dir,
+    mesh=Mesh(node_count=2, process_per_node=2, gpu_per_process=0),
+    resources=Resources(
+        cpu_limit=8,
+        memory_limit=105899345920,
+    ),
+    proxy_stderr_mode=StderrMode.primary,
+)
```
</details>

<font color="red">IMPORTANT NOTE</font> In this example we are running tractorun directly from Jupyter notebook.

This is a convenient method for experiments and demonstrations, as tractorun uses [pickle](https://docs.python.org/3/library/pickle.html) for easy serialization of the entire notebook state and transferring it to the cluster. This means that all variables will be available in the model training function, and tractorun will attempt to transfer all Python modules from the local environment to the cluster.

However, this method does not ensure reproducibility of the run of model's training. For production processes, use the execution via the tractorun CLI, which is described in [basic notebook](./tractorun-mnis.ipynb).

In [10]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

from tractorun.backend.tractorch import YtTensorDataset, Tractorch
from tractorun.toolbox import Toolbox
from tractorun.run import run
from tractorun.mesh import Mesh
from tractorun.resources import Resources
from tractorun.stderr_reader import StderrMode
from tractorun.backend.tractorch.serializer import TensorSerializer

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch, first_batch_index, checkpoint_manager):
    model.train()
    ts = TensorSerializer()
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx < first_batch_index:
            continue
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()), file=sys.stderr)
            if args.dry_run:
                break
            state_dict = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            metadata_dict = {
                "first_batch_index": batch_idx + 1,
                "loss": loss.item(),
            }
            checkpoint_manager.save_checkpoint(ts.serialize(state_dict), metadata_dict)


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)), file=sys.stderr)


def main(toolbox: Toolbox):
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--no-mps', action='store_true', default=False,
                        help='disables macOS GPU training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    args = parser.parse_args([])
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    use_mps = not args.no_mps and torch.backends.mps.is_available()

    torch.manual_seed(args.seed)

    if use_cuda:
        device = torch.device("cuda")
    elif use_mps:
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': False}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
    dataset1 = YtTensorDataset(yt_client=toolbox.yt_client, path=dataset_train_path, columns=['data', 'labels'])
    dataset2 = YtTensorDataset(yt_client=toolbox.yt_client, path=dataset_test_path, columns=['data', 'labels'])

    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    ts = TensorSerializer()
    first_batch_index = 0
    checkpoint = toolbox.checkpoint_manager.get_last_checkpoint()
    if checkpoint is not None:
        first_batch_index = checkpoint.metadata["first_batch_index"]
        print(
            "Found checkpoint with index",
            checkpoint.index,
            "and first batch index",
            first_batch_index,
            file=sys.stderr,
        )
        checkpoint_dict = serializer.desirialize(checkpoint.value)
        model.load_state_dict(checkpoint_dict["model"])
        optimizer.load_state_dict(checkpoint_dict["optimizer"])

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch, first_batch_index, toolbox.checkpoint_manager)
        test(model, device, test_loader)
        scheduler.step()

    if args.save_model:
        toolbox.save_model(ts.serialize(model.state_dict()), dataset_train_path, metadata={})


run(
    main,
    backend=Tractorch(),
    yt_path=training_dir,
    mesh=Mesh(node_count=2, process_per_node=1, gpu_per_process=1, pool_trees=h100_8_pool_trees),
    resources=Resources(
        cpu_limit=4,
        memory_limit=32212254720,
    ),
    proxy_stderr_mode=StderrMode.primary,
)