In [None]:
#| default_exp distributed

In [1]:
# | export
import random, math, torch, numpy as np, matplotlib.pyplot as plt
from tinyai.model import *
from tinyai.learner import *
from tinyai.hooks import *
from tinyai.init import *
from tinyai.speedup import *
from tinyai.hyperparam import *
import fastcore.all as fc
from functools import partial
import time
import os

In [2]:
# | export
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [3]:
# | export
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

In [4]:
# | export
ddp_enabled = dist.is_available() and int(os.environ.get("RANK", -1)) != -1

In [5]:
ddp_enabled

False

In [None]:
if ddp_enabled:
    init_process_group(backend="nccl")
    ddp_rank = dist.get_rank()
    ddp_local_rank = int(os.environ["LOCAL_RANK"])
    ddp_world_size = dist.get_world_size()
    device = f"cuda:{ddp_local_rank}"
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0
else:
    ddp_rank = 0
    ddp_local_rank = 0
    ddp_world_size = 1
    master_process = True
    device = default_device

In [None]:
set_seed(1337)
torch.cuda.manual_seed(1337)

In [9]:
import tiktoken
import os

enc = tiktoken.get_encoding("gpt2")

cwd = os.getcwd()
data_dir = f"{cwd}/fast-nanogpt"
pattern = "input.txt"

In [7]:
#| export
def to_tensor(f):
    def _f(*args, **kwargs):
        return torch.tensor(f(*args, **kwargs), dtype=torch.long)
    return _f

@to_tensor
def get_tokens(input_file):
    with open(input_file) as f:
        text = f.read()
    tokens = enc.encode(text)
    return tokens

In [10]:
t = get_tokens(f"{data_dir}/{pattern}")
t.shape

torch.Size([338025])

## Iterable dataset

In distributed settings, we have multiple process that access the same dataset. So we need to make sure that each process only access its own pard.

In [None]:
#| export
from torch.utils.data import IterableDataset

class FSDataSet(IterableDataset):
    """ A dataset that loads data from a directory of files.
    """
    def __init__(
        self, data_dir, pattern=None, token_fn=get_tokens, T=32, num_proc=1, rank=0
    ):
        self.T = T
        self.num_proc = num_proc
        self.rank = rank
        self.pattern = pattern
        self.token_fn = token_fn

        self.shards = self.get_shards(data_dir, pattern)
        self.current_shard = 0
        self.reset()

    def reset(self):
        self.current_shard = 0
        # each process starts at a different offset corresponding to its rank
        self.current_pos = self.T * self.rank
        self.tokens = self.token_fn(self.shards[self.current_shard])

    def get_shards(self, data_dir, pattern=None):
        shards = os.listdir(data_dir)
        if pattern is not None:
            shards = [os.path.join(data_dir, s) for s in shards if pattern in s]
        shards = sorted(shards)
        return shards

    def step(self):
        # advance position
        self.current_pos += self.T * self.num_proc
        # if next step will go over the end of the shard, move to the next shard
        if self.current_pos + self.T * self.num_proc + 1 > len(self.tokens):
            self.current_shard = (self.current_shard + 1) % len(self.shards)
            self.tokens = self.token_fn(self.shards[self.current_shard])
            self.current_pos = self.T * self.rank

    def __iter__(self):
        return self
    
    def __next__(self):
        buf = self.tokens[self.current_pos : self.current_pos + self.T + 1]
        x = buf[:-1]
        y = buf[1:]

        self.step()
        return x, y

In [None]:
tds = FSDataSet(
    data_dir,
    pattern=pattern,
    T=32,
    num_proc=ddp_world_size,
    rank=ddp_rank,
)
it = iter(tds)
x, y = next(it)
print(x)
print(y)

In [None]:
dls = DataLoaders.from_dd([tds, None], batch_size=4, drop_last=True)
x, y = next(iter(dls.train))
x.shape, y.shape

In [None]:
#| export
class FixedStepCallback(Callback):
    def __init__(self, step_count=50):
        self.step_count = step_count

    def after_batch(self, learn):
        if hasattr(learn, "opt"):
            if learn.opt._step_count >= self.step_count:
                raise CancelFitException()

In [None]:
grad_accu_steps = 50
cbs = [GradAccuTrainCB(grad_accu_steps), InitWeightsCB(), DeviceCB()]
def fit(model, epochs=1, opt_func=optim.AdamW, xtra_cbs=None, lr=3e-4):
    lrn = Learner(model, dls=dls, opt_func=opt_func, cbs=cbs + fc.L(xtra_cbs), lr=lr)
    lrn.fit(epochs, valid=False)
    return lrn

In [None]:
set_seed(1337)
model = get_model().to(default_device)
record = GradAccuRecordCB(lr=get_lr, grad_norm=get_grad_norm)
schd = GradAccuScheduleCB(partial(CosineLR, warmup_steps=10, max_steps=50))
fit(model, opt_func=get_optimizer, xtra_cbs=[schd, record, GradAccuLogCallback(), FixedStepCallback(8)], lr=6e-4)

In [None]:
record.recs['grad_norm']

## DDP

DDP container handles the communication between the different processes. The forward pass remains unchained, after backward is called it uses `all_reduce` to synchronize across all GPUs to average the gradients


In [None]:
if ddp_enabled:
    model = DDP(model, device_ids=[ddp_local_rank])

In [None]:
#| export
class DDPCB(Callback):
    def __init__(self, local_rank, compile=True):
        self.compile = compile
        self.local_rank = local_rank

    def before_fit(self, learn):
        if self.compile:
            learn.model = torch.compile(learn.model)

        learn.model = DDP(learn.model, device_ids=[self.local_rank])

Since we are using gradient accumulation, we don't want to synchronize the gradients after each batch, that will be extremely wasteful. Instead we want to synchronize the gradients at the last `micro step` where micro step is `accu_step - 1`.


In [None]:
#| export
class DDPGradAccuTrainCB(GradAccuTrainCB):

    def backward(self, learn):
        if ddp_enabled:
            learn.model.require_backward_grad_sync = (
                learn._micro_step_count % self.accu_steps == (self.accu_steps - 1)
            )
        super().backward(learn)

It is quite tricky to start DDP from notebook, so we will use the `train_gpt2.py` script to train the model.