In [1]:
%pip install --upgrade jupyter ipywidgets
%pip install torch fairscale tiktoken==0.4.0 fair blobfile datasets mwparserfromhell

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
import logging
import subprocess
import torch
from common import save_to_disk, load_from_disk, save_layer_state_dict, load_layer_state_dict, model_args  # Import global model_args
from tokenizer import Tokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import os
import time
from model import TransformerLayer

class StreamingWikipediaDataset(IterableDataset):
    def __init__(self, texts, tokenizer, seq_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def parse(self, text):
        tokens = self.tokenizer.encode(text, bos=True, eos=True)
        tokens = tokens[:self.seq_len] + [self.tokenizer.pad_id] * (self.seq_len - len(tokens))
        return torch.tensor(tokens)

    def __iter__(self):
        for item in self.texts:
            yield self.parse(item['text'])

def wait_for_file(filename, timeout=30):
    start_time = time.time()
    while not os.path.exists(filename):
        if time.time() - start_time > timeout:
            raise TimeoutError(f"Timeout waiting for {filename}")
        time.sleep(0.1)

def initialize_layers(num_layers, vocab_size):
    for layer_idx in range(num_layers):
        layer = TransformerLayer(model_args.dim, model_args.n_heads, model_args.dim * model_args.ffn_dim_multiplier)
        save_layer_state_dict(layer.state_dict(), f"data/layer_{layer_idx}.pt")
    
    # Save the state dict for the embedding and fully connected layers
    embedding = nn.Embedding(vocab_size, model_args.dim)
    save_layer_state_dict(embedding.state_dict(), "data/embedding.pt")
    
    fc = nn.Linear(model_args.dim, vocab_size)
    save_layer_state_dict(fc.state_dict(), "data/layer_fc.pt")

def main():
    # Initialize the tokenizer
    tokenizer = Tokenizer(encoding_name='cl100k_base')
    
    # Load the dataset using streaming
    dataset = load_dataset("wikipedia", language="en", date="20240401", split='train', streaming=True)
    
    # Prepare the dataset and dataloader
    seq_len = 2048
    wiki_dataset = StreamingWikipediaDataset(dataset, tokenizer, seq_len)
    dataloader = DataLoader(wiki_dataset, batch_size=1, shuffle=False, num_workers=0)

    num_epochs = 3
    learning_rate = 1e-4
    beta1 = 0.9
    beta2 = 0.999
    epsilon = 1e-8

    # Initialize layer state dictionaries
    initialize_layers(num_layers=model_args.n_layers, vocab_size=tokenizer.get_vocab_size())

    for epoch in range(num_epochs):
        epoch_loss = 0

        for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
            try:
                process_batch(batch, tokenizer, epoch_loss, learning_rate, beta1, beta2, epsilon, step + 1)
            except Exception as e:
                logging.error(f"Error processing batch at step {step}: {e}")
                continue 

        print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(dataloader)}")

