## Intro To TP

[Article Link](https://pytorch.org/tutorials/intermediate/TP_tutorial.html)

[Dummy Training Link](https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/fsdp_tp_example.py)

It was introduced in the `Megatron-LM` paper. It is an efficient model parallelism technique to train large scale transformer models.

SP (Sequence Parallel) is a variant of TP.
- It shards on sequence dimension
- for nn.LayerNorm or RMSNorm

As the model size increases the activation memory becomes the bottleneck, as all the intermediate outputs has to be stored for gradient calculation.

Therefore, in TP training SP is also applied to layers such as LayerNorm & RMSNorm

### How TP Works

**Sharding Initialization**
- Determine which `ParallelStyle` to apply to each layer and shard the initialized module by calling `parallelize_module`.
- The parallelized modules would have their model parameters be swapped to DTensors, and DTensor would be responsible to run the parallelized module using sharded computation.

**Runtime forward / backward**
- Depending on the input/outputs DTensor layouts user specified for each `ParallelStyle`, it would run proper communication operation to transform the DTensor layouts for inputs/outputs (such as `allreduce`, `allgather`, `reduce_scatter`).
- Run sharded computation for the parallelized layers to save compute/memory (e.g. `nn.Linear`, `nn.Embedding`)

### When & Why TP should be applied

PyTorch Fully Sharded Data Parallel (FSDP) already has the capability to scale model training to a specific number of GPUs. However, when it comes to further scale the model training in terms of model size and GPU quantity, many additional challenges arise that may require combining Tensor Parallel with FSDP.:

- no of gpu >> 128/256 communication operations mentioned above are then dominated by `ring latency`.
- Implementing TP/SP over FSDP helps to reduce the FSDP world size by 8 thereby reducing the latency.
- `Hit data parallelism`, limit where global batch size cannot be greater than the world size due to convergence and GPU memory limitations.
- TP is the only known way to help with this and allow scaling of both model & GPU.
- For certain types of models (## TODO: IDENTIFY WHICH) when local batch size is smaller, TP/SP can yield matrix multiplication shapes that are more optimized for floating point operations (FLOPS)

**Real World E.G**

Llama 2 70B
- GPUS - 2k
- Global BS - 1k

FSDP cannot alone work here at 2k GPU. Here local `batch size = 1` only is not possible due to convergence and memory constraints.

### How to apply Tensor Parallel

Some module level primitives

`ColwiseParallel` and `RowwiseParallel`: Shard the `nn.Linear` and `nn.Embedding` in the column or row fashion

`SequenceParallel`: Perform sharded computations on the `nn.LayerNorm`, `nn.Dropout`, `RMSNormPython`

`PrepareModuleInput` and `PrepareModuleOutput`: Configure the module inputs/outputs sharding layout with proper communication operations.

In [None]:
## TODO: Implement this using kaggle multi-gpu
## Train on any random tensor