# Tractorun

This notebook demonstrates the `Tractorun` library and CLI tool for running distributed machine learning tasks on the `Tracto`. `Tractorun` provides convenient tools to integrate with Tracto distributed data processing system, enabling execution and management of machine learning training jobs.

Tractorun:
1. Manages the configuration and coordination of distributed training.
2. Provides tools for working with `YtDataset` (allows you to use data on Tracto as a dataset), checkpoints, saving models, interacting with `tensorproxy`, and more.
3. Ensures integration with the Tracto ecosystem.

In this notebook, we cover the following steps:

1. Uploading a PyTorch dataset to Tracto.
2. Training a model using MNIST. We perform model training on the MNIST dataset directly from a Jupyter Notebook, leveraging Tracto as the computation platform.
3. Running the same training with Tractorun CLI. We'll demonstrate how to run the same training job via the command line using the Tractorun CLI.

We use the official PyTorch [MNIST training example](https://github.com/pytorch/examples/blob/cdef4d43fb1a2c6c4349daa5080e4e8731c34569/mnist/main.py) as a reference and show how to modify it with minimal changes to run using `Tractorun`.

In [1]:
from yt import wrapper as yt
from yt import type_info

In [2]:
import uuid
import sys
import io

## Create a base directory for examples

In [4]:
# configure environment to run this notebooks
import uuid
import yt.wrapper as yt

username = yt.get_user_name()
if yt.exists(f"//sys/users/{username}/@user_info/home_path"):
    # prepare working directory on distributed file system
    user_info = yt.get(f"//sys/users/{yt.get_user_name()}/@user_info")
    homedir = user_info["home_path"]
    # find avaliable vm presets
    cpu_pool_trees = [pool_tree for pool_tree in user_info["available_pool_trees"] if pool_tree.endswith("cpu")] or ["default"]
    h100_pool_trees = [pool_tree for pool_tree in user_info["available_pool_trees"] if pool_tree.endswith("h100")]
    h100_8_pool_trees = [pool_tree for pool_tree in user_info["available_pool_trees"] if pool_tree.endswith("h100-8")]
    workdir = f"{homedir}/tmp/demo_workdir/{uuid.uuid4().hex}"
else:
    cpu_pool_trees = ["default"]
    h100_pool_trees = ["gpu_h100"]
    h100_8_pool_trees = ["gpu_h100"]
    workdir = f"//tmp/examples/{uuid.uuid4().hex}"

yt.create("map_node", workdir, recursive=True, ignore_existing=True)
print("Current working directory:", workdir)

Current working directory: //home/equal_amethyst_vulture/tmp/demo_workdir/594e35db3a434a748e844f4d22c07e6c


## Ensure torch and torchvision exist

Let's ensure that the system has installed `torch` and `torchvision`.

In [6]:
import torch
import torchvision 
from torchvision import datasets, transforms

## Upload dataset to Tracto

For this demonstration, we will use the MNIST dataset from the `torchvision` library and upload it to Tracto. Some rows in the dataset exceed the standard limits, so we will set `table_writer={"max_row_weight": 50 * 1024 * 1024}`.

In [8]:
# https://github.com/pytorch/examples/blob/26de41904319c7094afc53a3ee809de47112d387/mnist/main.py#L119
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ],
)

dataset_train_local = datasets.MNIST("./mnist", train=True, download=True)
dataset_test_local = datasets.MNIST("./mnist", train=False, download=True)

  0%|          | 0.00/9.91M [00:00<?, ?B/s]

  1%|          | 98.3k/9.91M [00:00<00:19, 514kB/s]

  2%|▏         | 197k/9.91M [00:00<00:18, 518kB/s] 

  4%|▍         | 426k/9.91M [00:00<00:09, 1.04MB/s]

  9%|▉         | 918k/9.91M [00:00<00:04, 2.21MB/s]

 20%|██        | 2.03M/9.91M [00:00<00:01, 4.89MB/s]

 43%|████▎     | 4.29M/9.91M [00:00<00:00, 10.2MB/s]

 81%|████████  | 8.03M/9.91M [00:00<00:00, 14.7MB/s]

