In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# Sample dataset
class TextDataset(Dataset):
    def __init__(self, texts, tokenizer):
        self.encodings = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)

    def __len__(self):
        return self.encodings["input_ids"].size(0)

    def __getitem__(self, idx):
        return {
            "input_ids": self.encodings["input_ids"][idx],
            "attention_mask": self.encodings["attention_mask"][idx]
        }

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)

    # Tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    # Model
    model = GPT2LMHeadModel.from_pretrained("gpt2").to(rank)
    model = DDP(model, device_ids=[rank])

    # Data
    train_texts = [
        "Hello, how are you?",
        "I am learning to train GPT-2 with multiple GPUs!",
        "Deep learning is amazing.",
        "PyTorch makes model training easier."
    ]

    dataset = TextDataset(train_texts, tokenizer)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
    dataloader = DataLoader(dataset, batch_size=2, sampler=sampler)

    optimizer = optim.AdamW(model.parameters(), lr=5e-5)
    epochs = 10

    for epoch in range(epochs):
        sampler.set_epoch(epoch)  # for consistent shuffling
        
        for batch in dataloader:
            input_ids = batch["input_ids"].to(rank)
            attention_mask = batch["attention_mask"].to(rank)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss.mean()

            loss.backward()
            optimizer.step()

            if rank == 0:
                print(f"Epoch [{epoch+1}/{epochs}] | Loss: {loss.item():.4f}")

    # Save model from rank 0 process
    if rank == 0:
        save_path = "gpt2-ddp-manual"
        model.module.save_pretrained(save_path)
        tokenizer.save_pretrained(save_path)
        print(f"Model saved at {save_path}")

    cleanup()

if __name__ == "__main__":
    world_size = torch.cuda.device_count()  # e.g., 2 for 2 T4s
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)


KeyError: 'RANK'