def process_batch(batch, tokenizer, epoch_loss, learning_rate, beta1, beta2, epsilon, t):
    try:
        batch = batch.to('cpu')
        save_to_disk(batch, "data/batch.pt")
        run_command(f"python3 worker.py --task embed --batch data/batch.pt --embedding_file data/embedding.pt --inputs data/inputs.pt")
        wait_for_file("data/inputs.pt")
        inputs = load_from_disk("data/inputs.pt")

        for layer_idx in range(model_args.n_layers):
            inputs_file = f"data/inputs_layer_{layer_idx}.pt"
            logits_file = f"data/logits_layer_{layer_idx}.pt"
            state_dict_file = f"data/layer_{layer_idx}.pt"
            save_to_disk(inputs, inputs_file)
            run_command(f"python3 worker.py --task forward --layer_idx {layer_idx} --inputs {inputs_file} --state_dict {state_dict_file} --logits_file {logits_file}")
            wait_for_file(logits_file)
            inputs = load_from_disk(logits_file)

        save_to_disk(inputs, "data/final_inputs.pt")
        run_command(f"python3 worker.py --task final_logits --inputs data/final_inputs.pt --state_dict data/layer_fc.pt --logits_file data/logits.pt")
        wait_for_file("data/logits.pt")
        logits = load_from_disk("data/logits.pt")
        
        save_to_disk(batch, "data/targets.pt")
        save_to_disk(logits, "data/logits.pt")
        run_command(f"python3 worker.py --task loss --logits data/logits.pt --targets data/targets.pt --loss_file data/loss.pt --logits_grad_file data/logits_grad.pt")
        wait_for_file("data/loss.pt")
        wait_for_file("data/logits_grad.pt")
        
        loss = load_from_disk("data/loss.pt")
        print(f"Loss: {loss:.4f}")
        epoch_loss += loss
        logits_grad = load_from_disk("data/logits_grad.pt")
        
        # Backpropagate through the final fully connected layer
        save_to_disk(logits_grad, "data/final_error.pt")
        run_command(f"python3 worker.py --task final_logits_backward --error data/final_error.pt --inputs data/final_inputs.pt --state_dict data/layer_fc.pt --error_output_file data/error_layer_fc.pt")
        wait_for_file("data/error_layer_fc.pt")
        final_error, fc_grads = load_from_disk("data/error_layer_fc.pt")
        save_to_disk(fc_grads, "data/fc_grads.pt")
        run_command(f"python3 worker.py --task apply_adam --layer_idx -1 --grads data/fc_grads.pt --learning_rate {learning_rate} --beta1 {beta1} --beta2 {beta2} --epsilon {epsilon} --t {t}")

        error = final_error[0]  # Use the input gradient for the next backward pass
        for layer_idx in reversed(range(model_args.n_layers)):
            inputs_file = f"data/inputs_layer_{layer_idx}.pt"
            error_output_file = f"data/error_layer_{layer_idx}.pt"
            state_dict_file = f"data/layer_{layer_idx}.pt"
            save_to_disk(error, "data/error_for_layer.pt")
            run_command(f"python3 worker.py --task backward --layer_idx {layer_idx} --error data/error_for_layer.pt --inputs {inputs_file} --state_dict {state_dict_file} --error_output_file {error_output_file}")
            wait_for_file(error_output_file)
            error_grads = load_from_disk(error_output_file)
            if error_grads is None or len(error_grads) != 2:
                raise ValueError(f"Failed to load error and grads from {error_output_file}")
            error, grads = error_grads
            error = error.view(1, -1, model_args.dim)  # Reshape to match the output shape of the previous layer
            save_to_disk(grads, f"data/grads_layer_{layer_idx}.pt")
            run_command(f"python3 worker.py --task apply_adam --layer_idx {layer_idx} --grads data/grads_layer_{layer_idx}.pt --learning_rate {learning_rate} --beta1 {beta1} --beta2 {beta2} --epsilon {epsilon} --t {t}")

        # Backpropagate through the embedding layer
        error = error.view(batch.size(0), -1, model_args.dim)  # Reshape to match the embedding layer output
        save_to_disk(error, "data/embedding_error.pt")
        run_command(f"python3 worker.py --task embed_backward --error data/embedding_error.pt --batch data/batch.pt --embedding_file data/embedding.pt --error_output_file data/embedding_grads.pt")
        wait_for_file("data/embedding_grads.pt")
        embedding_grads = load_from_disk("data/embedding_grads.pt")
        save_to_disk(embedding_grads, "data/embedding_grads.pt")
        run_command(f"python3 worker.py --task apply_adam --layer_idx -2 --grads data/embedding_grads.pt --learning_rate {learning_rate} --beta1 {beta1} --beta2 {beta2} --epsilon {epsilon} --t {t}")

    except Exception as e:
        logging.error(f"Failed to process and save batch: {e}")
        raise

def run_command(command):
    try:
        result = subprocess.run(command, shell=True, text=True, capture_output=True, check=True)
        logging.info(result.stdout)
    except subprocess.CalledProcessError as e:
        logging.error(f"Command '{command}' failed with error: {e.stderr}")
        raise

if __name__ == "__main__":
    main()


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Layer state dict saved to data/layer_0.pt
Layer state dict saved to data/layer_1.pt
Layer state dict saved to data/layer_2.pt
Layer state dict saved to data/layer_3.pt
Layer state dict saved to data/layer_4.pt
Layer state dict saved to data/layer_5.pt
Layer state dict saved to data/embedding.pt
Layer state dict saved to data/layer_fc.pt


Epoch 1/3: 0it [00:00, ?it/s]

Extracting content from https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream1.xml-p1p41242.bz2
Saved to data/batch.pt
Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt
Loaded from data/logits_layer_0.pt
Saved to data/inputs_layer_1.pt
Loaded from data/logits_layer_1.pt
Saved to data/inputs_layer_2.pt
Loaded from data/logits_layer_2.pt
Saved to data/inputs_layer_3.pt
Loaded from data/logits_layer_3.pt
Saved to data/inputs_layer_4.pt
Loaded from data/logits_layer_4.pt
Saved to data/inputs_layer_5.pt
Loaded from data/logits_layer_5.pt
Saved to data/final_inputs.pt
Loaded from data/logits.pt
Saved to data/targets.pt
Saved to data/logits.pt
Loaded from data/loss.pt
Loss: 11.7371
Loaded from data/logits_grad.pt
Saved to data/final_error.pt
Loaded from data/error_layer_fc.pt
Layer state dict loaded from data/layer_fc.pt
Layer state dict saved to data/layer_fc.pt
Saved to data/error_for_layer.pt
Loaded from data/error_layer_5.pt
Layer state dict loade

