In [1]:
model_id = "meta-llama/Llama-3.2-1B-Instruct" 
CONFIG = "training_shakespeare_llama3_2_1B_fixed.yaml"
#max lenght 704
#accum 3 (max, 4 oom on n150) 
# broken model 

In [2]:
model_id = "Qwen/Qwen3-1.7B" 
CONFIG = "training_shakespeare_qwen3_1_7B.yaml"

model_id = "Qwen/Qwen3-0.6B" 
CONFIG = "training_shakespeare_qwen3_0_6B.yaml"
#max context 

model_id =  "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
CONFIG = "training_shakespeare_tinyllama.yaml" # must be working

In [3]:
import ttml

In [4]:
import os
import sys
import datasets
import numpy as np
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from typing import Optional

import ttml
from ttml.common.config import get_config, TrainingConfig, SchedulerConfig, DeviceConfig
from ttml.common.model_factory import TransformerModelFactory
from ttml.common.utils import round_up_to_tile, create_optimizer, initialize_device
from ttml.common.data import build_causal_mask

#import tt_serialization

In [5]:
class SpeedrunScheduler:
    """Linear warmup -> optional hold -> linear decay; optional beta1 warmup."""

    def __init__(self, cfg: SchedulerConfig):
        self.cfg = cfg

    def lr_at(self, step: int) -> float:
        s = step
        w = max(0, self.cfg.warmup_steps)
        h = max(0, self.cfg.hold_steps)
        T = max(1, self.cfg.total_steps)
        peak = self.cfg.max_lr
        min_lr = self.cfg.min_lr

        if s <= w:
            # linear warmup 0 -> lr_max
            return peak * (s / max(1, w))
        elif s <= w + h:
            # hold at lr_max
            return peak
        else:
            # linear decay from lr_max at (w+h) to min_lr at T
            s2 = min(s, T)
            frac = (s2 - (w + h)) / max(1, (T - (w + h)))
            return peak + (min_lr - peak) * frac

    def beta1_at(self, step: int) -> Optional[float]:
        if (
            self.cfg.beta1_start is None
            or self.cfg.beta1_end is None
            or self.cfg.beta1_warmup_steps <= 0
        ):
            return None
        s = min(step, self.cfg.beta1_warmup_steps)
        t = s / float(self.cfg.beta1_warmup_steps)
        return (1.0 - t) * self.cfg.beta1_start + t * self.cfg.beta1_end



In [6]:
class OptimParamSetter:
    def __init__(self, optim):
        self.optim = optim
        self._warned_lr = False
        self._warned_beta1 = False

    def set_lr(self, lr: float):
        self.optim.set_lr(float(lr))

    def set_beta1(self, beta1: float):
        raise NotImplementedError(
            "set_beta1 is not implemented in TTML AdamW optimizer."
        )


def build_logits_mask(vocab_size: int, padded_vocab_size: int) -> ttml.autograd.Tensor:
    logits_mask = np.zeros((1, 1, 1, padded_vocab_size), dtype=np.float32)
    logits_mask[:, :, :, vocab_size:] = 1e4
    return ttml.autograd.Tensor.from_numpy(
        logits_mask, ttml.Layout.TILE, ttml.autograd.DataType.BFLOAT16
    )  # [1,1,1,T], bfloat16"


class CollateFn:
    def __init__(self, eos_token_id, max_sequence_length, padded_vocab_size):
        self.eos_token_id = eos_token_id
        self.max_sequence_length = max_sequence_length
        self.padded_vocab_size = padded_vocab_size

    def collate_fn(self, batch):
        X = [sample[0] for sample in batch]
        Y = [sample[1] for sample in batch]

        batch_size = len(X)

        data_np = np.full(
            (batch_size, self.max_sequence_length), self.eos_token_id, dtype=np.uint32
        )
        mask_lens = []

        for i in range(batch_size):
            x_tokens = X[i]
            y_tokens = Y[i]

            # Concatenate question + answer
            combined_length = len(x_tokens) + len(y_tokens)
            if combined_length > self.max_sequence_length:
                # Truncate if too long, prioritizing keeping the answer
                available_space = self.max_sequence_length - len(y_tokens)
                if available_space > 0:
                    x_tokens = x_tokens[:available_space]
                    data_np[i, : len(x_tokens)] = x_tokens
                    data_np[i, len(x_tokens) : len(x_tokens) + len(y_tokens)] = y_tokens

                else:
                    # If answer is too long, just use the answer
                    data_np[i, : self.max_sequence_length] = y_tokens[
                        : self.max_sequence_length
                    ]
                    x_tokens = []

            else:
                # Normal case: concatenate question + answer

                data_np[i, : len(x_tokens)] = x_tokens
                data_np[i, len(x_tokens) : len(x_tokens) + len(y_tokens)] = y_tokens

            mask_lens.append(len(x_tokens))

        # Shape: [batch_size, 1, 1, max_sequence_length]
        X_np = np.expand_dims(data_np, axis=(1, 2))

        y_np = np.full(
            (batch_size, self.max_sequence_length), self.eos_token_id, dtype=np.uint32
        )  # Shape: [batch, seq_len]
        y_np[:, 0:-1] = X_np[:, 0, 0, 1:]  # Shift left by 1

        loss_scaler_np = np.full(
            (batch_size, 1, self.max_sequence_length, 1), 1.0, dtype=np.float32
        )
        for i, mask_len in enumerate(mask_lens):
            loss_scaler_np[i, :, :mask_len, :] = 0.0
            pad_positions = X_np[i, 0, 0, :] == self.eos_token_id
            loss_scaler_np[i, :, pad_positions, :] = 0.0
        loss_scaler_ratio = (
            self.max_sequence_length * batch_size / np.sum(loss_scaler_np)
        )
        loss_scaler_np = loss_scaler_np * loss_scaler_ratio

        return X_np, y_np, loss_scaler_np

    def __call__(self, batch):
        return self.collate_fn(batch)