100%|██████████| 9.91M/9.91M [00:00<00:00, 10.1MB/s]




  0%|          | 0.00/28.9k [00:00<?, ?B/s]

100%|██████████| 28.9k/28.9k [00:00<00:00, 304kB/s]




  0%|          | 0.00/1.65M [00:00<?, ?B/s]

  6%|▌         | 98.3k/1.65M [00:00<00:03, 507kB/s]

 10%|▉         | 164k/1.65M [00:00<00:02, 570kB/s] 

 26%|██▌       | 426k/1.65M [00:00<00:01, 991kB/s]

 56%|█████▌    | 918k/1.65M [00:00<00:00, 2.10MB/s]

100%|██████████| 1.65M/1.65M [00:00<00:00, 2.40MB/s]




  0%|          | 0.00/4.54k [00:00<?, ?B/s]

100%|██████████| 4.54k/4.54k [00:00<00:00, 8.62MB/s]




Let's upload on Tracto the MNIST dataset as tensors and as simple types. There are 4 columns:
* `image` - raw png image. This column has the tag "image/png" which allows to draw images directly in the Tracto UI.
* `number` - human-readable label.
* `data` and `labels` - serialized tensor form of dataset's data and label.

It is more efficient to save ready-to-use tensors in YT right away to save time and resources during model training. In the following examples, we will work only with columns containing tensors.

In [10]:
schema = yt.schema.TableSchema()
schema.add_column("image", type_info.Tagged[type_info.String, "image/png"])
schema.add_column("number", type_info.Int8)
schema.add_column("data", type_info.String)
schema.add_column("labels", type_info.String)

TableSchema({'value': [{'name': 'image', 'type_v3': {'type_name': 'tagged', 'item': 'string', 'tag': 'image/png'}}, {'name': 'number', 'type_v3': 'int8'}, {'name': 'data', 'type_v3': 'string'}, {'name': 'labels', 'type_v3': 'string'}], 'attributes': {'strict': True, 'unique_keys': False}})

In [11]:
from tractorun.backend.tractorch import TensorSerializer

dataset_train_path = f"{workdir}/dataset_train"
dataset_test_path = f"{workdir}/dataset_test"
print(dataset_train_path)
print(dataset_test_path)

yt.create("table", dataset_train_path, force=True, attributes={"schema": schema.to_yson_type()})
yt.create("table", dataset_test_path, force=True, attributes={"schema": schema.to_yson_type()})

def pil_to_png(image):
    r = io.BytesIO()
    image.save(r, format="PNG")
    return r.getvalue()

ts = TensorSerializer()

yt_train_data = [
    {
        "image": pil_to_png(data),
        "number": labels,
        "labels": ts.serialize(labels),
        "data": ts.serialize(transform(data)),
    }
    for data, labels in dataset_train_local
]
yt.write_table(dataset_train_path, yt_train_data, table_writer={"max_row_weight": 50 * 1024 * 1024})

yt_test_data = [
    {
        "image": pil_to_png(data),
        "number": labels,
        "labels": ts.serialize(labels),
        "data": ts.serialize(transform(data)),
    }
    for data, labels in dataset_test_local
]
yt.write_table(dataset_test_path, yt_test_data, table_writer={"max_row_weight": 50 * 1024 * 1024})

//home/equal_amethyst_vulture/tmp/demo_workdir/594e35db3a434a748e844f4d22c07e6c/dataset_train
//home/equal_amethyst_vulture/tmp/demo_workdir/594e35db3a434a748e844f4d22c07e6c/dataset_test


## Run training

Tractorun store some data to the training dir:
1. Checkpoints.
2. Metadata about each training run.
3. Models.
4. Some locks.
5. etc

