In [1]:
%pip install --upgrade jupyter ipywidgets
%pip install torch fairscale tiktoken==0.4.0 fair blobfile datasets mwparserfromhell

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [1]:
import logging
import subprocess
import torch
import torch.distributed as dist
from common import save_to_disk, load_from_disk, save_layer_state_dict, load_layer_state_dict, model_args, tokenizer
from tokenizer import Tokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset
import torch.nn as nn
import os
import time
from tqdm import tqdm
from model import TransformerBlock, VocabParallelEmbedding, ColumnParallelLinear, RMSNorm, Transformer, precompute_freqs_cis
from fairscale.nn.model_parallel.initialize import initialize_model_parallel, model_parallel_is_initialized
from fairscale.nn.model_parallel.layers import (
    ColumnParallelLinear,
    RowParallelLinear,
    VocabParallelEmbedding,
)
import random

logging.basicConfig(level=logging.INFO)

MAX_SEQ_LEN = model_args.max_seq_len

class StreamingWikipediaDataset(IterableDataset):
    def __init__(self, texts, tokenizer, max_seq_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def parse(self, text):
        tokens = self.tokenizer.encode(text, bos=True, eos=True)
        seq_len = random.randint(1, self.max_seq_len)
        if len(tokens) < seq_len + 1:
            tokens = tokens + [self.tokenizer.pad_id] * (seq_len + 1 - len(tokens))
        else:
            tokens = tokens[:seq_len + 1]
        return torch.tensor(tokens)

    def __iter__(self):
        for item in self.texts:
            tokens = self.parse(item['text'])
            yield tokens[:-1], tokens[-1]  # Inputs and the next token


def pad_collate_fn(batch):
    pad_id = tokenizer.pad_id
    inputs = [torch.cat([item[0], torch.tensor([pad_id] * (MAX_SEQ_LEN - len(item[0])))]) for item in batch]
    targets = [item[1] for item in batch]
    inputs = torch.stack(inputs)
    targets = torch.tensor(targets)
    return inputs, targets


def wait_for_file(filename, timeout=30):
    start_time = time.time()
    while not os.path.exists(filename):
        if time.time() - start_time > timeout:
            raise TimeoutError(f"Timeout waiting for {filename}")
        time.sleep(0.1)


def initialize_layers(model_args):
    def init_weights(m):
        if isinstance(m, (nn.Linear, ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)):
            nn.init.xavier_uniform_(m.weight)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.zeros_(m.bias)

    for layer_idx in range(model_args.n_layers):
        layer = TransformerBlock(layer_idx, model_args)
        layer.apply(init_weights)

        for param in layer.parameters():
            if torch.isnan(param).any() or torch.isinf(param).any():
                raise ValueError(f"NaNs or Infs detected in weights of layer {layer_idx}")

        save_layer_state_dict(layer.state_dict(), f"data/layer_{layer_idx}.pt")

    embedding = VocabParallelEmbedding(model_args.vocab_size, model_args.dim)
    embedding.apply(init_weights)

    for param in embedding.parameters():
        if torch.isnan(param).any() or torch.isinf(param).any():
            raise ValueError("NaNs or Infs detected in embedding weights")

    save_layer_state_dict(embedding.state_dict(), "data/embedding.pt")

    output = ColumnParallelLinear(model_args.dim, model_args.vocab_size, bias=False)
    output.apply(init_weights)

    for param in output.parameters():
        if torch.isnan(param).any() or torch.isinf(param).any():
            raise ValueError("NaNs or Infs detected in output weights")

    save_layer_state_dict(output.state_dict(), "data/output.pt")

    freqs_cis = precompute_freqs_cis(
        model_args.dim // model_args.n_heads,
        model_args.max_seq_len * 2,
        model_args.rope_theta,
    )
    save_to_disk(freqs_cis, "data/freqs_cis.pt")

def process_batch(batch, tokenizer, epoch_loss, learning_rate, beta1, beta2, epsilon, weight_decay, t):
    try:
        inputs, targets = batch
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        inputs, targets = inputs.to(device), targets.to(device)

        if torch.isnan(inputs).any() or torch.isinf(inputs).any():
            raise ValueError("NaNs or Infs detected in input batch")

        if torch.isnan(targets).any() or torch.isinf(targets).any():
            raise ValueError("NaNs or Infs detected in target batch")

        logging.info(f"Batch (token IDs): {inputs.shape}")
        logging.info(f"Targets: {targets.shape}")

        # Print the actual sample input and target
        input_text = tokenizer.decode(inputs[0].tolist())
        target_text = tokenizer.decode([targets[0].item()])
        print(f"Sample input: {input_text}")
        print(f"Expected next token: {target_text}")

        # Ensure targets tensor has correct shape [batch_size, seq_len]
        if targets.dim() == 1:
            targets = targets.unsqueeze(1)  # Add dimension if missing

        if targets.shape[1] != inputs.shape[1]:
            targets = targets.expand(inputs.shape[0], inputs.shape[1])

        save_to_disk(inputs, "data/inputs.pt")
        save_to_disk(targets, "data/targets.pt")

        run_command(f"python3 worker.py --task embed --batch data/inputs.pt --embedding_file data/embedding.pt --outputs data/inputs_embed.pt")
        wait_for_file("data/inputs_embed.pt")
        inputs_embed = load_from_disk("data/inputs_embed.pt")
        

        logging.info(f"Embedded inputs: {inputs_embed.shape}")

        check_for_nans(inputs_embed, "inputs after embedding")

        freqs_cis = load_from_disk("data/freqs_cis.pt")

        seqlen = inputs_embed.shape[1]
        mask = None
        if seqlen > 1:
            mask = torch.triu(torch.full((seqlen, seqlen), float('-inf')), diagonal=1).to(device)
            save_to_disk(mask, "data/mask.pt")

        for layer_idx in range(model_args.n_layers):
            inputs_file = f"data/inputs_layer_{layer_idx}.pt"
            logits_file = f"data/logits_layer_{layer_idx}.pt"
            state_dict_file = f"data/layer_{layer_idx}.pt"
            save_to_disk(inputs_embed, inputs_file)
            run_command(f"python3 worker.py --task forward --layer_idx {layer_idx} --inputs {inputs_file} --state_dict {state_dict_file} --freqs_cis data/freqs_cis.pt --outputs {logits_file} --mask data/mask.pt")
            wait_for_file(logits_file)
            inputs_embed, inputs, layer = load_from_disk(logits_file)

            logging.info(f"Inputs after layer {layer_idx}: {inputs_embed.shape}")

            if torch.isnan(inputs_embed).any() or torch.isinf(inputs_embed).any():
                raise ValueError(f"NaNs or Infs detected in inputs after layer {layer_idx}")

            check_for_nans(inputs_embed, f"inputs after layer {layer_idx}")

        norm = RMSNorm(model_args.dim, eps=model_args.norm_eps).to(device)
        inputs_embed = norm(inputs_embed)
        save_to_disk(inputs_embed, "data/final_inputs.pt")

        run_command(f"python3 worker.py --task final_logits --inputs data/final_inputs.pt --state_dict data/output.pt --logits_file data/logits.pt")
        wait_for_file("data/logits.pt")
        logits, inputs_embed, output_layer = load_from_disk("data/logits.pt")

        logging.info(f"Final logits: {logits.shape}")

        if torch.isnan(logits).any() or torch.isinf(logits).any():
            raise ValueError("NaNs or Infs detected in final logits")

        check_for_nans(logits, "final logits")

        run_command(f"python3 worker.py --task loss --logits data/logits.pt --targets data/targets.pt --loss_file data/loss.pt --logits_grad_file data/logits_grad.pt")
        wait_for_file("data/loss.pt")
        wait_for_file("data/logits_grad.pt")

        loss = load_from_disk("data/loss.pt")
        logging.info(f"Loss: {loss:.4f}")
        print(f"Loss: {loss:.4f}")
        epoch_loss += loss
        logits_grad, inputs, output_layer = load_from_disk("data/logits_grad.pt")

        logging.info(f"Logits gradients: {logits_grad.shape}")

        save_to_disk(logits_grad, "data/final_error.pt")
        final_logits_error_file = "data/error_output_final_logits.pt"
        run_command(f"python3 worker.py --task final_logits_backward --error data/final_error.pt --state_dict data/output.pt --inputs data/final_inputs.pt --error_output_file {final_logits_error_file}")

        # Backpropagate through the transformer layers
        for layer_idx in reversed(range(model_args.n_layers)):
            inputs_file = f"data/inputs_layer_{layer_idx}.pt"
            state_dict_file = f"data/layer_{layer_idx}.pt"
            error_output_file = f"data/error_layer_{layer_idx}.pt"
            if layer_idx == model_args.n_layers - 1:
                error_file = final_logits_error_file
            else:
                error_file = f"data/error_layer_{layer_idx + 1}.pt"

            run_command(f"python3 worker.py --task backward --inputs {inputs_file} --layer_idx {layer_idx} --error {error_file} --state_dict {state_dict_file} --error_output_file {error_output_file} --freqs_cis data/freqs_cis.pt --mask data/mask.pt")

            wait_for_file(error_output_file)

        # Backpropagate through the embedding layer
        embed_error_file = "data/error_output_embedding.pt"
        run_command(f"python3 worker.py --task embed_backward --error data/error_layer_0.pt --batch data/inputs.pt --embedding_file data/embedding.pt --error_output_file {embed_error_file}")

        # Update weights using AdamW optimizer
        for layer_idx in range(model_args.n_layers):
            error_output_file = f"data/error_layer_{layer_idx}.pt"
            run_command(f"python3 worker.py --task apply_adamw --layer_idx {layer_idx} --grads {error_output_file} --learning_rate {learning_rate} --beta1 {beta1} --beta2 {beta2} --epsilon {epsilon} --weight_decay {weight_decay} --t {t}")

        # Update weights for embedding and output layers
        run_command(f"python3 worker.py --task apply_adamw --layer_idx -1 --grads {final_logits_error_file} --learning_rate {learning_rate} --beta1 {beta1} --beta2 {beta2} --epsilon {epsilon} --weight_decay {weight_decay} --t {t}")
        run_command(f"python3 worker.py --task apply_adamw --layer_idx -2 --grads {embed_error_file} --learning_rate {learning_rate} --beta1 {beta1} --beta2 {beta2} --epsilon {epsilon} --weight_decay {weight_decay} --t {t}")

    except Exception as e:
        logging.error(f"Failed to process and save batch: {e}")
        raise


def run_command(command):
    try:
        result = subprocess.run(command, shell=True, text=True, capture_output=True, check=True)
        logging.info(result.stdout)
    except subprocess.CalledProcessError as e:
        logging.error(f"Command '{command}' failed with error: {e.stderr}")
        raise


def main():
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    os.environ['WORLD_SIZE'] = '1'
    os.environ['RANK'] = '0'

    dist.init_process_group(backend='nccl')
    initialize_model_parallel(model_parallel_size_=1)

    dataset = load_dataset("wikipedia", language="en", date="20240401", split='train', streaming=True, trust_remote_code=True)

    max_seq_len = MAX_SEQ_LEN
    batch_size = 2  # Adjust batch size here
    wiki_dataset = StreamingWikipediaDataset(dataset, tokenizer, max_seq_len)
    dataloader = DataLoader(wiki_dataset, batch_size=batch_size, shuffle=False, num_workers=6, collate_fn=pad_collate_fn)

    num_epochs = 3
    learning_rate = 1e-4
    beta1 = 0.9
    beta2 = 0.999
    epsilon = 1e-8
    weight_decay = 1e-2

    initialize_layers(model_args)

    for epoch in range(num_epochs):
        epoch_loss = 0

        for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
            try:
                process_batch(batch, tokenizer, epoch_loss, learning_rate, beta1, beta2, epsilon, weight_decay, step + 1)
            except Exception as e:
                logging.error(f"Error processing batch at step {step}: {e}")
                continue

        print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(dataloader)}")

        dist.destroy_process_group()