In [7]:
def get_batch_generator(
    dataloader,
    batch_size,
    max_sequence_length,
    padded_vocab_size,
    tokenizer,
    device_config=None,
):
    """Custom data generator for GSM8K dataset."""
    mapper = None
    if device_config is not None:
        device = ttml.autograd.AutoContext.get_instance().get_device()
        mapper = ttml.core.distributed.shard_tensor_to_mesh_mapper(device, 0)

    while True:
        for batch in dataloader:
            X_np, y_np, loss_scaler_np = batch

            X = ttml.autograd.Tensor.from_numpy(
                X_np, ttml.Layout.ROW_MAJOR, ttml.autograd.DataType.UINT32, mapper
            )
            y = ttml.autograd.Tensor.from_numpy(
                y_np, ttml.Layout.ROW_MAJOR, ttml.autograd.DataType.UINT32, mapper
            )
            loss_scaler = ttml.autograd.Tensor.from_numpy(
                loss_scaler_np,
                ttml.Layout.TILE,
                ttml.autograd.DataType.BFLOAT16,
                mapper,
            )

            yield (X, y, loss_scaler)


In [8]:
def generate_text_tt(
    model,
    tokenizer,
    question,
    max_sequence_length,
    causal_mask,
    temperature,
    logits_mask_tensor,
    max_gen_tokens,
    pad_token_id=None,
    return_with_prompt=False,
):
    """
    Greedy/temperature=0 generation that prints the *full* text once at the end.
    Uses a sliding window if prompt exceeds max_sequence_length.
    """
    model.eval()
    ttml.autograd.AutoContext.get_instance().set_gradient_mode(
        ttml.autograd.GradMode.DISABLED
    )

    # --- Tokenize once ---
    prompt_tokens = tokenizer.encode(question)
    if pad_token_id is None:
        # Try tokenizer.pad_token_id, else fall back to 0
        pad_token_id = getattr(tokenizer, "pad_token_id", None)
        if pad_token_id is None:
            pad_token_id = 0

    generated_tokens = []

    device = ttml.autograd.AutoContext.get_instance().get_device()
    composer = ttml.core.distributed.concat_mesh_to_tensor_composer(device, 0)

    # Preallocate once
    padded_prompt_tokens = np.full(
        (1, 1, 1, max_sequence_length), pad_token_id, dtype=np.uint32
    )
    for _ in tqdm(range(max_gen_tokens)):
        # Sliding window for long prompts
        if len(prompt_tokens) > max_sequence_length:
            start_idx = len(prompt_tokens) - max_sequence_length
            window = prompt_tokens[start_idx:]
        else:
            start_idx = 0
            window = prompt_tokens

        # Refill buffer (fully) to avoid stale ids
        padded_prompt_tokens[...] = pad_token_id
        padded_prompt_tokens[0, 0, 0, : len(window)] = np.asarray(
            window, dtype=np.uint32
        )

        # [1,1,1,T] -> TT tensor
        padded_prompt_tensor = ttml.autograd.Tensor.from_numpy(
            padded_prompt_tokens, ttml.Layout.ROW_MAJOR, ttml.autograd.DataType.UINT32
        )

        # Forward: logits [1,1,T,V]
        logits = model(padded_prompt_tensor, causal_mask)

        # Sample: next tokens for all positions [1,1,T,1]
        # With temperature=0.0 this behaves like argmax/greedy.
        next_token_tensor = ttml.ops.sample.sample_op(
            logits, 0.0, np.random.randint(low=1e7), logits_mask_tensor
        )

        # Take the token at the last active position in the current window
        next_token_idx = (
            max_sequence_length - 1
            if len(prompt_tokens) > max_sequence_length
            else len(window) - 1
        )
        next_token = int(
            next_token_tensor.to_numpy(composer=composer).reshape(-1, 1)[
                next_token_idx
            ][0]
        )

        if next_token == tokenizer.eos_token_id:
            break

        generated_tokens.append(next_token)
        prompt_tokens.append(next_token)

    # Decode once at the end
    out = tokenizer.decode(generated_tokens)
    if return_with_prompt:
        out = tokenizer.decode(prompt_tokens)

    ttml.autograd.AutoContext.get_instance().set_gradient_mode(
        ttml.autograd.GradMode.ENABLED
    )
    return out


