In [1]:
%pip install --upgrade jupyter ipywidgets
%pip install torch fairscale tiktoken==0.4.0 fair blobfile datasets mwparserfromhell

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

import torch
from common import save_to_disk, load_from_disk, save_layer_state_dict, load_layer_state_dict
from tokenizer import Tokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset
import torch.nn.functional as F
from tqdm import tqdm
import os
import time
from model import ModelArgs, TransformerLayer

class StreamingWikipediaDataset(IterableDataset):
    def __init__(self, texts, tokenizer, seq_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def parse(self, text):
        tokens = self.tokenizer.encode(text, bos=True, eos=True)
        tokens = tokens[:self.seq_len] + [self.tokenizer.pad_id] * (self.seq_len - len(tokens))
        return torch.tensor(tokens)

    def __iter__(self):
        for item in self.texts:
            yield self.parse(item['text'])

def wait_for_file(filename, timeout=30):
    start_time = time.time()
    while not os.path.exists(filename):
        if time.time() - start_time > timeout:
            raise TimeoutError(f"Timeout waiting for {filename}")
        time.sleep(0.1)

def initialize_layers(num_layers, vocab_size):
    for layer_idx in range(num_layers):
        model_args = ModelArgs(
            vocab_size=vocab_size,
            dim=512,
            n_layers=6,
            n_heads=8,
            ffn_dim_multiplier=4
        )
        layer = TransformerLayer(model_args.dim, model_args.n_heads, model_args.dim * model_args.ffn_dim_multiplier)
        save_layer_state_dict(layer.state_dict(), f"data/layer_{layer_idx}.pt")

def main():
    # Initialize the tokenizer
    tokenizer = Tokenizer(encoding_name='cl100k_base')
    
    # Load the dataset using streaming
    dataset = load_dataset("wikipedia", language="en", date="20240401", split='train', streaming=True)
    
    # Prepare the dataset and dataloader
    seq_len = 2048
    wiki_dataset = StreamingWikipediaDataset(dataset, tokenizer, seq_len)
    dataloader = DataLoader(wiki_dataset, batch_size=1, shuffle=False, num_workers=0)

    num_epochs = 3
    gradient_accumulation_steps = 4

    # Initialize layer state dictionaries
    initialize_layers(num_layers=6, vocab_size=tokenizer.get_vocab_size())

    for epoch in range(num_epochs):
        epoch_loss = 0

        for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
            try:
                process_batch(batch, tokenizer, gradient_accumulation_steps, epoch_loss)
            except Exception as e:
                logging.error(f"Error processing batch at step {step}: {e}")
                continue 

        print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(dataloader)}")

def process_batch(batch, tokenizer, gradient_accumulation_steps, epoch_loss):
    try:
        batch = batch.to('cpu')
        save_to_disk(batch, "data/batch.pt")
        os.system(f"python worker.py --task embed --batch data/batch.pt")
        wait_for_file("data/inputs.pt")
        inputs = load_from_disk("data/inputs.pt")

        for layer_idx in range(6):
            worker_forward(layer_idx, inputs)
            wait_for_file(f"data/activations_layer_{layer_idx}.pt")
            inputs = load_from_disk(f"data/activations_layer_{layer_idx}.pt")

        worker_final_logits(inputs)
        wait_for_file("data/logits.pt")
        logits = load_from_disk("data/logits.pt")
        
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), batch.view(-1), ignore_index=tokenizer.pad_id)
        loss = loss / gradient_accumulation_steps
        epoch_loss += loss.item() * gradient_accumulation_steps

        logits.retain_grad()
        loss.backward()
        error = logits.grad
        for layer_idx in reversed(range(6)):
            worker_backward(layer_idx, error)
            wait_for_file(f"data/error_layer_{layer_idx}.pt")
            error_grads = load_from_disk(f"data/error_layer_{layer_idx}.pt")
            print("ERROR GRADS SHAPE: ", error_grads[0].shape)
            if error_grads is None or len(error_grads) != 1:
                print("ERROR GRADS: ", error_grads)
                raise ValueError(f"Failed to load error and grads from {f'data/error_layer_{layer_idx}.pt'}")
            error, grads = error_grads
            update_layer(layer_idx, grads)

    except Exception as e:
        logging.error(f"Failed to process and save batch: {e}")
        raise

def worker_forward(layer_idx, inputs):
    save_to_disk(inputs, f"data/inputs_layer_{layer_idx}.pt")
    os.system(f"python worker.py --task forward --layer_idx {layer_idx} --inputs data/inputs_layer_{layer_idx}.pt")

def worker_backward(layer_idx, error):
    save_to_disk(error, f"data/error_layer_{layer_idx}.pt")
    os.system(f"python worker.py --task backward --layer_idx {layer_idx} --error data/error_layer_{layer_idx}.pt")

def worker_final_logits(inputs):
    save_to_disk(inputs, "data/final_inputs.pt")
    os.system(f"python worker.py --task final_logits --inputs data/final_inputs.pt")

def update_layer(layer_idx, grads):
    state_dict = load_layer_state_dict(f"data/layer_{layer_idx}.pt")
    if state_dict is None:
        raise ValueError(f"Failed to load state dict for layer {layer_idx}")
    for param, grad in zip(state_dict.values(), grads):
        param.data -= grad
    save_layer_state_dict(state_dict, f"data/layer_{layer_idx}.pt")

if __name__ == "__main__":
    main()


2024-05-16 18:11:36,920 - INFO - NumExpr defaulting to 8 threads.