def check_for_nans(tensor, name):
    if torch.isnan(tensor).any():
        print(f"NaNs detected in {name}")


if __name__ == "__main__":
    main()


> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Layer state dict saved to data/layer_0.pt
Layer state dict saved to data/layer_1.pt
Layer state dict saved to data/layer_2.pt
Layer state dict saved to data/layer_3.pt
Layer state dict saved to data/embedding.pt
Layer state dict saved to data/output.pt
Saved to data/freqs_cis.pt


Epoch 1/3: 0it [00:00, ?it/s]INFO:datasets_modules.datasets.wikipedia.d41137e149b2ea90eead07e7e3f805119a8c22dd1d5b61651af8e3e3ee736001.wikipedia:generating examples from = https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream1.xml-p1p41242.bz2
INFO:datasets_modules.datasets.wikipedia.d41137e149b2ea90eead07e7e3f805119a8c22dd1d5b61651af8e3e3ee736001.wikipedia:generating examples from = https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream3.xml-p151574p311329.bz2
INFO:datasets_modules.datasets.wikipedia.d41137e149b2ea90eead07e7e3f805119a8c22dd1d5b61651af8e3e3ee736001.wikipedia:generating examples from = https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream4.xml-p311330p558391.bz2


Extracting content from https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream1.xml-p1p41242.bz2
Extracting content from

