In [1]:
%pip install --upgrade jupyter ipywidgets
%pip install torch fairscale tiktoken==0.4.0 fair blobfile datasets mwparserfromhell

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
import logging
import subprocess
import torch
import torch.distributed as dist
from common import save_to_disk, load_from_disk, save_layer_state_dict, load_layer_state_dict, model_args, tokenizer
from tokenizer import Tokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset
import torch.nn as nn
import os
import time
from tqdm import tqdm
from model import TransformerBlock, VocabParallelEmbedding, ColumnParallelLinear, RMSNorm, Transformer, precompute_freqs_cis
from fairscale.nn.model_parallel.initialize import initialize_model_parallel, model_parallel_is_initialized
from fairscale.nn.model_parallel.layers import (
    ColumnParallelLinear,
    RowParallelLinear,
    VocabParallelEmbedding,
)

logging.basicConfig(level=logging.DEBUG)


class StreamingWikipediaDataset(IterableDataset):
    def __init__(self, texts, tokenizer, seq_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def parse(self, text):
        tokens = self.tokenizer.encode(text, bos=True, eos=True)
        tokens = tokens[:self.seq_len] + [self.tokenizer.pad_id] * (self.seq_len - len(tokens))
        return torch.tensor(tokens)

    def __iter__(self):
        for item in self.texts:
            yield self.parse(item['text'])


def wait_for_file(filename, timeout=30):
    start_time = time.time()
    while not os.path.exists(filename):
        if time.time() - start_time > timeout:
            raise TimeoutError(f"Timeout waiting for {filename}")
        time.sleep(0.1)


def initialize_layers(model_args):
    def init_weights(m):
        if isinstance(m, (nn.Linear, ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)):
            nn.init.xavier_uniform_(m.weight)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.zeros_(m.bias)

    for layer_idx in range(model_args.n_layers):
        layer = TransformerBlock(layer_idx, model_args)
        layer.apply(init_weights)

        for param in layer.parameters():
            if torch.isnan(param).any() or torch.isinf(param).any():
                raise ValueError(f"NaNs or Infs detected in weights of layer {layer_idx}")

        save_layer_state_dict(layer.state_dict(), f"data/layer_{layer_idx}.pt")

    embedding = VocabParallelEmbedding(model_args.vocab_size, model_args.dim)
    embedding.apply(init_weights)

    for param in embedding.parameters():
        if torch.isnan(param).any() or torch.isinf(param).any():
            raise ValueError("NaNs or Infs detected in embedding weights")

    save_layer_state_dict(embedding.state_dict(), "data/embedding.pt")

    output = ColumnParallelLinear(model_args.dim, model_args.vocab_size, bias=False)
    output.apply(init_weights)

    for param in output.parameters():
        if torch.isnan(param).any() or torch.isinf(param).any():
            raise ValueError("NaNs or Infs detected in output weights")

    save_layer_state_dict(output.state_dict(), "data/output.pt")

    freqs_cis = precompute_freqs_cis(
        model_args.dim // model_args.n_heads,
        model_args.max_seq_len * 2,
        model_args.rope_theta,
    )
    save_to_disk(freqs_cis, "data/freqs_cis.pt")


def process_batch(batch, tokenizer, epoch_loss, learning_rate, beta1, beta2, epsilon, weight_decay, t):
    try:
        batch = batch.to('cpu')
        if torch.isnan(batch).any() or torch.isinf(batch).any():
            raise ValueError("NaNs or Infs detected in input batch")

        logging.debug(f"Batch (token IDs): {batch}")
        decoded_batch = [tokenizer.decode(tokens.tolist()) for tokens in batch]
        logging.debug(f"Batch (decoded text): {decoded_batch}")

        inputs = batch[:, :-1]
        targets = batch[:, 1:]

        save_to_disk(inputs, "data/inputs.pt")
        save_to_disk(targets, "data/targets.pt")
        
        run_command(f"python3 worker.py --task embed --batch data/inputs.pt --embedding_file data/embedding.pt --inputs data/inputs_embed.pt")
        wait_for_file("data/inputs_embed.pt")
        inputs_embed = load_from_disk("data/inputs_embed.pt")

        logging.debug(f"Embedded inputs: {inputs_embed}")

        check_for_nans(inputs_embed, "inputs after embedding")

        freqs_cis = load_from_disk("data/freqs_cis.pt")
        save_to_disk(freqs_cis, "data/freqs_cis.pt")

        seqlen = inputs_embed.shape[1]
        mask = None
        if seqlen > 1:
            start_pos = 0
            mask = torch.full((seqlen, seqlen), float('-inf'))
            mask = torch.triu(mask, diagonal=1)
            mask = torch.cat([torch.zeros((seqlen, start_pos)), mask], dim=1)
            save_to_disk(mask, "data/mask.pt")

        for layer_idx in range(model_args.n_layers):
            inputs_file = f"data/inputs_layer_{layer_idx}.pt"
            logits_file = f"data/logits_layer_{layer_idx}.pt"
            state_dict_file = f"data/layer_{layer_idx}.pt"
            save_to_disk(inputs_embed, inputs_file)
            run_command(f"python3 worker.py --task forward --layer_idx {layer_idx} --inputs {inputs_file} --state_dict {state_dict_file} --freqs_cis data/freqs_cis.pt --logits_file {logits_file} --mask data/mask.pt")
            wait_for_file(logits_file)
            inputs_embed = load_from_disk(logits_file)

            logging.debug(f"Inputs after layer {layer_idx}: {inputs_embed}")

            if torch.isnan(inputs_embed).any() or torch.isinf(inputs_embed).any():
                raise ValueError(f"NaNs or Infs detected in inputs after layer {layer_idx}")

            check_for_nans(inputs_embed, f"inputs after layer {layer_idx}")

        norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
        inputs_embed = norm(inputs_embed)
        save_to_disk(inputs_embed, "data/final_inputs.pt")

        run_command(f"python3 worker.py --task final_logits --inputs data/final_inputs.pt --state_dict data/output.pt --logits_file data/logits.pt")
        wait_for_file("data/logits.pt")
        logits = load_from_disk("data/logits.pt")

        logging.debug(f"Final logits: {logits}")

        if torch.isnan(logits).any() or torch.isinf(logits).any():
            raise ValueError("NaNs or Infs detected in final logits")

        check_for_nans(logits, "final logits")

        save_to_disk(logits, "data/logits.pt")
        run_command(f"python3 worker.py --task loss --logits data/logits.pt --targets data/targets.pt --loss_file data/loss.pt --logits_grad_file data/logits_grad.pt")
        wait_for_file("data/loss.pt")
        wait_for_file("data/logits_grad.pt")

        loss = load_from_disk("data/loss.pt")
        logging.debug(f"Loss: {loss:.4f}")
        print(f"Loss: {loss:.4f}")
        epoch_loss += loss
        logits_grad = load_from_disk("data/logits_grad.pt")

        logging.debug(f"Logits gradients: {logits_grad}")

        save_to_disk(logits_grad, "data/final_error.pt")
        run_command(f"python3 worker.py --task final_logits_backward --error data/final_error.pt --inputs data/final_inputs.pt --state_dict data/output.pt --error_output_file data/error_output.pt")


    except Exception as e:
        logging.error(f"Failed to process and save batch: {e}")
        raise


def run_command(command):
    try:
        result = subprocess.run(command, shell=True, text=True, capture_output=True, check=True)
        logging.info(result.stdout)
    except subprocess.CalledProcessError as e:
        logging.error(f"Command '{command}' failed with error: {e.stderr}")
        raise


def main():
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    os.environ['WORLD_SIZE'] = '1'
    os.environ['RANK'] = '0'

    dist.init_process_group(backend='nccl')
    initialize_model_parallel(model_parallel_size_=1)

    dataset = load_dataset("wikipedia", language="en", date="20240401", split='train', streaming=True, trust_remote_code=True)

    seq_len = 2048
    wiki_dataset = StreamingWikipediaDataset(dataset, tokenizer, seq_len)
    dataloader = DataLoader(wiki_dataset, batch_size=1, shuffle=False, num_workers=0)

    num_epochs = 3
    learning_rate = 1e-4
    beta1 = 0.9
    beta2 = 0.999
    epsilon = 1e-8
    weight_decay = 1e-2

    initialize_layers(model_args)

    for epoch in range(num_epochs):
        epoch_loss = 0

        for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
            try:
                process_batch(batch, tokenizer, epoch_loss, learning_rate, beta1, beta2, epsilon, weight_decay, step + 1)
            except Exception as e:
                logging.error(f"Error processing batch at step {step}: {e}")
                continue

        print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(dataloader)}")

        dist.destroy_process_group()


def check_for_nans(tensor, name):
    if torch.isnan(tensor).any():
        print(f"NaNs detected in {name}")


if __name__ == "__main__":
    main()


SyntaxError: unmatched ')' (1639551878.py, line 153)