Layer state dict saved to data/layer_0.pt
Layer state dict saved to data/layer_1.pt
Layer state dict saved to data/layer_2.pt
Layer state dict saved to data/layer_3.pt
Layer state dict saved to data/layer_4.pt
Layer state dict saved to data/layer_5.pt


Epoch 1/3: 0it [00:00, ?it/s]2024-05-16 18:11:38,948 - INFO - generating examples from = https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream1.xml-p1p41242.bz2


Extracting content from https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream1.xml-p1p41242.bz2
Saved to data/batch.pt
Loaded from data/batch.pt
Saved to data/inputs.pt
Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt
Layer state dict loaded from data/layer_0.pt
Loaded from data/inputs_layer_0.pt




Saved to data/activations_layer_0.pt
Loaded from data/activations_layer_0.pt
Saved to data/inputs_layer_1.pt
Layer state dict loaded from data/layer_1.pt
Loaded from data/inputs_layer_1.pt




Saved to data/activations_layer_1.pt
Loaded from data/activations_layer_1.pt
Saved to data/inputs_layer_2.pt
Layer state dict loaded from data/layer_2.pt
Loaded from data/inputs_layer_2.pt




Saved to data/activations_layer_2.pt
Loaded from data/activations_layer_2.pt
Saved to data/inputs_layer_3.pt
Layer state dict loaded from data/layer_3.pt
Loaded from data/inputs_layer_3.pt




Saved to data/activations_layer_3.pt
Loaded from data/activations_layer_3.pt
Saved to data/inputs_layer_4.pt
Layer state dict loaded from data/layer_4.pt
Loaded from data/inputs_layer_4.pt




Saved to data/activations_layer_4.pt
Loaded from data/activations_layer_4.pt
Saved to data/inputs_layer_5.pt
Layer state dict loaded from data/layer_5.pt
Loaded from data/inputs_layer_5.pt




Saved to data/activations_layer_5.pt
Loaded from data/activations_layer_5.pt
Saved to data/final_inputs.pt
Loaded from data/final_inputs.pt
Saved to data/logits.pt
Loaded from data/logits.pt
Saved to data/error_layer_5.pt
Layer state dict loaded from data/layer_5.pt
Loaded from data/error_layer_5.pt




Running backward step
Shape of x: torch.Size([1, 2048, 100277])


2024-05-16 18:13:01,340 - ERROR - Failed to process and save batch: Failed to load error and grads from data/error_layer_5.pt
2024-05-16 18:13:01,356 - ERROR - Error processing batch at step 0: Failed to load error and grads from data/error_layer_5.pt
Epoch 1/3: 1it [01:22, 82.43s/it]

Loaded from data/error_layer_5.pt
ERROR GRADS SHAPE:  torch.Size([2048, 100277])
ERROR GRADS LEN:  1
ERROR GRADS:  tensor([[[1.0594e-09, 1.0891e-09, 1.5914e-09,  ..., 3.1967e-10,
          6.7050e-10, 5.4135e-10],
         [3.3859e-10, 8.9587e-10, 7.4177e-09,  ..., 3.6375e-10,
          4.8879e-10, 1.3713e-09],
         [4.4283e-10, 1.0711e-09, 1.3665e-10,  ..., 8.9546e-10,
          1.5640e-09, 1.2189e-09],
         ...,
         [4.2622e-10, 7.7656e-10, 1.6378e-09,  ..., 2.6334e-10,
          1.0670e-09, 7.9344e-10],
         [6.8544e-10, 4.5751e-10, 5.8359e-10,  ..., 2.0027e-09,
          3.3350e-10, 6.0974e-10],
         [7.1990e-10, 2.4505e-10, 2.0007e-09,  ..., 6.4085e-10,
          9.1735e-10, 1.4276e-09]]])
Saved to data/batch.pt
Loaded from data/batch.pt
Saved to data/inputs.pt
Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt
Layer state dict loaded from data/layer_0.pt
Loaded from data/inputs_layer_0.pt




Saved to data/activations_layer_0.pt
Loaded from data/activations_layer_0.pt
Saved to data/inputs_layer_1.pt
Layer state dict loaded from data/layer_1.pt
Loaded from data/inputs_layer_1.pt




Saved to data/activations_layer_1.pt
Loaded from data/activations_layer_1.pt
Saved to data/inputs_layer_2.pt
Layer state dict loaded from data/layer_2.pt
Loaded from data/inputs_layer_2.pt




Saved to data/activations_layer_2.pt
Loaded from data/activations_layer_2.pt
Saved to data/inputs_layer_3.pt
Layer state dict loaded from data/layer_3.pt
Loaded from data/inputs_layer_3.pt




Saved to data/activations_layer_3.pt
Loaded from data/activations_layer_3.pt
Saved to data/inputs_layer_4.pt
Layer state dict loaded from data/layer_4.pt
Loaded from data/inputs_layer_4.pt




Saved to data/activations_layer_4.pt
Loaded from data/activations_layer_4.pt
Saved to data/inputs_layer_5.pt
Layer state dict loaded from data/layer_5.pt
Loaded from data/inputs_layer_5.pt




Saved to data/activations_layer_5.pt
Loaded from data/activations_layer_5.pt
Saved to data/final_inputs.pt
Loaded from data/final_inputs.pt
Saved to data/logits.pt
Loaded from data/logits.pt
Saved to data/error_layer_5.pt
Layer state dict loaded from data/layer_5.pt
Loaded from data/error_layer_5.pt




Running backward step
Shape of x: torch.Size([1, 2048, 100277])