INFO:datasets_modules.datasets.wikipedia.d41137e149b2ea90eead07e7e3f805119a8c22dd1d5b61651af8e3e3ee736001.wikipedia:generating examples from = https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream2.xml-p41243p151573.bz2


Extracting content from

INFO:datasets_modules.datasets.wikipedia.d41137e149b2ea90eead07e7e3f805119a8c22dd1d5b61651af8e3e3ee736001.wikipedia:generating examples from = https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream5.xml-p558392p958045.bz2


  Extracting content fromExtracting content fromhttps://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream3.xml-p151574p311329.bz2https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream4.xml-p311330p558391.bz2  

https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream2.xml-p41243p151573.bz2https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream5.xml-p558392p958045.bz2



INFO:datasets_modules.datasets.wikipedia.d41137e149b2ea90eead07e7e3f805119a8c22dd1d5b61651af8e3e3ee736001.wikipedia:generating examples from = https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream6.xml-p958046p1483661.bz2


Extracting content from https://dumps.wikimedia.org/enwiki/20240401/enwiki-20240401-pages-articles-multistream6.xml-p958046p1483661.bz2


Got disconnected from remote data host. Retrying in 5sec [1/20]
Got disconnected from remote data host. Retrying in 5sec [1/20]
Got disconnected from remote data host. Retrying in 5sec [1/20]
INFO:root:Batch (token IDs): torch.Size([2, 2048])
INFO:root:Targets: torch.Size([2])