Epoch 1/3: 1it [01:19, 79.65s/it]

Loaded from data/embedding_grads.pt
Layer state dict loaded from data/embedding.pt
Layer state dict saved to data/embedding.pt
Saved to data/batch.pt
Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt
Loaded from data/logits_layer_0.pt
Saved to data/inputs_layer_1.pt
Loaded from data/logits_layer_1.pt
Saved to data/inputs_layer_2.pt
Loaded from data/logits_layer_2.pt
Saved to data/inputs_layer_3.pt
Loaded from data/logits_layer_3.pt
Saved to data/inputs_layer_4.pt
Loaded from data/logits_layer_4.pt
Saved to data/inputs_layer_5.pt
Loaded from data/logits_layer_5.pt
Saved to data/final_inputs.pt
Loaded from data/logits.pt
Saved to data/targets.pt
Saved to data/logits.pt
Loaded from data/loss.pt
Loss: 11.9691
Loaded from data/logits_grad.pt
Saved to data/final_error.pt
Loaded from data/error_layer_fc.pt
Layer state dict loaded from data/layer_fc.pt
Layer state dict saved to data/layer_fc.pt
Saved to data/error_for_layer.pt
Loaded from data/error_layer_5.pt
Layer state dict loaded 

Epoch 1/3: 2it [01:47, 49.12s/it]

Loaded from data/embedding_grads.pt
Layer state dict loaded from data/embedding.pt
Layer state dict saved to data/embedding.pt
Saved to data/batch.pt
Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt
Loaded from data/logits_layer_0.pt
Saved to data/inputs_layer_1.pt
Loaded from data/logits_layer_1.pt
Saved to data/inputs_layer_2.pt
Loaded from data/logits_layer_2.pt
Saved to data/inputs_layer_3.pt
Loaded from data/logits_layer_3.pt
Saved to data/inputs_layer_4.pt
Loaded from data/logits_layer_4.pt
Saved to data/inputs_layer_5.pt
Loaded from data/logits_layer_5.pt
Saved to data/final_inputs.pt
Loaded from data/logits.pt
Saved to data/targets.pt
Saved to data/logits.pt
Loaded from data/loss.pt
Loss: 315.8104
Loaded from data/logits_grad.pt
Saved to data/final_error.pt
Loaded from data/error_layer_fc.pt
Layer state dict loaded from data/layer_fc.pt
Layer state dict saved to data/layer_fc.pt
Saved to data/error_for_layer.pt
Loaded from data/error_layer_5.pt
Layer state dict loaded

Epoch 1/3: 3it [02:15, 39.32s/it]

Loaded from data/embedding_grads.pt
Layer state dict loaded from data/embedding.pt
Layer state dict saved to data/embedding.pt
Saved to data/batch.pt
Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt
Loaded from data/logits_layer_0.pt
Saved to data/inputs_layer_1.pt
Loaded from data/logits_layer_1.pt
Saved to data/inputs_layer_2.pt
Loaded from data/logits_layer_2.pt
Saved to data/inputs_layer_3.pt
Loaded from data/logits_layer_3.pt
Saved to data/inputs_layer_4.pt
Loaded from data/logits_layer_4.pt
Saved to data/inputs_layer_5.pt
Loaded from data/logits_layer_5.pt
Saved to data/final_inputs.pt
Loaded from data/logits.pt
Saved to data/targets.pt
Saved to data/logits.pt
Loaded from data/loss.pt
Loss: 298946.4688
Loaded from data/logits_grad.pt
Saved to data/final_error.pt
Loaded from data/error_layer_fc.pt
Layer state dict loaded from data/layer_fc.pt
Layer state dict saved to data/layer_fc.pt
Saved to data/error_for_layer.pt
Loaded from data/error_layer_5.pt
Layer state dict loa

Epoch 1/3: 4it [02:41, 34.38s/it]