Let's create and cleanup the training dir.

In [13]:
training_dir = f"{workdir}/tractorun"
yt.create("map_node", training_dir, force=True)

print(training_dir)

//home/equal_amethyst_vulture/tmp/demo_workdir/594e35db3a434a748e844f4d22c07e6c/tractorun


The model training process run in a Docker container. When launching from a Jupyter Notebook, it is important to ensure that the same container as in the `Kernel` is used.

We use the official PyTorch [MNIST training example](https://github.com/pytorch/examples/blob/cdef4d43fb1a2c6c4349daa5080e4e8731c34569/mnist/main.py) as a reference and show how to modify it with minimal changes to run using Tractorun:
1. Add `toolbox: Toolbox` to the main function. Toolbox object provides useful utils for training like checkpoint manager, coordination metadata, initialized ytsaurus client, and more.
2. Add `file=sys.stderr` to each print.
3. Use `YtTensorDataset` instead of default `torch.Dataset`.
4. Call magic function `tractorun.run.run`.

<details>
  <summary>Show the full diff</summary>

```diff
@@ -6,6 +6,13 @@
 from torchvision import datasets, transforms
 from torch.optim.lr_scheduler import StepLR

+from tractorun.backend.tractorch import YtTensorDataset, Tractorch
+from tractorun.toolbox import Toolbox
+from tractorun.run import run
+from tractorun.mesh import Mesh
+from tractorun.resources import Resources
+from tractorun.stderr_reader import StderrMode
+from tractorun.backend.tractorch.serializer import TensorSerializer

 class Net(nn.Module):
     def __init__(self):
@@ -45,7 +52,7 @@
         if batch_idx % args.log_interval == 0:
             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                 epoch, batch_idx * len(data), len(train_loader.dataset),
-                100. * batch_idx / len(train_loader), loss.item()))
+                100. * batch_idx / len(train_loader), loss.item()), file=sys.stderr)
             if args.dry_run:
                 break

@@ -66,10 +73,10 @@

     print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
         test_loss, correct, len(test_loader.dataset),
-        100. * correct / len(test_loader.dataset)))
+        100. * correct / len(test_loader.dataset)), file=sys.stderr)


-def main():
+def main(toolbox: Toolbox):
     # Training settings
     parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
     parser.add_argument('--batch-size', type=int, default=64, metavar='N',
@@ -94,7 +101,7 @@
                         help='how many batches to wait before logging training status')
     parser.add_argument('--save-model', action='store_true', default=False,
                         help='For Saving the current Model')
-    args = parser.parse_args()
+    args = parser.parse_args([])
     use_cuda = not args.no_cuda and torch.cuda.is_available()
     use_mps = not args.no_mps and torch.backends.mps.is_available()

@@ -120,10 +127,9 @@
         transforms.ToTensor(),
         transforms.Normalize((0.1307,), (0.3081,))
         ])
-    dataset1 = datasets.MNIST('../data', train=True, download=True,
-                       transform=transform)
-    dataset2 = datasets.MNIST('../data', train=False,
-                       transform=transform)
+    dataset1 = YtTensorDataset(path=dataset_train_path, columns=['data', 'labels'])
+    dataset2 = YtTensorDataset(path=dataset_test_path, columns=['data', 'labels'])
+
     train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
     test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

@@ -137,9 +143,20 @@
         scheduler.step()

     if args.save_model:
-        torch.save(model.state_dict(), "mnist_cnn.pt")
+        ts = TensorSerializer()
+        toolbox.save_model(ts.serialize(model.state_dict()), dataset_train_path, metadata={})


-if __name__ == '__main__':
-    main()
+run(
+    main,
+    backend=Tractorch(),
+    yt_path=training_dir,
+    mesh=Mesh(node_count=1, process_per_node=1, gpu_per_process=1, pool_trees=h100_pool_trees),
+    resources=Resources(
+        cpu_limit=8,
+        memory_limit=105899345920,
+    ),
+    proxy_stderr_mode=StderrMode.primary,
+)
```
</details>

<font color="red">IMPORTANT NOTE</font> In this example we are running tractorun directly from Jupyter notebook.

This is a convenient method for experiments and demonstrations, as tractorun uses [pickle](https://docs.python.org/3/library/pickle.html) for easy serialization of the entire notebook state and transferring it to the cluster. This means that all variables will be available in the model training function, and tractorun will attempt to transfer all Python modules from the local environment to the cluster.

However, this method does not ensure reproducibility of the run of model's training. For production processes, use the execution via the tractorun CLI, which is described below.

In [16]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

from tractorun.backend.tractorch import YtTensorDataset, Tractorch
from tractorun.toolbox import Toolbox
from tractorun.run import run
from tractorun.mesh import Mesh
from tractorun.resources import Resources
from tractorun.stderr_reader import StderrMode
from tractorun.backend.tractorch.serializer import TensorSerializer


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()), file=sys.stderr)
            if args.dry_run:
                break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)), file=sys.stderr)


