In [1]:
%pip install --upgrade jupyter ipywidgets
%pip install torch fairscale tiktoken==0.4.0 fair blobfile datasets mwparserfromhell

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
%pip install --upgrade jupyter ipywidgets
%pip install torch fairscale tiktoken==0.4.0 fair blobfile datasets mwparserfromhell

import logging
import subprocess
import torch
from common import save_to_disk, load_from_disk, save_layer_state_dict, load_layer_state_dict, model_args  # Import global model_args
from tokenizer import Tokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset
import torch.nn.functional as F
from tqdm import tqdm
import os
import time
from model import TransformerLayer

class StreamingWikipediaDataset(IterableDataset):
    def __init__(self, texts, tokenizer, seq_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def parse(self, text):
        tokens = self.tokenizer.encode(text, bos=True, eos=True)
        tokens = tokens[:self.seq_len] + [self.tokenizer.pad_id] * (self.seq_len - len(tokens))
        return torch.tensor(tokens)

    def __iter__(self):
        for item in self.texts:
            yield self.parse(item['text'])

def wait_for_file(filename, timeout=30):
    start_time = time.time()
    while not os.path.exists(filename):
        if time.time() - start_time > timeout:
            raise TimeoutError(f"Timeout waiting for {filename}")
        time.sleep(0.1)

def initialize_layers(num_layers, vocab_size):
    for layer_idx in range(num_layers):
        layer = TransformerLayer(model_args.dim, model_args.n_heads, model_args.dim * model_args.ffn_dim_multiplier)
        save_layer_state_dict(layer.state_dict(), f"data/layer_{layer_idx}.pt")

def main():
    # Initialize the tokenizer
    tokenizer = Tokenizer(encoding_name='cl100k_base')
    
    # Load the dataset using streaming
    dataset = load_dataset("wikipedia", language="en", date="20240401", split='train', streaming=True)
    
    # Prepare the dataset and dataloader
    seq_len = 2048
    wiki_dataset = StreamingWikipediaDataset(dataset, tokenizer, seq_len)
    dataloader = DataLoader(wiki_dataset, batch_size=1, shuffle=False, num_workers=0)

    num_epochs = 3
    gradient_accumulation_steps = 4

    # Initialize layer state dictionaries
    initialize_layers(num_layers=model_args.n_layers, vocab_size=tokenizer.get_vocab_size())

    for epoch in range(num_epochs):
        epoch_loss = 0

        for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
            try:
                process_batch(batch, tokenizer, gradient_accumulation_steps, epoch_loss)
            except Exception as e:
                logging.error(f"Error processing batch at step {step}: {e}")
                continue 

        print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(dataloader)}")

def process_batch(batch, tokenizer, gradient_accumulation_steps, epoch_loss):
    try:
        batch = batch.to('cpu')
        save_to_disk(batch, "data/batch.pt")
        run_command(f"python3 worker.py --task embed --batch data/batch.pt")
        wait_for_file("data/inputs.pt")
        inputs = load_from_disk("data/inputs.pt")

        for layer_idx in range(model_args.n_layers):
            worker_forward(layer_idx, inputs)
            wait_for_file(f"data/activations_layer_{layer_idx}.pt")
            inputs = load_from_disk(f"data/activations_layer_{layer_idx}.pt")

        worker_final_logits(inputs)
        wait_for_file("data/logits.pt")
        logits = load_from_disk("data/logits.pt")
        
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), batch.view(-1), ignore_index=tokenizer.pad_id)
        loss = loss / gradient_accumulation_steps
        epoch_loss += loss.item() * gradient_accumulation_steps

        logits.retain_grad()
        loss.backward()
        error = logits.grad
        for layer_idx in reversed(range(model_args.n_layers)):
            worker_backward(layer_idx, error)
            wait_for_file(f"data/error_layer_{layer_idx}.pt")
            error_grads = load_from_disk(f"data/error_layer_{layer_idx}.pt")
            print("ERROR GRADS SHAPE: ", error_grads[0].shape)
            print("ERROR GRADS LEN: ", len(error_grads))
            if error_grads is None or len(error_grads) != 2:
                print("ERROR GRADS: ", error_grads)
                raise ValueError(f"Failed to load error and grads from {f'data/error_layer_{layer_idx}.pt'}")
            error, grads = error_grads
            update_layer(layer_idx, grads)

    except Exception as e:
        logging.error(f"Failed to process and save batch: {e}")
        raise

def worker_forward(layer_idx, inputs):
    save_to_disk(inputs, f"data/inputs_layer_{layer_idx}.pt")
    run_command(f"python3 worker.py --task forward --layer_idx {layer_idx} --inputs data/inputs_layer_{layer_idx}.pt")

