# m4 : Parallel training with FSDP (PyTorch 2.5)

Use PyTorch 2.5 and the Kaggle double GPU instance. 
Train gpt-2 on openwebtext using Pytorch FSDP. Do not use Huggingface Transformers or Accelerate.

In [15]:
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

In [13]:
!pip install torchdata torchtext


Collecting torchdata
  Downloading torchdata-0.10.1-py3-none-any.whl.metadata (6.3 kB)
Collecting torchtext
  Downloading torchtext-0.18.0-cp310-cp310-manylinux1_x86_64.whl.metadata (7.9 kB)
Downloading torchdata-0.10.1-py3-none-any.whl (57 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.5/57.5 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading torchtext-0.18.0-cp310-cp310-manylinux1_x86_64.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m30.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: torchdata, torchtext
Successfully installed torchdata-0.10.1 torchtext-0.18.0


In [14]:
!pip install datasets



In [16]:
def prepare_dataset():
    dataset = load_dataset("openwebtext", split="train")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")

    def tokenize_fn(examples):
        return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=1024, return_tensors="pt")
    
    tokenized_dataset = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
    return tokenized_dataset

In [18]:
def get_dataloader(tokenized_dataset, batch_size):
    sampler = DistributedSampler(tokenized_dataset)
    return DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        sampler=sampler,
        collate_fn=lambda x: (x["input_ids"], x["input_ids"])  
    )

In [19]:
class GPT2(nn.Module):
    def __init__(self, vocab_size, seq_len, embed_dim, num_heads, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=4 * embed_dim
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.transformer(x)
        x = self.fc_out(x)
        return x

In [20]:
def train(model, dataloader, optimizer, scaler, device, epochs=1):
    model.train()
    for epoch in range(epochs):
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            with autocast():
                outputs = model(inputs)
                loss = cross_entropy(outputs.view(-1, outputs.size(-1)), targets.view(-1))

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            print(f"Epoch {epoch+1}, Loss: {loss.item()}")

In [None]:
def setup_environment():
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "12355"
    os.environ["WORLD_SIZE"] = "1"
    os.environ["RANK"] = "0"
    os.environ["LOCAL_RANK"] = "0"

def setup_distributed():
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))


device = torch.device("cuda")

In [None]:
 tokenized_dataset = prepare_dataset()

README.md:   0%|          | 0.00/7.35k [00:00<?, ?B/s]

openwebtext.py:   0%|          | 0.00/2.73k [00:00<?, ?B/s]

The repository for openwebtext contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/openwebtext.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  y


Downloading data:   0%|          | 0/21 [00:00<?, ?files/s]

urlsf_subset00.tar:   0%|          | 0.00/633M [00:00<?, ?B/s]

urlsf_subset01.tar:   0%|          | 0.00/629M [00:00<?, ?B/s]

urlsf_subset02.tar:   0%|          | 0.00/629M [00:00<?, ?B/s]

urlsf_subset03.tar:   0%|          | 0.00/628M [00:00<?, ?B/s]

urlsf_subset04.tar:   0%|          | 0.00/627M [00:00<?, ?B/s]

urlsf_subset05.tar:   0%|          | 0.00/630M [00:00<?, ?B/s]

urlsf_subset06.tar:   0%|          | 0.00/626M [00:00<?, ?B/s]

urlsf_subset07.tar:   0%|          | 0.00/625M [00:00<?, ?B/s]

urlsf_subset08.tar:   0%|          | 0.00/625M [00:00<?, ?B/s]

urlsf_subset09.tar:   0%|          | 0.00/626M [00:00<?, ?B/s]

urlsf_subset10.tar:   0%|          | 0.00/625M [00:00<?, ?B/s]

urlsf_subset11.tar:   0%|          | 0.00/625M [00:00<?, ?B/s]

urlsf_subset12.tar:   0%|          | 0.00/624M [00:00<?, ?B/s]

urlsf_subset13.tar:   0%|          | 0.00/629M [00:00<?, ?B/s]

urlsf_subset14.tar:   0%|          | 0.00/627M [00:00<?, ?B/s]

urlsf_subset15.tar:   0%|          | 0.00/621M [00:00<?, ?B/s]

urlsf_subset16.tar:   0%|          | 0.00/619M [00:00<?, ?B/s]

urlsf_subset17.tar:   0%|          | 0.00/619M [00:00<?, ?B/s]

urlsf_subset18.tar:   0%|          | 0.00/618M [00:00<?, ?B/s]

urlsf_subset19.tar:   0%|          | 0.00/619M [00:00<?, ?B/s]

urlsf_subset20.tar:   0%|          | 0.00/377M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8013769 [00:00<?, ? examples/s]

In [None]:
dataloader = get_dataloader(tokenized_dataset, batch_size)
model = GPT2(vocab_size, seq_len, embed_dim, num_heads, num_layers).to(device)
model = FSDP(model)
optimizer = optim.AdamW(model.parameters(), lr=lr)
scaler = GradScaler()

In [None]:
train(model, dataloader, optimizer, scaler, device, epochs)