In [1]:
import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm import tqdm

# Import the model and tokenizer
from model import Transformer, ModelArgs
from tokenizer import Tokenizer

import fairscale.nn.model_parallel.initialize as fs_init

class WikipediaDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        tokens = self.tokenizer.encode(text, bos=True, eos=True)
        tokens = tokens[:self.seq_len] + [self.tokenizer.pad_id] * (self.seq_len - len(tokens))
        return torch.tensor(tokens)

def main():
    # Enable anomaly detection
    torch.autograd.set_detect_anomaly(True)

    # Set environment variables for distributed training
    os.environ['RANK'] = '0'
    os.environ['WORLD_SIZE'] = '1'
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # Initialize the distributed environment
    dist.init_process_group(backend='nccl')

    # Initialize model parallel
    fs_init.initialize_model_parallel(model_parallel_size_=1)  # Adjust model_parallel_size based on your setup

    # Load the dataset
    dataset = load_dataset("wikipedia", language="en", date="20240401", split='train[:5%]', trust_remote_code=True)
    texts = dataset['text']  # Extract the texts from the dataset

    # Initialize the tokenizer
    tokenizer = Tokenizer(model_path='cl100k_base.tiktoken')

    # Prepare the dataset and dataloader
    seq_len = 2048
    wiki_dataset = WikipediaDataset(texts, tokenizer, seq_len)
    dataloader = DataLoader(wiki_dataset, batch_size=1, shuffle=True, num_workers=2)

    # Initialize the model
    model_args = ModelArgs(
        vocab_size=tokenizer.get_vocab_size(),
        dim=512,
        n_layers=6,
        n_heads=8,
        ffn_dim_multiplier=4
    )

    model = Transformer(model_args).cuda()

    # Define optimizer and learning rate scheduler
    optimizer = AdamW(model.parameters(), lr=1e-4)
    num_epochs = 3
    total_steps = len(dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

    # Mixed precision scaler
    scaler = torch.cuda.amp.GradScaler()

    # Training loop with gradient accumulation and mixed precision
    gradient_accumulation_steps = 4  # Accumulate gradients over 4 batches
    model.train()

    for epoch in range(num_epochs):
        epoch_loss = 0
        optimizer.zero_grad()

        for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
            batch = batch.cuda()

            with torch.cuda.amp.autocast():
                outputs = model(batch, start_pos=0)
                logits = outputs.view(-1, outputs.size(-1))
                targets = batch.view(-1)

                loss = F.cross_entropy(logits, targets, ignore_index=tokenizer.pad_id)
                loss = loss / gradient_accumulation_steps  # Scale the loss

            scaler.scale(loss).backward(retain_graph=(step + 1) % gradient_accumulation_steps != 0)

            if (step + 1) % gradient_accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            epoch_loss += loss.item() * gradient_accumulation_steps

        print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(dataloader)}")

    # Finalize model parallel
    fs_init.destroy_model_parallel()
    dist.destroy_process_group()

if __name__ == "__main__":
    main()


> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/user/anaconda3/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/home/user/anaconda3/lib/python3.11/site-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/home/user/anaconda3/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 701, in start
    self.io_loop.start()
  File "/home/user/anaconda3/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/home/user/anaconda3/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
    self._run_once()
  File "/home/user/anaconda3/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
    handle._run()
  File "/home/user/anaconda3/lib/python3.11/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._arg

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.