In [9]:
def validate(
    tt_model,
    tokenizer,
    val_batch_generator,
    testing_data,
    loss_fn,
    causal_mask,
    logits_mask_tensor,
    max_sequence_length,
    max_gen_tokens,
    current_step,
):
    reduce = ttml.ops.ReduceType.NONE
    ttml.autograd.AutoContext.get_instance().set_gradient_mode(
        ttml.autograd.GradMode.DISABLED
    )
    tt_model.eval()
    eval_batch_count = 4
    cur_val_losses = []
    for _ in range(eval_batch_count):
        val_X, val_y, val_loss_scaler = next(val_batch_generator)
        val_logits = tt_model(val_X, causal_mask)

        # Compute validation loss
        val_loss = loss_fn(val_logits, val_y, reduce)
        val_loss = val_loss * val_loss_scaler
        val_loss = ttml.ops.unary.mean(val_loss)
        cur_val_losses.append(get_loss_over_devices(val_loss))

    checks_count = 4

    with open("validation.txt", "a+") as val_file:
        val_file.write(f"Validation at step {current_step}\n")
        for check in range(checks_count):
            val_file.write(f"Validation check: {check}\n")
            val_file.write("====================================\n")

            tokenized_question, tokenized_answer = testing_data[check]
            question = tokenizer.decode(tokenized_question, skip_special_tokens=True)

            val_file.write(f"Question: {question}\n")
            val_file.write("====================================\n")

            gen_text = generate_text_tt(
                tt_model,
                tokenizer,
                question,
                max_sequence_length,
                causal_mask,
                0.0,
                logits_mask_tensor,
                max_gen_tokens
            )

            val_file.write(f"Generated Answer: {gen_text}\n")
            val_file.write("\n====================================\n")

        val_file.write(
            "Last validation loss: {:.4f}\n\n\n".format(np.mean(cur_val_losses))
        )

    ttml.autograd.AutoContext.get_instance().set_gradient_mode(
        ttml.autograd.GradMode.ENABLED
    )
    tt_model.train()
    return np.mean(cur_val_losses)



In [10]:
def adjust_logits(logits, binary_mask, add_mask):
    masked_logits = binary_mask * logits
    masked_logits = masked_logits + add_mask

    return masked_logits


def get_loss_over_devices(loss):
    device = ttml.autograd.AutoContext.get_instance().get_device()
    composer = ttml.core.distributed.concat_mesh_to_tensor_composer(device, 0)
    loss_numpy = loss.to_numpy(composer=composer)
    return loss_numpy.mean()


def tokenize_dataset(data, tokenizer):
    X = [sample["question"] for sample in data]
    y = [sample["answer"] for sample in data]

    X = tokenizer(X, return_tensors="np", add_special_tokens=False)["input_ids"]
    y = tokenizer(y, return_tensors="np", add_special_tokens=False)["input_ids"]
    return X, y


In [11]:
class TokenizedDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.len = len(X)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


In [12]:
print("Loading tokenizer and config...")
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# Disable tokenizer parallelism to avoid conflicts with DataLoader multiprocessing
tokenizer = AutoTokenizer.from_pretrained(
    model_id
)
yaml_config = get_config(CONFIG)

Loading tokenizer and config...


In [13]:
yaml_config['training_config']['transformer_config']['max_sequence_length']

2048

In [14]:
# Load dataset
print("Loading GSM8K dataset...")
training_data = datasets.load_dataset("gsm8k", "main", split="train")
testing_data = datasets.load_dataset("gsm8k", "main", split="test")

training_data_x, training_data_y = tokenize_dataset(training_data, tokenizer)
testing_data_x, testing_data_y = tokenize_dataset(testing_data, tokenizer)
training_data = TokenizedDataset(training_data_x, training_data_y)
testing_data = TokenizedDataset(testing_data_x, testing_data_y)