def worker_backward(layer_idx, error):
    input_error_file = f"data/input_error_layer_{layer_idx}.pt"
    save_to_disk(error, input_error_file)
    run_command(f"python3 worker.py --task backward --layer_idx {layer_idx} --error {input_error_file}")

def worker_final_logits(inputs):
    save_to_disk(inputs, "data/final_inputs.pt")
    run_command(f"python3 worker.py --task final_logits --inputs data/final_inputs.pt")

def update_layer(layer_idx, grads):
    state_dict = load_layer_state_dict(f"data/layer_{layer_idx}.pt")
    if state_dict is None:
        raise ValueError(f"Failed to load state dict for layer {layer_idx}")
    for param, grad in zip(state_dict.values(), grads):
        param.data -= grad
    save_layer_state_dict(state_dict, f"data/layer_{layer_idx}.pt")

def run_command(command):
    try:
        result = subprocess.run(command, shell=True, text=True, capture_output=True, check=True)
        logging.info(result.stdout)
    except subprocess.CalledProcessError as e:
        logging.error(f"Command '{command}' failed with error: {e.stderr}")
        raise

if __name__ == "__main__":
    main()


2024-05-16 19:10:07,422 - INFO - NumExpr defaulting to 8 threads.


Layer state dict saved to data/layer_0.pt
Layer state dict saved to data/layer_1.pt
Layer state dict saved to data/layer_2.pt
Layer state dict saved to data/layer_3.pt
Layer state dict saved to data/layer_4.pt
Layer state dict saved to data/layer_5.pt


Epoch 1/3: 0it [00:00, ?it/s]2024-05-16 19:10:09,606 - INFO - generating examples from = https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream1.xml-p1p41242.bz2


Extracting content from https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream1.xml-p1p41242.bz2
Saved to data/batch.pt


2024-05-16 19:10:46,710 - INFO - Loaded from data/batch.pt
Saved to data/inputs.pt



Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt


2024-05-16 19:10:47,903 - INFO - Layer state dict loaded from data/layer_0.pt
Loaded from data/inputs_layer_0.pt
Saved to data/activations_layer_0.pt



Loaded from data/activations_layer_0.pt
Saved to data/inputs_layer_1.pt


2024-05-16 19:10:49,046 - INFO - Layer state dict loaded from data/layer_1.pt
Loaded from data/inputs_layer_1.pt
Saved to data/activations_layer_1.pt



Loaded from data/activations_layer_1.pt
Saved to data/inputs_layer_2.pt


2024-05-16 19:10:50,281 - INFO - Layer state dict loaded from data/layer_2.pt
Loaded from data/inputs_layer_2.pt
Saved to data/activations_layer_2.pt



Loaded from data/activations_layer_2.pt
Saved to data/inputs_layer_3.pt


2024-05-16 19:10:51,798 - INFO - Layer state dict loaded from data/layer_3.pt
Loaded from data/inputs_layer_3.pt
Saved to data/activations_layer_3.pt



Loaded from data/activations_layer_3.pt
Saved to data/inputs_layer_4.pt


2024-05-16 19:10:53,044 - INFO - Layer state dict loaded from data/layer_4.pt
Loaded from data/inputs_layer_4.pt
Saved to data/activations_layer_4.pt



Loaded from data/activations_layer_4.pt
Saved to data/inputs_layer_5.pt


2024-05-16 19:10:54,171 - INFO - Layer state dict loaded from data/layer_5.pt
Loaded from data/inputs_layer_5.pt
Saved to data/activations_layer_5.pt



Loaded from data/activations_layer_5.pt
Saved to data/final_inputs.pt


2024-05-16 19:10:56,799 - INFO - Loaded from data/final_inputs.pt
Saved to data/logits.pt



Loaded from data/logits.pt
Saved to data/input_error_layer_5.pt



2024-05-16 19:11:34,381 - ERROR - Failed to process and save batch: Command 'python3 worker.py --task backward --layer_idx 5 --error data/input_error_layer_5.pt' died with <Signals.SIGKILL: 9>.
2024-05-16 19:11:34,384 - ERROR - Error processing batch at step 0: Command 'python3 worker.py --task backward --layer_idx 5 --error data/input_error_layer_5.pt' died with <Signals.SIGKILL: 9>.
Epoch 1/3: 1it [01:24, 84.80s/it]

Saved to data/batch.pt


2024-05-16 19:11:36,997 - INFO - Loaded from data/batch.pt
Saved to data/inputs.pt



Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt


2024-05-16 19:11:38,204 - INFO - Layer state dict loaded from data/layer_0.pt
Loaded from data/inputs_layer_0.pt
Saved to data/activations_layer_0.pt