Loaded from data/embedding_grads.pt
Layer state dict loaded from data/embedding.pt
Layer state dict saved to data/embedding.pt
Saved to data/batch.pt
Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt
Loaded from data/logits_layer_0.pt
Saved to data/inputs_layer_1.pt
Loaded from data/logits_layer_1.pt
Saved to data/inputs_layer_2.pt
Loaded from data/logits_layer_2.pt
Saved to data/inputs_layer_3.pt
Loaded from data/logits_layer_3.pt
Saved to data/inputs_layer_4.pt
Loaded from data/logits_layer_4.pt
Saved to data/inputs_layer_5.pt
Loaded from data/logits_layer_5.pt
Saved to data/final_inputs.pt
Loaded from data/logits.pt
Saved to data/targets.pt
Saved to data/logits.pt
Loaded from data/loss.pt
Loss: 1050592706188804096.0000
Loaded from data/logits_grad.pt
Saved to data/final_error.pt
Loaded from data/error_layer_fc.pt
Layer state dict loaded from data/layer_fc.pt
Layer state dict saved to data/layer_fc.pt
Saved to data/error_for_layer.pt
Loaded from data/error_layer_5.pt
Layer s

Epoch 1/3: 5it [03:12, 33.10s/it]

Loaded from data/embedding_grads.pt
Layer state dict loaded from data/embedding.pt
Layer state dict saved to data/embedding.pt
Saved to data/batch.pt
Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt
Loaded from data/logits_layer_0.pt
Saved to data/inputs_layer_1.pt
Loaded from data/logits_layer_1.pt
Saved to data/inputs_layer_2.pt
Loaded from data/logits_layer_2.pt
Saved to data/inputs_layer_3.pt
Loaded from data/logits_layer_3.pt
Saved to data/inputs_layer_4.pt
Loaded from data/logits_layer_4.pt
Saved to data/inputs_layer_5.pt
Loaded from data/logits_layer_5.pt
Saved to data/final_inputs.pt
Loaded from data/logits.pt
Saved to data/targets.pt
Saved to data/logits.pt
Loaded from data/loss.pt
Loss: nan
Loaded from data/logits_grad.pt
Saved to data/final_error.pt
Loaded from data/error_layer_fc.pt
Layer state dict loaded from data/layer_fc.pt
Layer state dict saved to data/layer_fc.pt
Saved to data/error_for_layer.pt
Loaded from data/error_layer_5.pt
Layer state dict loaded from

Epoch 1/3: 6it [03:45, 32.96s/it]

Loaded from data/embedding_grads.pt
Layer state dict loaded from data/embedding.pt
Layer state dict saved to data/embedding.pt
Saved to data/batch.pt
Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt
Loaded from data/logits_layer_0.pt
Saved to data/inputs_layer_1.pt
Loaded from data/logits_layer_1.pt
Saved to data/inputs_layer_2.pt
Loaded from data/logits_layer_2.pt
Saved to data/inputs_layer_3.pt
Loaded from data/logits_layer_3.pt
Saved to data/inputs_layer_4.pt
Loaded from data/logits_layer_4.pt
Saved to data/inputs_layer_5.pt
Loaded from data/logits_layer_5.pt
Saved to data/final_inputs.pt
Loaded from data/logits.pt
Saved to data/targets.pt
Saved to data/logits.pt
Loaded from data/loss.pt
Loss: nan
Loaded from data/logits_grad.pt
Saved to data/final_error.pt
Loaded from data/error_layer_fc.pt
Layer state dict loaded from data/layer_fc.pt
Layer state dict saved to data/layer_fc.pt
Saved to data/error_for_layer.pt
Loaded from data/error_layer_5.pt
Layer state dict loaded from

Epoch 1/3: 7it [04:16, 32.46s/it]

Loaded from data/embedding_grads.pt
Layer state dict loaded from data/embedding.pt
Layer state dict saved to data/embedding.pt
Saved to data/batch.pt
Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt
Loaded from data/logits_layer_0.pt
Saved to data/inputs_layer_1.pt
Loaded from data/logits_layer_1.pt
Saved to data/inputs_layer_2.pt
Loaded from data/logits_layer_2.pt
Saved to data/inputs_layer_3.pt
Loaded from data/logits_layer_3.pt
Saved to data/inputs_layer_4.pt
Loaded from data/logits_layer_4.pt
Saved to data/inputs_layer_5.pt
Loaded from data/logits_layer_5.pt
Saved to data/final_inputs.pt
Loaded from data/logits.pt
Saved to data/targets.pt
Saved to data/logits.pt
Loaded from data/loss.pt
Loss: nan
Loaded from data/logits_grad.pt
Saved to data/final_error.pt
Loaded from data/error_layer_fc.pt
Layer state dict loaded from data/layer_fc.pt
Layer state dict saved to data/layer_fc.pt
Saved to data/error_for_layer.pt
Loaded from data/error_layer_5.pt
Layer state dict loaded from