In [1]:
%pip install --upgrade jupyter ipywidgets
%pip install torch fairscale tiktoken==0.4.0 fair blobfile datasets mwparserfromhell

Collecting ipywidgets
  Obtaining dependency information for ipywidgets from https://files.pythonhosted.org/packages/70/1a/7edeedb1c089d63ccd8bd5c0612334774e90cf9337de9fe6c82d90081791/ipywidgets-8.1.2-py3-none-any.whl.metadata
  Downloading ipywidgets-8.1.2-py3-none-any.whl.metadata (2.4 kB)
Collecting comm>=0.1.3 (from ipywidgets)
  Obtaining dependency information for comm>=0.1.3 from https://files.pythonhosted.org/packages/e6/75/49e5bfe642f71f272236b5b2d2691cf915a7283cc0ceda56357b61daa538/comm-0.2.2-py3-none-any.whl.metadata
  Downloading comm-0.2.2-py3-none-any.whl.metadata (3.7 kB)
Collecting widgetsnbextension~=4.0.10 (from ipywidgets)
  Obtaining dependency information for widgetsnbextension~=4.0.10 from https://files.pythonhosted.org/packages/99/bc/82a8c3985209ca7c0a61b383c80e015fd92e74f8ba0ec1af98f9d6ca8dce/widgetsnbextension-4.0.10-py3-none-any.whl.metadata
  Downloading widgetsnbextension-4.0.10-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab-widgets~=3.0.10 (from i

In [2]:
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

import torch
from common import save_to_disk, load_from_disk, save_layer_state_dict, load_layer_state_dict
from tokenizer import Tokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset
import torch.nn.functional as F
from tqdm import tqdm
import os
import time
from model import ModelArgs, TransformerLayer

class StreamingWikipediaDataset(IterableDataset):
    def __init__(self, texts, tokenizer, seq_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def parse(self, text):
        tokens = self.tokenizer.encode(text, bos=True, eos=True)
        tokens = tokens[:self.seq_len] + [self.tokenizer.pad_id] * (self.seq_len - len(tokens))
        return torch.tensor(tokens)

    def __iter__(self):
        for item in self.texts:
            yield self.parse(item['text'])

def wait_for_file(filename, timeout=30):
    start_time = time.time()
    while not os.path.exists(filename):
        if time.time() - start_time > timeout:
            raise TimeoutError(f"Timeout waiting for {filename}")
        time.sleep(0.1)

def initialize_layers(num_layers, vocab_size):
    for layer_idx in range(num_layers):
        model_args = ModelArgs(
            vocab_size=vocab_size,
            dim=512,
            n_layers=6,
            n_heads=8,
            ffn_dim_multiplier=4
        )
        layer = TransformerLayer(model_args.dim, model_args.n_heads, model_args.dim * model_args.ffn_dim_multiplier)
        save_layer_state_dict(layer.state_dict(), f"data/layer_{layer_idx}.pt")

def main():
    # Initialize the tokenizer
    tokenizer = Tokenizer(encoding_name='cl100k_base')
    
    # Load the dataset using streaming
    dataset = load_dataset("wikipedia", language="en", date="20240401", split='train', streaming=True)
    
    # Prepare the dataset and dataloader
    seq_len = 2048
    wiki_dataset = StreamingWikipediaDataset(dataset, tokenizer, seq_len)
    dataloader = DataLoader(wiki_dataset, batch_size=1, shuffle=False, num_workers=0)

    num_epochs = 3
    gradient_accumulation_steps = 4

    # Initialize layer state dictionaries
    initialize_layers(num_layers=6, vocab_size=tokenizer.get_vocab_size())

    for epoch in range(num_epochs):
        epoch_loss = 0

        for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
            try:
                process_batch(batch, tokenizer, gradient_accumulation_steps, epoch_loss)
            except Exception as e:
                logging.error(f"Error processing batch at step {step}: {e}")
                continue 

        print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(dataloader)}")

def process_batch(batch, tokenizer, gradient_accumulation_steps, epoch_loss):
    try:
        batch = batch.to('cpu')
        save_to_disk(batch, "data/batch.pt")
        os.system(f"python worker.py --task embed --batch data/batch.pt")
        wait_for_file("data/inputs.pt")
        inputs = load_from_disk("data/inputs.pt")

        for layer_idx in range(6):
            worker_forward(layer_idx, inputs)
            wait_for_file(f"data/activations_layer_{layer_idx}.pt")
            inputs = load_from_disk(f"data/activations_layer_{layer_idx}.pt")

        worker_final_logits(inputs)
        wait_for_file("data/logits.pt")
        logits = load_from_disk("data/logits.pt")
        
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), batch.view(-1), ignore_index=tokenizer.pad_id)
        loss = loss / gradient_accumulation_steps
        epoch_loss += loss.item() * gradient_accumulation_steps

        logits.retain_grad()
        loss.backward()
        error = logits.grad
        for layer_idx in reversed(range(6)):
            worker_backward(layer_idx, error)
            wait_for_file(f"data/error_layer_{layer_idx}.pt")
            error_grads = load_from_disk(f"data/error_layer_{layer_idx}.pt")
            if error_grads is None or len(error_grads) != 2:
                print("ERROR GRADS: ", error_grads)
                raise ValueError(f"Failed to load error and grads from {f'data/error_layer_{layer_idx}.pt'}")
            error, grads = error_grads
            update_layer(layer_idx, grads)

    except Exception as e:
        logging.error(f"Failed to process and save batch: {e}")
        raise

def worker_forward(layer_idx, inputs):
    save_to_disk(inputs, f"data/inputs_layer_{layer_idx}.pt")
    os.system(f"python worker.py --task forward --layer_idx {layer_idx} --inputs data/inputs_layer_{layer_idx}.pt")

def worker_backward(layer_idx, error):
    save_to_disk(error, f"data/error_layer_{layer_idx}.pt")
    os.system(f"python worker.py --task backward --layer_idx {layer_idx} --error data/error_layer_{layer_idx}.pt")

def worker_final_logits(inputs):
    save_to_disk(inputs, "data/final_inputs.pt")
    os.system(f"python worker.py --task final_logits --inputs data/final_inputs.pt")

def update_layer(layer_idx, grads):
    state_dict = load_layer_state_dict(f"data/layer_{layer_idx}.pt")
    if state_dict is None:
        raise ValueError(f"Failed to load state dict for layer {layer_idx}")
    for param, grad in zip(state_dict.values(), grads):
        param.data -= grad
    save_layer_state_dict(state_dict, f"data/layer_{layer_idx}.pt")

if __name__ == "__main__":
    main()


2024-05-16 16:47:42,007 - INFO - NumExpr defaulting to 8 threads.


RuntimeError: Parent directory data does not exist.