def main(toolbox: Toolbox):
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--no-mps', action='store_true', default=False,
                        help='disables macOS GPU training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    args = parser.parse_args([])
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    use_mps = not args.no_mps and torch.backends.mps.is_available()

    torch.manual_seed(args.seed)

    if use_cuda:
        device = torch.device("cuda")
    elif use_mps:
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': False}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
    dataset1 = YtTensorDataset(path=dataset_train_path, yt_client=toolbox.yt_client, columns=['data', 'labels'])
    dataset2 = YtTensorDataset(path=dataset_test_path, yt_client=toolbox.yt_client, columns=['data', 'labels'])

    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        scheduler.step()

    if args.save_model:
        ts = TensorSerializer()
        toolbox.save_model(ts.serialize(model.state_dict()), dataset_train_path, metadata={})


run(
    main,
    backend=Tractorch(),
    yt_path=training_dir,
    mesh=Mesh(node_count=1, process_per_node=1, gpu_per_process=1, pool_trees=h100_pool_trees),
    resources=Resources(
        cpu_limit=10,
        memory_limit=32212254720,
    ),
    proxy_stderr_mode=StderrMode.primary,
)





2025-06-19 20:44:27,225	INFO	Operation started: https://playground.tracto.ai/playground/operations/833c9385-39d5f04a-24dd03e8-b3861697/details


2025-06-19 20:44:27,268	INFO	( 0 min) operation 833c9385-39d5f04a-24dd03e8-b3861697 initializing


2025-06-19 20:44:30,066	INFO	( 0 min) Unrecognized spec: {'enable_partitioned_data_balancing': false}


2025-06-19 20:44:30,118	INFO	( 0 min) operation 833c9385-39d5f04a-24dd03e8-b3861697: running=0     completed=0     pending=1     failed=0     aborted=0     lost=0     total=1     blocked=0    


2025-06-19 20:44:31,292	INFO	( 0 min) operation 833c9385-39d5f04a-24dd03e8-b3861697: running=1     completed=0     pending=0     failed=0     aborted=0     lost=0     total=1     blocked=0    


  __tar.extractall(destination)


  __tar.extractall(destination)