max_gen_tokens = max(max(s.shape[0] for s in training_data_y),
                     max(s.shape[0] for s in testing_data_y))
max_seq_lenght = max(max(s.shape[0] for s in training_data_x),
                     max(s.shape[0] for s in testing_data_x)) + max_gen_tokens

max_seq_lenght = round_up_to_tile(max_seq_lenght)
print(max_seq_lenght)




Loading GSM8K dataset...


Using custom data configuration main-f9306ececa7c2eca
Found cached dataset parquet (/home/ubuntu/.cache/huggingface/datasets/parquet/main-f9306ececa7c2eca/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Using custom data configuration main-f9306ececa7c2eca
Found cached dataset parquet (/home/ubuntu/.cache/huggingface/datasets/parquet/main-f9306ececa7c2eca/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


800


In [15]:
yaml_config['training_config']

{'project_name': 'tt_train_qwen3',
 'model_type': 'qwen3',
 'seed': 5489,
 'model_save_interval': 1,
 'batch_size': 1,
 'num_epochs': 1,
 'max_steps': 5000,
 'learning_rate': 0.0003,
 'weight_decay': 0.01,
 'use_moreh_adamw': True,
 'use_kahan_summation': False,
 'use_clip_grad_norm': False,
 'clip_grad_norm_max_norm': 1.0,
 'tokenizer_path': 'data/qwen-tokenizer.json',
 'tokenizer_type': 'bpe',
 'transformer_config': {'num_heads': 16,
  'num_groups': 8,
  'embedding_dim': 2048,
  'head_dim': 128,
  'intermediate_dim': 6144,
  'dropout_prob': 0.0,
  'num_blocks': 28,
  'weight_tying': 'enabled',
  'vocab_size': 151936,
  'max_sequence_length': 2048,
  'runner_type': 'memory_efficient',
  'theta': 1000000.0,
  'rms_norm_eps': 1e-06}}

In [16]:
max_seq_lenght

800

In [17]:
#max_seq_lenght = 128

In [18]:
print('overrride max sequence length:', yaml_config['training_config']['transformer_config']['max_sequence_length'], max_seq_lenght)
yaml_config['training_config']['transformer_config']['max_sequence_length'] = max_seq_lenght
yaml_config['training_config']['gradient_accumulation_steps'] = 128
yaml_config['training_config']['max_steps'] = 500

overrride max sequence length: 2048 800


In [19]:
training_config = TrainingConfig(yaml_config)
scheduler_config = SchedulerConfig(yaml_config)

batch_size = training_config.batch_size

# initialize device
device_config = DeviceConfig(yaml_config)
# no need to initialize device if #devices=1
if device_config.total_devices() > 1:
    initialize_device(yaml_config)

# Download safetensors
print("Downloading safetensors...")
safetensors_path = hf_hub_download(repo_id=model_id, filename="config.json")
safetensors_path = safetensors_path.replace("config.json","")

Downloading safetensors...


In [20]:
training_config.gradient_accumulation_steps

128

In [21]:
import torch
from transformers import AutoModelForCausalLM
torch.manual_seed(42)
torch_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [22]:
orig_vocab_size = torch_model.vocab_size

In [23]:
tt_model_factory = TransformerModelFactory(yaml_config)
tt_model_factory.transformer_config.vocab_size = orig_vocab_size

In [24]:
max_sequence_length = tt_model_factory.transformer_config.max_sequence_length

In [25]:
from time import time

In [26]:
ttml.autograd.AutoContext.get_instance().set_init_mode(ttml.autograd.InitMode.DISABLED)

In [27]:
start_time = time()
tt_model = tt_model_factory.create_model()
print(f"Model created: {time() - start_time}" )

Qwen3 configuration:
    Vocab size: 151936
    Max sequence length: 800
    Embedding dim (hidden_size): 2048
    Head dim: 128
    Attention output dim: 2048
    Intermediate dim: 6144
    Num heads: 16
    Num groups (KV heads): 8
    Dropout probability: 0
    Num blocks: 28
    Positional embedding type: RoPE
    Runner type: Memory efficient
    Weight tying: Enabled
    Theta: 1000000
    RMSNorm epsilon: 1e-06
2025-11-27 23:38:56.786 | info     |             UMD | Established cluster ETH FW version: 6.14.0 (topology_discovery_wormhole.cpp:359)
2025-11-27 23:38:56.791 | info     |          Device | Opening user mode device driver (tt_cluster.cpp:209)
2025-11-27 23:38:56.810 | info     |             UMD | Established cluster ETH FW version: 6.14.0 (topology_discovery_wormhole.cpp:359)
2025-11-27 23:38:56.850 | info     |             UMD | Established cluster ETH FW version: 6.14.0 (topology_discovery_wormhole.cpp:359)
2025-11-27 23:38:56.853 | info     |             UMD | Harvest

In [28]:
# Model created: 133.07612419128418
# Model loaded: 27.28171706199646

In [29]:
# Model created: 68.56190633773804
# Model loaded: 26.858034372329712

In [30]:
# Model created: 9.410035133361816


In [31]:
start_time = time()
tt_model.load_from_safetensors(safetensors_path)
print(f"Model loaded: {time() - start_time}" )

Model loaded: 32.21319389343262


In [32]:
padded_vocab_size = round_up_to_tile(orig_vocab_size, 32)
if orig_vocab_size != padded_vocab_size:
    print(f"Padding vocab size for tilization: original {orig_vocab_size} -> padded {padded_vocab_size}")

In [33]:

training_dataloader = DataLoader(
    training_data,
    batch_size=batch_size,
    shuffle=True,  # Shuffle the dataset for each epoch
    drop_last=True,
    num_workers=0,
    collate_fn=CollateFn(
        tokenizer.eos_token_id, max_seq_lenght, padded_vocab_size
    ),
)


num_devices = device_config.total_devices()
testing_dataloader = DataLoader(
    testing_data,
    batch_size=training_config.validation_batch_size * num_devices,
    shuffle=False,  # Disable shuffling for validation
    drop_last=True,
    num_workers=0,
    collate_fn=CollateFn(
        tokenizer.eos_token_id, max_sequence_length, padded_vocab_size
    ),
)

In [34]:
num_devices

1

In [35]:
# Setup training
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

optim = create_optimizer(tt_model, yaml_config)
causal_mask = build_causal_mask(max_sequence_length)

causal_mask = ttml.autograd.Tensor.from_numpy(
    causal_mask, ttml.Layout.ROW_MAJOR, ttml.autograd.DataType.BFLOAT16
)

logits_mask_tensor = build_logits_mask(orig_vocab_size, padded_vocab_size)

loss_fn = ttml.ops.loss.cross_entropy_loss
reduce = ttml.ops.ReduceType.NONE

In [36]:

# Training setup
tt_model.train()
train_losses = []
val_losses = []

train_batch_generator = get_batch_generator(
    training_dataloader,
    batch_size,
    max_sequence_length,
    padded_vocab_size,
    tokenizer,
    device_config,
)

val_batch_generator = get_batch_generator(
    testing_dataloader,
    training_config.validation_batch_size * num_devices,
    max_sequence_length,
    padded_vocab_size,
    tokenizer,
    device_config,
)

In [37]:
tokens_per_batch = batch_size * max_sequence_length
print("Tokens per micro-batch:", tokens_per_batch)
print(
    "Tokens per accumulated batch:",
    tokens_per_batch * training_config.gradient_accumulation_steps,
)

sched = SpeedrunScheduler(scheduler_config)
setter = OptimParamSetter(optim)

f = open("validation.txt", "w")
f.write("Validation log\n")
f.write("===============\n")
f.close()



Tokens per micro-batch: 800
Tokens per accumulated batch: 102400


In [38]:
accum_steps = training_config.gradient_accumulation_steps

In [39]:
print(
    f"Starting training for {training_config.epochs} epochs, max {training_config.steps} steps..."
)
bar = tqdm(range(1, training_config.steps + 1))

total_steps = 0
last_val_loss = 0
accum_steps = training_config.gradient_accumulation_steps


# ========== Training Loop ===========
for opt_step in bar:
    # LR (and optional beta1) updated once per optimizer step
    optim.zero_grad()
    lr_now = sched.lr_at(opt_step - 1)  # zero-based inside scheduler
    setter.set_lr(lr_now)

    # ---- internal micro-steps ----
    # Aggregate the true (unscaled) mean losses across micro-steps to report per optimizer step.
    micro_losses = []

    for micro in range(accum_steps):
        X, y, loss_scaler = next(train_batch_generator)

        # Forward
        logits = tt_model(X, causal_mask)  # [B,1,T,V]

        # CE on masked logits
        loss = loss_fn(logits, y, reduce)  # [B,1,T,1] shape reduced later
        loss = loss * loss_scaler
        loss = ttml.ops.unary.mean(loss)  # scalar

        # Track true loss for reporting
        # micro_losses.append(float(loss.to_numpy()))
        micro_losses.append(get_loss_over_devices(loss))

        # Scale for accumulation and backward
        scaled_loss = ttml.ops.binary.mul(
            loss, 1 / accum_steps
        ) 
        scaled_loss.backward(False)
        ttml.autograd.AutoContext.get_instance().reset_graph()

    # Synchronize gradients if DDP is enabled
    if device_config.enable_ddp:
        ttml.core.distributed.synchronize_parameters(tt_model.parameters())

    # Optimizer step after micro-steps
    optim.step()

    # Average loss across micro-steps (this corresponds to the optimizer step)
    step_loss = float(np.mean(micro_losses)) if len(micro_losses) > 0 else 0.0
    train_losses.append(step_loss)

    # tqdm postfix
    postfix = {"train_loss": f"{step_loss:.4f}", "lr": f"{lr_now:.6f}"}
    if last_val_loss is not None:
        postfix["val_loss"] = f"{last_val_loss:.4f}"
    bar.set_postfix(postfix, refresh=False)

    # Validation every eval_every steps
    if ( 
        total_steps % training_config.eval_every == 0
        or total_steps + 1 == training_config.steps
    ):
        last_val_loss = validate(
            tt_model,
            tokenizer,
            val_batch_generator,
            testing_data,
            loss_fn,
            causal_mask,
            logits_mask_tensor,
            max_sequence_length=max_seq_lenght,
            max_gen_tokens=max_gen_tokens,
            current_step=total_steps,
        )
        
        val_losses.append(last_val_loss)

    total_steps += 1

Starting training for 1 epochs, max 500 steps...


  0%|          | 0/500 [00:00<?, ?it/s]


All 310 parameters were successfully loaded and used.
2025-11-27 23:40:53.633 | info     |            Test | Small tensor algorithm selected (softmax_backward_w_small.cpp:18)
2025-11-27 23:40:59.234 | critical |          Always | Out of Memory: Not enough space to allocate 9830400 B DRAM buffer across 12 banks, where each bank needs to store 819200 B, but bank size is only 1071181792 B (assert.hpp:103)


RuntimeError: TT_FATAL @ /home/ubuntu/tt-metal/tt_metal/impl/allocator/bank_manager.cpp:431: address.has_value()
info:
Out of Memory: Not enough space to allocate 9830400 B DRAM buffer across 12 banks, where each bank needs to store 819200 B, but bank size is only 1071181792 B
backtrace:
 --- /home/ubuntu/tt-metal/build/lib/libtt_metal.so(+0x524cdd) [0x7efdb2191cdd]
 --- tt::tt_metal::BankManager::allocate_buffer(unsigned long, unsigned long, bool, CoreRangeSet const&, std::optional<unsigned int>, ttsl::StrongType<unsigned int, tt::tt_metal::AllocatorIDTag>)
 --- tt::tt_metal::Allocator::allocate_buffer(tt::tt_metal::Buffer*)
 --- tt::tt_metal::Buffer::allocate_impl()
 --- tt::tt_metal::Buffer::create(tt::tt_metal::IDevice*, unsigned long, unsigned long, tt::tt_metal::BufferType, tt::tt_metal::BufferShardingArgs const&, std::optional<bool>, std::optional<ttsl::StrongType<unsigned char, tt::tt_metal::SubDeviceIdTag> >)
 --- tt::tt_metal::distributed::MeshBuffer::create(std::variant<tt::tt_metal::distributed::ReplicatedBufferConfig, tt::tt_metal::distributed::ShardedBufferConfig> const&, tt::tt_metal::distributed::DeviceLocalBufferConfig const&, tt::tt_metal::distributed::MeshDevice*, std::optional<unsigned long>)
 --- tt::tt_metal::tensor_impl::allocate_device_buffer(tt::tt_metal::distributed::MeshDevice*, tt::tt_metal::TensorSpec const&)
 --- tt::tt_metal::allocate_tensor_on_device(tt::tt_metal::TensorSpec const&, tt::tt_metal::distributed::MeshDevice*)
 --- tt::tt_metal::create_device_tensor(tt::tt_metal::TensorSpec const&, tt::tt_metal::IDevice*)
 --- ttnn::operations::binary_ng::BinaryNgDeviceOperation::create_output_tensors(ttnn::operations::binary_ng::BinaryNgDeviceOperation::operation_attributes_t const&, ttnn::operations::binary_ng::BinaryNgDeviceOperation::tensor_args_t const&)
 --- ttnn::operations::binary_ng::BinaryNgDeviceOperation::tensor_return_value_t ttnn::device_operation::detail::launch_on_device<ttnn::operations::binary_ng::BinaryNgDeviceOperation>(ttnn::operations::binary_ng::BinaryNgDeviceOperation::operation_attributes_t const&, ttnn::operations::binary_ng::BinaryNgDeviceOperation::tensor_args_t const&)
 --- ttnn::operations::binary_ng::BinaryNgDeviceOperation::tensor_return_value_t ttnn::device_operation::detail::invoke<ttnn::operations::binary_ng::BinaryNgDeviceOperation>(ttnn::operations::binary_ng::BinaryNgDeviceOperation::operation_attributes_t const&, ttnn::operations::binary_ng::BinaryNgDeviceOperation::tensor_args_t const&)
 --- /home/ubuntu/tt-metal/build/lib/_ttnncpp.so(+0xb63256) [0x7efd9777e256]
 --- /home/ubuntu/tt-metal/build/lib/_ttnncpp.so(+0xb62f0c) [0x7efd9777df0c]
 --- /home/ubuntu/tt-metal/build/lib/_ttnncpp.so(+0xb421af) [0x7efd9775d1af]
 --- ttnn::operations::binary::BinaryOperation<(ttnn::operations::binary::BinaryOpType)2>::invoke(tt::tt_metal::Tensor const&, tt::tt_metal::Tensor const&, std::optional<tt::tt_metal::DataType const> const&, std::optional<tt::tt_metal::MemoryConfig> const&, std::optional<tt::tt_metal::Tensor> const&, std::span<ttnn::operations::unary::BasicUnaryWithParam<float, int, unsigned int> const, 18446744073709551615ul>, std::span<ttnn::operations::unary::BasicUnaryWithParam<float, int, unsigned int> const, 18446744073709551615ul>, std::span<ttnn::operations::unary::BasicUnaryWithParam<float, int, unsigned int> const, 18446744073709551615ul>, std::optional<bool> const&)
 --- /home/ubuntu/tt-metal/python_env/lib/python3.10/site-packages/ttml/_ttml.cpython-310-x86_64-linux-gnu.so(_ZNK4ttnn10decorators22registered_operation_tIXtlN7reflect6v1_2_512fixed_stringIcLm14EEEtlA15_cLc116ELc116ELc110ELc110ELc58ELc58ELc109ELc117ELc108ELc116ELc105ELc112ELc108ELc121EEEENS_10operations6binary15BinaryOperationILNS8_12BinaryOpTypeE2EEEE16invoke_compositeIJRN2tt8tt_metal6TensorERKSG_EEEDaDpOT_+0x8d) [0x7efd96ab02fd]
 --- /home/ubuntu/tt-metal/python_env/lib/python3.10/site-packages/ttml/_ttml.cpython-310-x86_64-linux-gnu.so(_ZNK4ttnn10decorators22registered_operation_tIXtlN7reflect6v1_2_512fixed_stringIcLm14EEEtlA15_cLc116ELc116ELc110ELc110ELc58ELc58ELc109ELc117ELc108ELc116ELc105ELc112ELc108ELc121EEEENS_10operations6binary15BinaryOperationILNS8_12BinaryOpTypeE2EEEE13traced_invokeIJRN2tt8tt_metal6TensorERKSG_EEEDaDpOT_+0x101) [0x7efd96ab0081]
 --- /home/ubuntu/tt-metal/python_env/lib/python3.10/site-packages/ttml/_ttml.cpython-310-x86_64-linux-gnu.so(+0x1c51e8) [0x7efd96aad1e8]
 --- ttml::autograd::Tensor::backward(bool)
 --- ttml::models::common::transformer::memory_efficient_runner<ttml::modules::ModuleBase&>(ttml::modules::ModuleBase&, std::shared_ptr<ttml::autograd::Tensor> const&, std::shared_ptr<ttml::autograd::Tensor> const&)::{lambda()#3}::operator()() const
 --- ttml::autograd::Tensor::backward(bool)
 --- /home/ubuntu/tt-metal/python_env/lib/python3.10/site-packages/ttml/_ttml.cpython-310-x86_64-linux-gnu.so(+0xff3c9) [0x7efd969e73c9]
 --- /home/ubuntu/tt-metal/python_env/lib/python3.10/site-packages/ttml/_ttml.cpython-310-x86_64-linux-gnu.so(+0x149f10) [0x7efd96a31f10]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x807) [0x5585a90a75a7]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x259be6) [0x5585a918bbe6]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(PyEval_EvalCode+0x86) [0x5585a918bab6]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x25f1cd) [0x5585a91911cd]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x18b309) [0x5585a90bd309]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x6c0) [0x5585a90a7460]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x1a7210) [0x5585a90d9210]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x2798) [0x5585a90a9538]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x1a7210) [0x5585a90d9210]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x2798) [0x5585a90a9538]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x1a7210) [0x5585a90d9210]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x27713f) [0x5585a91a913f]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x195d6b) [0x5585a90c7d6b]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x807) [0x5585a90a75a7]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyFunction_Vectorcall+0x7c) [0x5585a90bd0ac]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x6c0) [0x5585a90a7460]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyFunction_Vectorcall+0x7c) [0x5585a90bd0ac]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x807) [0x5585a90a75a7]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x1981d1) [0x5585a90ca1d1]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(PyObject_Call+0x122) [0x5585a90cae72]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x29f4) [0x5585a90a9794]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x1981d1) [0x5585a90ca1d1]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x1999) [0x5585a90a8739]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x1a7210) [0x5585a90d9210]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x2798) [0x5585a90a9538]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x1a7210) [0x5585a90d9210]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x2798) [0x5585a90a9538]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x1a7210) [0x5585a90d9210]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x2798) [0x5585a90a9538]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x1a7210) [0x5585a90d9210]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x2798) [0x5585a90a9538]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x1a7210) [0x5585a90d9210]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x2798) [0x5585a90a9538]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x1a7210) [0x5585a90d9210]
 --- /usr/lib/python3.10/lib-dynload/_asyncio.cpython-310-x86_64-linux-gnu.so(+0x928e) [0x7efdd651e28e]
 --- /usr/lib/python3.10/lib-dynload/_asyncio.cpython-310-x86_64-linux-gnu.so(+0x90a4) [0x7efdd651e0a4]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyObject_MakeTpCall+0x25b) [0x5585a90b312b]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x2c7f0a) [0x5585a91f9f0a]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x17e21f) [0x5585a90b021f]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x69cc) [0x5585a90ad76c]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyFunction_Vectorcall+0x7c) [0x5585a90bd0ac]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x807) [0x5585a90a75a7]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyFunction_Vectorcall+0x7c) [0x5585a90bd0ac]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x807) [0x5585a90a75a7]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyFunction_Vectorcall+0x7c) [0x5585a90bd0ac]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x807) [0x5585a90a75a7]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyFunction_Vectorcall+0x7c) [0x5585a90bd0ac]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x807) [0x5585a90a75a7]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyFunction_Vectorcall+0x7c) [0x5585a90bd0ac]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x807) [0x5585a90a75a7]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x1981d1) [0x5585a90ca1d1]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x56f3) [0x5585a90ac493]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x259be6) [0x5585a918bbe6]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(PyEval_EvalCode+0x86) [0x5585a918bab6]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x25f1cd) [0x5585a91911cd]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x18b309) [0x5585a90bd309]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x6c0) [0x5585a90a7460]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyFunction_Vectorcall+0x7c) [0x5585a90bd0ac]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyEval_EvalFrameDefault+0x6c0) [0x5585a90a7460]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_PyFunction_Vectorcall+0x7c) [0x5585a90bd0ac]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(+0x274b3d) [0x5585a91a6b3d]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(Py_RunMain+0x128) [0x5585a91a58f8]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(Py_BytesMain+0x2d) [0x5585a917fa8d]
 --- /lib/x86_64-linux-gnu/libc.so.6(+0x29d90) [0x7efdd7a27d90]
 --- /lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0x80) [0x7efdd7a27e40]
 --- /home/ubuntu/tt-metal/python_env/bin/python3(_start+0x25) [0x5585a917f985]


