# Train with PyTorch Lightning

In [None]:
from azureml.core import Workspace

ws = Workspace.from_config()
ws

In [None]:
import git
from pathlib import Path

# get root of git repo
prefix = Path(git.Repo(".", search_parent_directories=True).working_tree_dir)

# training script
source_dir = prefix.joinpath(
    "code", "models", "pytorch-lightning", "mnist-autoencoder"
)
script_name = "train.py"

# environment file
environment_file = prefix.joinpath("environments", "pt-lightning.yml")

# azure ml settings
environment_name = "pt-lightning"
experiment_name = "pt-lightning-example"
cluster_name = "gpu-k80-2"

In [None]:
print(open(source_dir.joinpath(script_name)).read())

## Create environment

In [None]:
from azureml.core import Environment

env = Environment.from_conda_specification(environment_name, environment_file)

# specify a GPU base image
env.docker.enabled = True
env.docker.base_image = (
    "mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04"
)

## Configure and run training job
Create a ScriptRunConfig to specify the training script & arguments, environment, and cluster to run on.

For GPU training on a single node, specify the number of GPUs to train on (typically this will correspond to the number of GPUs in your cluster's SKU) and the distributed mode, in this case DistributedDataParallel (`"ddp"`), which PyTorch Lightning expects as arguments `--gpus` and `--distributed_backend`, respectively. See their [Multi-GPU training](https://pytorch-lightning.readthedocs.io/en/latest/multi_gpu.html) documentation for more information.

In [None]:
import os
from azureml.core import ScriptRunConfig, Experiment

cluster = ws.compute_targets[cluster_name]

src = ScriptRunConfig(
    source_directory=source_dir,
    script=script_name,
    arguments=["--max_epochs", 25, "--gpus", 2, "--distributed_backend", "ddp"],
    compute_target=cluster,
    environment=env,
)

run = Experiment(ws, experiment_name).submit(src)
run

In [None]:
from azureml.widgets import RunDetails

RunDetails(run).show()

In [None]:
run.wait_for_completion(show_output=True)