Failed to write user statistics
[rank0]:[W619 20:44:34.999815987 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.







Test set: Average loss: 0.0531, Accuracy: 9821/10000 (98%)








Test set: Average loss: 0.0368, Accuracy: 9874/10000 (99%)








Test set: Average loss: 0.0354, Accuracy: 9879/10000 (99%)








Test set: Average loss: 0.0360, Accuracy: 9883/10000 (99%)








Test set: Average loss: 0.0316, Accuracy: 9898/10000 (99%)








Test set: Average loss: 0.0318, Accuracy: 9898/10000 (99%)








Test set: Average loss: 0.0304, Accuracy: 9907/10000 (99%)








Test set: Average loss: 0.0304, Accuracy: 9909/10000 (99%)








Test set: Average loss: 0.0303, Accuracy: 9905/10000 (99%)








Test set: Average loss: 0.0295, Accuracy: 9908/10000 (99%)








Test set: Average loss: 0.0281, Accuracy: 9915/10000 (99%)






Test set: Average loss: 0.0284, Accuracy: 9911/10000 (99%)








Test set: Average loss: 0.0282, Accuracy: 9916/10000 (99%)








Test set: Average loss: 0.0286, Accuracy: 9917/10000 (99%)



2025-06-19 20:48:06,908	INFO	( 3 min) operation 833c9385-39d5f04a-24dd03e8-b3861697 completed


RunInfo(operation_spec={'intermediate_data_medium': 'nvme', 'intermediate_data_account': 'equal_amethyst_vulture', 'started_by': {'hostname': 'end-end-4.exec-nodes-end.nebius-playground.svc.kyt.k8s.nebius.yt', 'pid': 618, 'wrapper_version': '0.13.28', 'python_version': '3.12.10', 'binary_name': 'ipykernel_launcher.py', 'command': ['/slot/sandbox/jlab/lib/python42/site-packages/ipykernel_launcher.py', '-f', '/slot/sandbox/.local/share/jupyter/runtime/kernel-2a501a16-d03c-4234-9f78-97ef00f8ebd3.json'], 'user': 'root', 'platform': 'Debian GNU/Linux 12 (bookworm)'}, 'fail_on_job_restart': True, 'is_gang': True, 'annotations': {'is_tractorun': True}, 'tasks': {'task': {'command': 'python3 _py_runner.py wrapped.pickle config_dump _modules_info modules/_main_module.py _main_module PY_SOURCE', 'job_count': 1, 'gpu_limit': 1, 'port_count': 1, 'cpu_limit': 10, 'memory_limit': 32213319843, 'docker_image': 'cr.eu-north1.nebius.cloud/e00faee7vas5hpsh3s/solutions/examples:v5', 'file_paths': [{'value

## Tractorun cli

Let's consider a production-like scenario for running model training through the `Tractorun CLI`.

The `Tractorun CLI` allows:
1. Make model training reproducible.
2. Separating the model training code from the training run parameters. `Tractorun CLI` enables configuring the training process via a configuration file and CLI parameters.
3. Running the training module from any host where Python and `Tractorun` are installed.

We will use the official PyTorch [MNIST training example](https://github.com/pytorch/examples/blob/cdef4d43fb1a2c6c4349daa5080e4e8731c34569/mnist/main.py) again.
How to modify it with minimal changes to run using Tractorun:
1. Add `toolbox = prepare_and_get_toolbox(backend=Tractorch())` to the main function. Toolbox object provides useful utils for training like checkpoint manager, coordination metadata, initialized ytsaurus client, and more.
2. Add `file=sys.stderr` to each print.
3. Use `YtTensorDataset` instead of default `torch.Dataset`.

<details>
  <summary>Show the full diff</summary>

```diff
@@ -6,6 +6,15 @@
 from torchvision import datasets, transforms
 from torch.optim.lr_scheduler import StepLR

+import sys
+from tractorun.backend.tractorch import YtTensorDataset, Tractorch
+from tractorun.toolbox import Toolbox
+from tractorun.run import run
+from tractorun.mesh import Mesh
+from tractorun.resources import Resources
+from tractorun.stderr_reader import StderrMode
+from tractorun.backend.tractorch.serializer import TensorSerializer
+from tractorun.run import prepare_and_get_toolbox

 class Net(nn.Module):
     def __init__(self):
@@ -45,7 +54,7 @@
         if batch_idx % args.log_interval == 0:
             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                 epoch, batch_idx * len(data), len(train_loader.dataset),
-                100. * batch_idx / len(train_loader), loss.item()))
+                100. * batch_idx / len(train_loader), loss.item()), file=sys.stderr)
             if args.dry_run:
                 break

@@ -66,10 +75,13 @@

     print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
         test_loss, correct, len(test_loader.dataset),
-        100. * correct / len(test_loader.dataset)))
+        100. * correct / len(test_loader.dataset)), file=sys.stderr)


 def main():
+    toolbox = prepare_and_get_toolbox(backend=Tractorch())
+    user_config = toolbox.get_user_config()
+
     # Training settings
     parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
     parser.add_argument('--batch-size', type=int, default=64, metavar='N',
@@ -94,7 +106,7 @@
                         help='how many batches to wait before logging training status')
     parser.add_argument('--save-model', action='store_true', default=False,
                         help='For Saving the current Model')
-    args = parser.parse_args()
+    args = parser.parse_args([])
     use_cuda = not args.no_cuda and torch.cuda.is_available()
     use_mps = not args.no_mps and torch.backends.mps.is_available()

@@ -120,10 +132,9 @@
         transforms.ToTensor(),
         transforms.Normalize((0.1307,), (0.3081,))
         ])
-    dataset1 = datasets.MNIST('../data', train=True, download=True,
-                       transform=transform)
-    dataset2 = datasets.MNIST('../data', train=False,
-                       transform=transform)
+    dataset1 = YtTensorDataset(toolbox=toolbox, path=user_config["dataset_train_path"], columns=['data', 'labels'])
+    dataset2 = YtTensorDataset(toolbox=toolbox, path=user_config["dataset_test_path"], columns=['data', 'labels'])
+
     train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
     test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

@@ -137,9 +148,9 @@
         scheduler.step()

     if args.save_model:
-        torch.save(model.state_dict(), "mnist_cnn.pt")
+        ts = TensorSerializer()
+        toolbox.save_model(ts.serialize(model.state_dict()), dataset_train_path, metadata={})


-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
```
</details>

Let's create two files:
1. `main.py` with our model-training code.
2. `run_config.yaml` with training configuration.

In [19]:
code = r"""
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

import sys
from tractorun.backend.tractorch import YtTensorDataset, Tractorch
from tractorun.toolbox import Toolbox
from tractorun.run import run
from tractorun.mesh import Mesh
from tractorun.resources import Resources
from tractorun.stderr_reader import StderrMode
from tractorun.backend.tractorch.serializer import TensorSerializer
from tractorun.run import prepare_and_get_toolbox

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()), file=sys.stderr)
            if args.dry_run:
                break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)), file=sys.stderr)


def main():
    toolbox = prepare_and_get_toolbox(backend=Tractorch())
    user_config = toolbox.get_user_config()

    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--no-mps', action='store_true', default=False,
                        help='disables macOS GPU training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    args = parser.parse_args([])
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    use_mps = not args.no_mps and torch.backends.mps.is_available()

    torch.manual_seed(args.seed)

    if use_cuda:
        device = torch.device("cuda")
    elif use_mps:
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
    dataset1 = YtTensorDataset(toolbox=toolbox, yt_client=toolbox.yt_client, path=user_config["dataset_train_path"], columns=['data', 'labels'])
    dataset2 = YtTensorDataset(toolbox=toolbox, yt_client=toolbox.yt_client, path=user_config["dataset_test_path"], columns=['data', 'labels'])

    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        scheduler.step()

    if args.save_model:
        ts = TensorSerializer()
        toolbox.save_model(ts.serialize(model.state_dict()), dataset_train_path, metadata={})


if __name__ == "__main__":
    main()
"""

with open("main.py", "w") as f:
    f.write(code)

In [20]:
import yaml


config = {
    "yt_path": training_dir,
    "mesh": {
        "node_count": 1,
        "process_per_node": 1,
        "gpu_per_process": 0,
        "pool_trees": list(map(str, h100_pool_trees)),
    },  
    "user_config": {
        "dataset_train_path": dataset_train_path,
        "dataset_test_path": dataset_test_path,
    },
    "resources": {
        "cpu_limit": 10,
        "memory_limit": 32212254720,
    },
    "bind_local": [
        {
            "source": "./main.py",
            "destination": "/tractorun/main.py",
        },
    ],
    "command": ["python3", "/tractorun/main.py"],
    "proxy_stderr_mode": "primary",
}

with open("run_config.yaml", "w") as f:
    yaml.dump(config, f)

In [21]:
!tractorun --run-config-path run_config.yaml

  0%|                                                               | 0.00/19.6kStarting upload:   0%|                                              | 0.00/19.6k

(1/2) [UPLOAD] __file_0.zip:   0%|                                  | 0.00/19.6k

(1/2) [UPLOAD] __file_0.zip:  33%|##########7                      | 6.39k/19.6k

(1/2) [OK] __file_0.zip:  33%|############                         | 6.39k/19.6k

(1/2) [OK] __file_0.zip:  33%|############                         | 6.39k/19.6k

(2/2) [UPLOAD] __bootstrap_config:  33%|########7                  | 6.39k/19.6k

(2/2) [UPLOAD] __bootstrap_config: 100%|###########################| 19.6k/19.6k

(2/2) [OK] __bootstrap_config: 100%|###############################| 19.6k/19.6k

(2/2) [OK] __bootstrap_config: 100%|###############################| 19.6k/19.6k

(2/2) [OK] __bootstrap_config: 100%|###############################| 19.6k/19.6k


2025-06-19 20:51:09,904	INFO	Operation started: https://playground.tracto.ai/playground/operations/af3941d-fca6ac76-24dd03e8-9b60f6d2/details


2025-06-19 20:51:09,944	INFO	( 0 min) operation af3941d-fca6ac76-24dd03e8-9b60f6d2 starting


2025-06-19 20:51:10,480	INFO	( 0 min) operation af3941d-fca6ac76-24dd03e8-9b60f6d2 initializing


2025-06-19 20:51:12,694	INFO	( 0 min) Unrecognized spec: {'enable_partitioned_data_balancing': false}


2025-06-19 20:51:12,734	INFO	( 0 min) operation af3941d-fca6ac76-24dd03e8-9b60f6d2: running=0     completed=0     pending=1     failed=0     aborted=0     lost=0     total=1     blocked=0    


2025-06-19 20:51:13,893	INFO	( 0 min) operation af3941d-fca6ac76-24dd03e8-9b60f6d2: running=1     completed=0     pending=0     failed=0     aborted=0     lost=0     total=1     blocked=0    













Test set: Average loss: 0.0464, Accuracy: 9843/10000 (98%)












Test set: Average loss: 0.0366, Accuracy: 9869/10000 (99%)












Test set: Average loss: 0.0353, Accuracy: 9874/10000 (99%)












Test set: Average loss: 0.0313, Accuracy: 9895/10000 (99%)












Test set: Average loss: 0.0298, Accuracy: 9902/10000 (99%)












Test set: Average loss: 0.0325, Accuracy: 9892/10000 (99%)












Test set: Average loss: 0.0292, Accuracy: 9903/10000 (99%)










Test set: Average loss: 0.0285, Accuracy: 9905/10000 (99%)












Test set: Average loss: 0.0277, Accuracy: 9915/10000 (99%)












Test set: Average loss: 0.0276, Accuracy: 9912/10000 (99%)












Test set: Average loss: 0.0279, Accuracy: 9909/10000 (99%)












Test set: Average loss: 0.0269, Accuracy: 9912/10000 (99%)












Test set: Average loss: 0.0269, Accuracy: 9913/10000 (99%)












Test set: Average loss: 0.0267, Accuracy: 9915/10000 (99%)



2025-06-19 20:57:06,591	INFO	( 5 min) operation af3941d-fca6ac76-24dd03e8-9b60f6d2 completed