Loaded from data/activations_layer_0.pt
Saved to data/inputs_layer_1.pt


2024-05-16 19:11:39,390 - INFO - Layer state dict loaded from data/layer_1.pt
Loaded from data/inputs_layer_1.pt
Saved to data/activations_layer_1.pt



Loaded from data/activations_layer_1.pt
Saved to data/inputs_layer_2.pt


2024-05-16 19:11:40,563 - INFO - Layer state dict loaded from data/layer_2.pt
Loaded from data/inputs_layer_2.pt
Saved to data/activations_layer_2.pt



Loaded from data/activations_layer_2.pt
Saved to data/inputs_layer_3.pt


2024-05-16 19:11:41,713 - INFO - Layer state dict loaded from data/layer_3.pt
Loaded from data/inputs_layer_3.pt
Saved to data/activations_layer_3.pt



Loaded from data/activations_layer_3.pt
Saved to data/inputs_layer_4.pt


2024-05-16 19:11:42,859 - INFO - Layer state dict loaded from data/layer_4.pt
Loaded from data/inputs_layer_4.pt
Saved to data/activations_layer_4.pt



Loaded from data/activations_layer_4.pt
Saved to data/inputs_layer_5.pt


2024-05-16 19:11:43,987 - INFO - Layer state dict loaded from data/layer_5.pt
Loaded from data/inputs_layer_5.pt
Saved to data/activations_layer_5.pt



Loaded from data/activations_layer_5.pt
Saved to data/final_inputs.pt


2024-05-16 19:11:46,518 - INFO - Loaded from data/final_inputs.pt
Saved to data/logits.pt



Loaded from data/logits.pt
Saved to data/input_error_layer_5.pt



2024-05-16 19:12:21,831 - ERROR - Failed to process and save batch: Command 'python3 worker.py --task backward --layer_idx 5 --error data/input_error_layer_5.pt' died with <Signals.SIGKILL: 9>.
2024-05-16 19:12:21,833 - ERROR - Error processing batch at step 1: Command 'python3 worker.py --task backward --layer_idx 5 --error data/input_error_layer_5.pt' died with <Signals.SIGKILL: 9>.
Epoch 1/3: 2it [02:12, 62.84s/it]

Saved to data/batch.pt


2024-05-16 19:12:24,635 - INFO - Loaded from data/batch.pt
Saved to data/inputs.pt



Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt


2024-05-16 19:12:25,831 - INFO - Layer state dict loaded from data/layer_0.pt
Loaded from data/inputs_layer_0.pt
Saved to data/activations_layer_0.pt



Loaded from data/activations_layer_0.pt
Saved to data/inputs_layer_1.pt


2024-05-16 19:12:27,000 - INFO - Layer state dict loaded from data/layer_1.pt
Loaded from data/inputs_layer_1.pt
Saved to data/activations_layer_1.pt



Loaded from data/activations_layer_1.pt
Saved to data/inputs_layer_2.pt


2024-05-16 19:12:28,163 - INFO - Layer state dict loaded from data/layer_2.pt
Loaded from data/inputs_layer_2.pt
Saved to data/activations_layer_2.pt



Loaded from data/activations_layer_2.pt
Saved to data/inputs_layer_3.pt


2024-05-16 19:12:29,358 - INFO - Layer state dict loaded from data/layer_3.pt
Loaded from data/inputs_layer_3.pt
Saved to data/activations_layer_3.pt



Loaded from data/activations_layer_3.pt
Saved to data/inputs_layer_4.pt


2024-05-16 19:12:30,548 - INFO - Layer state dict loaded from data/layer_4.pt
Loaded from data/inputs_layer_4.pt
Saved to data/activations_layer_4.pt



Loaded from data/activations_layer_4.pt
Saved to data/inputs_layer_5.pt


2024-05-16 19:12:31,730 - INFO - Layer state dict loaded from data/layer_5.pt
Loaded from data/inputs_layer_5.pt
Saved to data/activations_layer_5.pt



Loaded from data/activations_layer_5.pt
Saved to data/final_inputs.pt


2024-05-16 19:12:34,406 - INFO - Loaded from data/final_inputs.pt
Saved to data/logits.pt



Loaded from data/logits.pt
Saved to data/input_error_layer_5.pt



2024-05-16 19:13:11,645 - ERROR - Failed to process and save batch: Command 'python3 worker.py --task backward --layer_idx 5 --error data/input_error_layer_5.pt' died with <Signals.SIGKILL: 9>.
2024-05-16 19:13:11,646 - ERROR - Error processing batch at step 2: Command 'python3 worker.py --task backward --layer_idx 5 --error data/input_error_layer_5.pt' died with <Signals.SIGKILL: 9>.
Epoch 1/3: 3it [03:02, 56.95s/it]

