In [1]:
import torch
from common import save_to_disk, load_from_disk, save_layer_state_dict, load_layer_state_dict
from tokenizer import Tokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from dataset import WikipediaDataset
import torch.nn.functional as F
from tqdm import tqdm
import os
import time

def wait_for_file(filename, timeout=30):
    start_time = time.time()
    while not os.path.exists(filename):
        if time.time() - start_time > timeout:
            raise TimeoutError(f"Timeout waiting for {filename}")
        time.sleep(0.1)

def main():
    # Initialize the tokenizer
    tokenizer = Tokenizer(encoding_name='cl100k_base')
    
    # Load the dataset
    dataset = load_dataset("wikipedia", language="en", date="20240401", split='train[:5%]', trust_remote_code=True)
    texts = dataset['text']

    # Prepare the dataset and dataloader
    seq_len = 2048
    wiki_dataset = WikipediaDataset(texts, tokenizer, seq_len)
    dataloader = DataLoader(wiki_dataset, batch_size=1, shuffle=True, num_workers=2)

    num_epochs = 3
    gradient_accumulation_steps = 4

    for epoch in range(num_epochs):
        epoch_loss = 0

        for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
            batch = batch.to('cpu')
            save_to_disk(batch, "data/batch.pt")
            os.system(f"python worker.py --task embed --batch data/batch.pt")
            wait_for_file("data/inputs.pt")
            inputs = load_from_disk("data/inputs.pt")

            # Forward pass through all layers
            for layer_idx in range(6):  # Assuming 6 layers
                worker_forward(layer_idx, inputs)
                wait_for_file(f"data/activations_layer_{layer_idx}.pt")
                inputs = load_from_disk(f"data/activations_layer_{layer_idx}.pt")

            # Dispatch worker to save final logits
            worker_final_logits(inputs)
            wait_for_file("data/logits.pt")
            logits = load_from_disk("data/logits.pt")
            
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), batch.view(-1), ignore_index=tokenizer.pad_id)
            loss = loss / gradient_accumulation_steps
            epoch_loss += loss.item() * gradient_accumulation_steps

            # Backward pass through all layers
            logits.retain_grad()  # Ensure logits have gradients
            loss.backward()
            error = logits.grad
            for layer_idx in reversed(range(6)):  # Assuming 6 layers
                worker_backward(layer_idx, error)
                wait_for_file(f"data/error_layer_{layer_idx}.pt")
                error_grads = load_from_disk(f"data/error_layer_{layer_idx}.pt")
                if error_grads is None:
                    raise ValueError(f"Failed to load error and grads from {f'data/error_layer_{layer_idx}.pt'}")
                error, grads = error_grads
                update_layer(layer_idx, grads)

        print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(dataloader)}")

def worker_forward(layer_idx, inputs):
    save_to_disk(inputs, f"data/inputs_layer_{layer_idx}.pt")
    os.system(f"python worker.py --task forward --layer_idx {layer_idx} --inputs data/inputs_layer_{layer_idx}.pt")

def worker_backward(layer_idx, error):
    save_to_disk(error, f"data/error_layer_{layer_idx}.pt")
    os.system(f"python worker.py --task backward --layer_idx {layer_idx} --error data/error_layer_{layer_idx}.pt")

def worker_final_logits(inputs):
    save_to_disk(inputs, "data/final_inputs.pt")
    os.system(f"python worker.py --task final_logits --inputs data/final_inputs.pt")

def update_layer(layer_idx, grads):
    state_dict = load_layer_state_dict(f"data/layer_{layer_idx}.pt")
    if state_dict is None:
        raise ValueError(f"Failed to load state dict for layer {layer_idx}")
    for param, grad in zip(state_dict.values(), grads):
        param.data -= grad
    save_layer_state_dict(state_dict, f"data/layer_{layer_idx}.pt")

if __name__ == "__main__":
    main()


Epoch 1/3:   0%|          | 0/340203 [00:00<?, ?it/s]

Saved to data/batch.pt
Loaded from data/batch.pt
Saved to data/inputs.pt
Loaded from data/inputs.pt
Saved to data/inputs_layer_0.pt
Layer state dict loaded from data/layer_0.pt
Loaded from data/inputs_layer_0.pt
Saved to data/activations_layer_0.pt
Loaded from data/activations_layer_0.pt
Saved to data/inputs_layer_1.pt
Layer state dict loaded from data/layer_1.pt
Loaded from data/inputs_layer_1.pt
Saved to data/activations_layer_1.pt
Loaded from data/activations_layer_1.pt
Saved to data/inputs_layer_2.pt
Layer state dict loaded from data/layer_2.pt
Loaded from data/inputs_layer_2.pt
Saved to data/activations_layer_2.pt
Loaded from data/activations_layer_2.pt
Saved to data/inputs_layer_3.pt
Layer state dict loaded from data/layer_3.pt
Loaded from data/inputs_layer_3.pt
Saved to data/activations_layer_3.pt
Loaded from data/activations_layer_3.pt
Saved to data/inputs_layer_4.pt
Layer state dict loaded from data/layer_4.pt
Loaded from data/inputs_layer_4.pt
Saved to data/activations_layer_

Traceback (most recent call last):
  File "/home/user/Coding/ritser/magnum/spl/worker.py", line 112, in <module>
    backward_task(args.layer_idx, args.error)
  File "/home/user/Coding/ritser/magnum/spl/worker.py", line 93, in backward_task
    loss, grads = run_layer_step(layer, None, is_forward=False, next_error=error, optimizer=optimizer, scaler=scaler, loss_fn=loss_fn, pad_id=tokenizer.pad_id)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/Coding/ritser/magnum/spl/worker.py", line 17, in run_layer_step
    logits = layer(x)
             ^^^^^^^^
  File "/home/user/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _ca

Loaded from data/error_layer_5.pt





ValueError: not enough values to unpack (expected 2, got 1)