Sample input: <bAnarchism is a political philosophy and movement that is against all forms of authority and seeks to abolish the institutions it claims maintain unnecessary coercion and hierarchy, typically including the state and capitalism. Anarchism advocates for the replacement of the state with stateless societies and voluntary free associations. As a historically left-wing movement, this reading of anarchism is placed on the farthest left of the political spectrum, usually described as the libertarian wing of the socialist movement (libertarian socialism).

Although traces of anarchist ideas are found all throughout history, modern anarchism emerged from the Enlightenment. During the latter half of the 19th and the first decades of the 20th century, the anarchist movement flourished in most parts of the world and had a significant role in workers' struggles for emancipation. Various anarchist schools of thought formed during this period. Anarchists have taken part in several revo

INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/inputs.pt
Layer state dict loaded from data/embedding.pt
Saved to data/inputs_embed.pt

INFO:root:Embedded inputs: torch.Size([2, 2048, 512])


Loaded from data/inputs_embed.pt
Loaded from data/freqs_cis.pt
Saved to data/mask.pt
Saved to data/inputs_layer_0.pt


INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/inputs_layer_0.pt
Layer state dict loaded from data/layer_0.pt
Loaded from data/freqs_cis.pt
Loaded from data/mask.pt
Saved to data/logits_layer_0.pt

INFO:root:Inputs after layer 0: torch.Size([2, 2048, 512])


Loaded from data/logits_layer_0.pt
Saved to data/inputs_layer_1.pt


INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/inputs_layer_1.pt
Layer state dict loaded from data/layer_1.pt
Loaded from data/freqs_cis.pt
Loaded from data/mask.pt
Saved to data/logits_layer_1.pt

INFO:root:Inputs after layer 1: torch.Size([2, 2048, 512])


Loaded from data/logits_layer_1.pt
Saved to data/inputs_layer_2.pt


INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/inputs_layer_2.pt
Layer state dict loaded from data/layer_2.pt
Loaded from data/freqs_cis.pt
Loaded from data/mask.pt
Saved to data/logits_layer_2.pt

INFO:root:Inputs after layer 2: torch.Size([2, 2048, 512])


Loaded from data/logits_layer_2.pt
Saved to data/inputs_layer_3.pt


INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/inputs_layer_3.pt
Layer state dict loaded from data/layer_3.pt
Loaded from data/freqs_cis.pt
Loaded from data/mask.pt
Saved to data/logits_layer_3.pt

INFO:root:Inputs after layer 3: torch.Size([2, 2048, 512])


Loaded from data/logits_layer_3.pt
Saved to data/final_inputs.pt


INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/final_inputs.pt
Layer state dict loaded from data/output.pt
Saved to data/logits.pt

INFO:root:Final logits: torch.Size([2, 2048, 100277])


Loaded from data/logits.pt


INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/logits.pt
Loaded from data/targets.pt
Saved to data/loss.pt
Saved to data/logits_grad.pt

INFO:root:Loss: 11.4862


Loaded from data/loss.pt
Loss: 11.4862


INFO:root:Logits gradients: torch.Size([2, 2048, 100277])


Loaded from data/logits_grad.pt
Saved to data/final_error.pt


INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/final_inputs.pt
Loaded from data/final_error.pt
Layer state dict loaded from data/output.pt
Saved to data/error_output_final_logits.pt

INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/error_output_final_logits.pt
Loaded from data/inputs_layer_3.pt
Layer state dict loaded from data/layer_3.pt
Loaded from data/freqs_cis.pt
Loaded from data/mask.pt
Saved to data/error_layer_3.pt

INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/error_layer_3.pt
Loaded from data/inputs_layer_2.pt
Layer state dict loaded from data/layer_2.pt
Loaded from data/freqs_cis.pt
Loaded from data/mask.pt
Saved to data/error_layer_2.pt

INFO:root:> initializing model parallel with size 1
> initializing ddp with 

Sample input: <bA hotline is a point-to-point communications link in which a call is automatically directed to the preselected destination without any additional action by the user when the end instrument goes off-hook. An example would be a phone that automatically connects to emergency services on picking up the receiver. Therefore, dedicated hotline phones do not need a rotary dial or keypad. A hotline can also be called an automatic signaling, ringdown, or off-hook service.

For crises and service 
True hotlines cannot be used to originate calls other than to preselected destinations.  However, in common or colloquial usage, a "hotline" often refers to a call center reachable by dialing a standard telephone number, or sometimes the phone numbers themselves.

This is especially the case with 24-hour, noncommercial numbers, such as police tip hotlines or suicide crisis hotlines, which are staffed around the clock and thereby give the appearance of real hotlines.  Increasingly, howeve

INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/inputs.pt
Layer state dict loaded from data/embedding.pt
Saved to data/inputs_embed.pt

INFO:root:Embedded inputs: torch.Size([2, 2048, 512])


Loaded from data/inputs_embed.pt
Loaded from data/freqs_cis.pt
Saved to data/mask.pt
Saved to data/inputs_layer_0.pt


INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/inputs_layer_0.pt
Layer state dict loaded from data/layer_0.pt
Loaded from data/freqs_cis.pt
Loaded from data/mask.pt
Saved to data/logits_layer_0.pt

INFO:root:Inputs after layer 0: torch.Size([2, 2048, 512])


Loaded from data/logits_layer_0.pt
Saved to data/inputs_layer_1.pt


INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/inputs_layer_1.pt
Layer state dict loaded from data/layer_1.pt
Loaded from data/freqs_cis.pt
Loaded from data/mask.pt
Saved to data/logits_layer_1.pt

INFO:root:Inputs after layer 1: torch.Size([2, 2048, 512])


Loaded from data/logits_layer_1.pt
Saved to data/inputs_layer_2.pt


INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/inputs_layer_2.pt
Layer state dict loaded from data/layer_2.pt
Loaded from data/freqs_cis.pt
Loaded from data/mask.pt
Saved to data/logits_layer_2.pt

INFO:root:Inputs after layer 2: torch.Size([2, 2048, 512])


Loaded from data/logits_layer_2.pt
Saved to data/inputs_layer_3.pt


INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/inputs_layer_3.pt
Layer state dict loaded from data/layer_3.pt
Loaded from data/freqs_cis.pt
Loaded from data/mask.pt
Saved to data/logits_layer_3.pt

INFO:root:Inputs after layer 3: torch.Size([2, 2048, 512])


Loaded from data/logits_layer_3.pt
Saved to data/final_inputs.pt


INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/final_inputs.pt
Layer state dict loaded from data/output.pt
Saved to data/logits.pt

INFO:root:Final logits: torch.Size([2, 2048, 100277])


Loaded from data/logits.pt


INFO:root:> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded from data/logits.pt
Loaded from data/targets.pt
Saved to data/loss.pt
Saved to data/logits_grad.pt

INFO:root:Loss: 11.5948


Loaded from data/loss.pt
Loss: 11.5948