Saved to data/batch.pt


2024-05-16 19:13:14,879 - INFO - Loaded from data/batch.pt
Saved to data/inputs.pt



Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt


2024-05-16 19:13:16,123 - INFO - Layer state dict loaded from data/layer_0.pt
Loaded from data/inputs_layer_0.pt
Saved to data/activations_layer_0.pt



Loaded from data/activations_layer_0.pt
Saved to data/inputs_layer_1.pt


2024-05-16 19:13:17,335 - INFO - Layer state dict loaded from data/layer_1.pt
Loaded from data/inputs_layer_1.pt
Saved to data/activations_layer_1.pt



Loaded from data/activations_layer_1.pt
Saved to data/inputs_layer_2.pt


2024-05-16 19:13:18,626 - INFO - Layer state dict loaded from data/layer_2.pt
Loaded from data/inputs_layer_2.pt
Saved to data/activations_layer_2.pt



Loaded from data/activations_layer_2.pt
Saved to data/inputs_layer_3.pt


2024-05-16 19:13:20,043 - INFO - Layer state dict loaded from data/layer_3.pt
Loaded from data/inputs_layer_3.pt
Saved to data/activations_layer_3.pt



Loaded from data/activations_layer_3.pt
Saved to data/inputs_layer_4.pt


2024-05-16 19:13:21,692 - INFO - Layer state dict loaded from data/layer_4.pt
Loaded from data/inputs_layer_4.pt
Saved to data/activations_layer_4.pt



Loaded from data/activations_layer_4.pt
Saved to data/inputs_layer_5.pt


2024-05-16 19:13:22,994 - INFO - Layer state dict loaded from data/layer_5.pt
Loaded from data/inputs_layer_5.pt
Saved to data/activations_layer_5.pt



Loaded from data/activations_layer_5.pt
Saved to data/final_inputs.pt


2024-05-16 19:13:25,632 - INFO - Loaded from data/final_inputs.pt
Saved to data/logits.pt



Loaded from data/logits.pt
Saved to data/input_error_layer_5.pt



2024-05-16 19:14:04,908 - ERROR - Failed to process and save batch: Command 'python3 worker.py --task backward --layer_idx 5 --error data/input_error_layer_5.pt' died with <Signals.SIGKILL: 9>.
2024-05-16 19:14:04,909 - ERROR - Error processing batch at step 3: Command 'python3 worker.py --task backward --layer_idx 5 --error data/input_error_layer_5.pt' died with <Signals.SIGKILL: 9>.
Epoch 1/3: 4it [03:55, 55.44s/it]

Saved to data/batch.pt


2024-05-16 19:14:07,638 - INFO - Loaded from data/batch.pt
Saved to data/inputs.pt



Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt


2024-05-16 19:14:08,873 - INFO - Layer state dict loaded from data/layer_0.pt
Loaded from data/inputs_layer_0.pt
Saved to data/activations_layer_0.pt



Loaded from data/activations_layer_0.pt
Saved to data/inputs_layer_1.pt


2024-05-16 19:14:10,050 - INFO - Layer state dict loaded from data/layer_1.pt
Loaded from data/inputs_layer_1.pt
Saved to data/activations_layer_1.pt



Loaded from data/activations_layer_1.pt
Saved to data/inputs_layer_2.pt


2024-05-16 19:14:11,200 - INFO - Layer state dict loaded from data/layer_2.pt
Loaded from data/inputs_layer_2.pt
Saved to data/activations_layer_2.pt



Loaded from data/activations_layer_2.pt
Saved to data/inputs_layer_3.pt


2024-05-16 19:14:12,359 - INFO - Layer state dict loaded from data/layer_3.pt
Loaded from data/inputs_layer_3.pt
Saved to data/activations_layer_3.pt



Loaded from data/activations_layer_3.pt
Saved to data/inputs_layer_4.pt


2024-05-16 19:14:13,630 - INFO - Layer state dict loaded from data/layer_4.pt
Loaded from data/inputs_layer_4.pt
Saved to data/activations_layer_4.pt



Loaded from data/activations_layer_4.pt
Saved to data/inputs_layer_5.pt


2024-05-16 19:14:14,880 - INFO - Layer state dict loaded from data/layer_5.pt
Loaded from data/inputs_layer_5.pt
Saved to data/activations_layer_5.pt



Loaded from data/activations_layer_5.pt
Saved to data/final_inputs.pt


2024-05-16 19:14:17,536 - INFO - Loaded from data/final_inputs.pt
Saved to data/logits.pt



Loaded from data/logits.pt
Saved to data/input_error_layer_5.pt


Epoch 1/3: 4it [04:22, 65.51s/it]


KeyboardInterrupt: 