In [None]:
!tt-smi -r

In [None]:
All 130 parameters were successfully loaded and used.
2025-11-27 09:04:42.192 | info     |            Test | Small tensor algorithm selected (softmax_backward_w_small.cpp:18)
2025-11-27 09:04:53.830 | critical |          Always | Out of Memory: Not enough space to allocate 525336576 B DRAM buffer across 12 banks, where each bank needs to store 43778048 B, but bank size is only 1071181792 B (assert.hpp:103)

In [None]:
torch_model.model.embed_tokens.weight.data.numel() * 2 / 1024 / 1024

In [None]:
525336576 / 1024 / 1024

In [None]:
43778048 / 1024 / 1024

In [None]:
525336576 / 43778048

In [None]:
1071181792 / 1024 / 1024

In [None]:
total_steps

In [None]:
X.shape()

In [None]:
training_config.validation_batch_size

In [None]:
from matplotlib import pyplot as plt

In [None]:
print("Training completed!")

# Plot training curves
print("Plotting training curves...")
fig, axs = plt.subplots(1, 1, figsize=(10, 5))
axs.plot(train_losses, color="blue", label="Train Loss")
axs.plot(
    np.arange(0, len(val_losses)) * training_config.eval_every,
    val_losses,
    color="orange",
    label="Val Loss",
)
axs.set_title("Training Loss")
axs.set_xlabel("Steps")
axs.set_ylabel("Loss")
axs.legend()
plt.savefig("training_curves.png")
plt.show()