# Distributed data parallel minGPT training with PyTorch Lightning and SMDataParallel


## Background
SMDataParallel is a new capability in Amazon SageMaker to train deep learning models faster and cheaper. SMDataParallel is a distributed data parallel training framework for PyTorch. 

This notebook example shows how to use SMDataParallel with PyTorch Lightning in SageMaker.

For more information:
1. [PyTorch in SageMaker](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html)
2. [SMDataParallel PyTorch API Specification](https://sagemaker.readthedocs.io/en/stable/api/training/smd_data_parallel_pytorch.html)
3. [Getting started with SMDataParallel on SageMaker](https://sagemaker.readthedocs.io/en/stable/api/training/smd_data_parallel.html)

**NOTE:** This example requires SageMaker Python SDK v2.X.

In [2]:
%%capture
!pip install sagemaker --upgrade

In [10]:
import sagemaker
from sagemaker.pytorch import PyTorch

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

## Model training with SMDataParallel

### Training script

The training script provides the code you need for distributed data parallel (DDP) training using SMDataParallel. The training script is very similar to a PyTorch Lightning training script you might run outside of SageMaker, but modified to run with SMDataParallel by passing the flag `accelerator="ddp_sm"` to Trainer.


In [5]:
!pygmentize minGPT/benchmark.py

### Estimator function options

In the following code block, you can update the estimator function to use a different instance type, instance count, and distrubtion strategy. You're also passing in the training script you reviewed in the previous cell.

**Instance types**

SMDataParallel supports model training on SageMaker with the following instance types only:
1. ml.p3.16xlarge
1. ml.p3dn.24xlarge [Recommended]
1. ml.p4d.24xlarge [Recommended]

**Instance count**

To get the best performance and the most out of SMDataParallel, you should use at least 2 instances, but you can also use 1 for testing this example.

**Distribution strategy**

Note that to use DDP mode, you update the the `distribution` strategy, and set it to use `smdistributed dataparallel`. 

Pass the filename of the training script as the `entry_point` parameter and specify the directory name in `source_dir` argument. You can also include a `requirements.txt` file in the same directory as your training script to install other dependencies at runtime. That's where we add the Pytorch Lightning dependency.

```
code
   |--training.py
   |--requirements.txt
```

In [None]:
! echo "git+https://github.com/kaushikb11/pytorch-lightning.git@smddp" > ./minGPT/requirements.txt

In [8]:
entry_point = "benchmark.py"
source_dir = "minGPT"

In [None]:
estimator = PyTorch(base_job_name='lightning-benchmarks',
                        source_dir=source_dir,
                        entry_point=entry_point,
                        role=role,
                        framework_version='1.6.0',
                        py_version='py36',
                        # For training with multinode distributed training, set this count. Example: 2
                        instance_count=1,
                        # For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
                        instance_type= 'ml.p3.16xlarge',
                        sagemaker_session=sagemaker_session,
                        # Training using SMDataParallel Distributed Training Framework
                        # Comment this when not training with smddp
                        distribution={'smdistributed':{
                                            'dataparallel':{
                                                    'enabled': True
                                                 }
                                          }
                                      },
                        debugger_hook_config=False)

In [None]:
estimator.fit()