# Pretraining GPT2

In [3]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,   # Vocabulary size
    "context_length": 256, # Shortened context length (orig: 1024)
    "emb_dim": 768,        # Embedding dimension
    "n_heads": 12,         # Number of attention heads
    "n_layers": 12,        # Number of layers
    "drop_rate": 0.0,      # Dropout rate
    "qkv_bias": False      # Query-key-value bias
}

## Torch Implementation

Import the libraries, and now we have separated the `GPTModel` implementation to a new file to make the notebook smaller, of course.

In [4]:
import torch
import tiktoken
from scripts.gpt2_model import GPTModel

torch.manual_seed(123)

<torch._C.Generator at 0x77375cfdb610>

We will just instantiate the model. Since our config already has 0 drop out, we don't need to turn on `eval` mode for the model.

In [5]:
model = GPTModel(GPT_CONFIG_124M)

Define some helpers. 

In [6]:
from scripts.generate import generate_text_simple

def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
    return encoded_tensor

def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0) # remove batch dimension
    return tokenizer.decode(flat.tolist())


Let's test the model at the moment. Again we expect it to generate some garbage.

In [7]:
start_context = "Every effort moves you"
tokenizer = tiktoken.get_encoding("gpt2")

token_ids = generate_text_simple(
  model=model,
  idx=text_to_token_ids(start_context, tokenizer),
  max_new_tokens=10,
  context_size=GPT_CONFIG_124M["context_length"]
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))

Output text:
 Every effort moves you rentingetic wasnم refres RexMeCHicular stren


## Loss Calculation

we have 2 batches here, so we add the corresponding "next" token into the targets while maintaining a consistent context length. this gives a "sliding window" effect.

In [10]:
inputs = torch.tensor([[16833, 3626, 6100],   # ["every effort moves",
                       [40,    1107, 588]])   #  "I really like"]

targets = torch.tensor([[3626, 6100, 345  ],  # [" effort moves you",
                        [1107,  588, 11311]]) #  " really like chocolate"]

inputs, targets

(tensor([[16833,  3626,  6100],
         [   40,  1107,   588]]),
 tensor([[ 3626,  6100,   345],
         [ 1107,   588, 11311]]))

In [11]:
torch.manual_seed(123)

logits = model(inputs)

probas = torch.softmax(logits, dim=-1) # Probability of each token in vocabulary
print(probas.shape) # Shape: (batch_size, num_tokens, vocab_size)


torch.Size([2, 3, 50257])


Probas contains for each token, there are vocab_size probabilities in the last dimension. We want to select the token ID (the index) of the highest number

In [8]:

probas

tensor([[[1.8849e-05, 1.5172e-05, 1.1687e-05,  ..., 2.2409e-05,
          6.9776e-06, 1.8776e-05],
         [9.1569e-06, 1.0062e-05, 7.8786e-06,  ..., 2.9090e-05,
          6.0103e-06, 1.3571e-05],
         [2.9877e-05, 8.8507e-06, 1.5741e-05,  ..., 3.5456e-05,
          1.4094e-05, 1.3526e-05]],

        [[1.2561e-05, 2.0538e-05, 1.4332e-05,  ..., 1.0389e-05,
          3.4784e-05, 1.4239e-05],
         [7.2731e-06, 1.7864e-05, 1.0565e-05,  ..., 2.1206e-05,
          1.1390e-05, 1.5559e-05],
         [2.9496e-05, 3.3605e-05, 4.1029e-05,  ..., 6.5249e-06,
          5.8203e-05, 1.3698e-05]]])

For each word, we find an individual token to succeed the text

In [9]:
token_ids = torch.argmax(probas, dim=-1, keepdim=True)
print("Token IDs:\n", token_ids)

Token IDs:
 tensor([[[16657],
         [  339],
         [42826]],

        [[49906],
         [29669],
         [41751]]])


But too bad! we're nowhere close to our target!

In [10]:
print(f"Targets batch 1: {token_ids_to_text(targets[0], tokenizer)}")
print(f"Outputs batch 1: {token_ids_to_text(token_ids[0].flatten(), tokenizer)}")

print(f"Targets batch 2: {token_ids_to_text(targets[1], tokenizer)}")
print(f"Outputs batch 2: {token_ids_to_text(token_ids[1].flatten(), tokenizer)}")

Targets batch 1:  effort moves you
Outputs batch 1:  Armed heNetflix
Targets batch 2:  really like chocolate
Outputs batch 2:  pressuring empoweredfaith


So what did this probabilities look like for the _real_ indices for targets? You'll find them to be small. The goal now is to increase these probabilities.

In [11]:
text_idx = 0
target_probas_1 = probas[text_idx, [0, 1, 2], targets[text_idx]]
print("Text 1:", target_probas_1)

text_idx = 1
target_probas_2 = probas[text_idx, [0, 1, 2], targets[text_idx]]
print("Text 2:", target_probas_2)

Text 1: tensor([7.4541e-05, 3.1061e-05, 1.1563e-05])
Text 2: tensor([1.0337e-05, 5.6776e-05, 4.7559e-06])


Take the log for individual probabilities and then concatenate them.

In [12]:
# Compute logarithm of all token probabilities
log_probas = torch.log(torch.cat((target_probas_1, target_probas_2)))
print(log_probas)

tensor([ -9.5042, -10.3796, -11.3677, -11.4798,  -9.7764, -12.2561])


Take the average and negate to make it positive.

In [13]:
# Calculate the average probability for each token
avg_log_probas = torch.mean(log_probas)
print(avg_log_probas)

tensor(-10.7940)


In [14]:
neg_avg_log_probas = avg_log_probas * -1
print(neg_avg_log_probas)


tensor(10.7940)


The above is basically what we call "cross entropy loss"

## Cross Entropy

Lets just use torch cross_entropy to do everything we did above.

In [15]:
# Logits have shape (batch_size, num_tokens, vocab_size)
print("Logits shape:", logits.shape)

# Targets have shape (batch_size, num_tokens)
print("Targets shape:", targets.shape)

Logits shape: torch.Size([2, 3, 50257])
Targets shape: torch.Size([2, 3])


In [16]:
logits_flat = logits.flatten(0, 1)
targets_flat = targets.flatten()

print("Flattened logits:", logits_flat.shape)
print("Flattened targets:", targets_flat.shape)

Flattened logits: torch.Size([6, 50257])
Flattened targets: torch.Size([6])


In [17]:
loss = torch.nn.functional.cross_entropy(logits_flat, targets_flat)
print(loss)


tensor(10.7940)


We can calculate the perplexity too.

In [18]:
perplexity = torch.exp(loss)
print(perplexity)

tensor(48725.8203)


## Data Stuff

In [36]:
with open("data/the-verdict.txt", "r") as f:
  text_data = f.read()

text_data[:99]

'I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no '

In [37]:
total_characters = len(text_data)
total_tokens = len(tokenizer.encode(text_data))

print("Characters:", total_characters)
print("Tokens:", total_tokens)

Characters: 20479
Tokens: 5145


In [38]:
from scripts.prepare_data import create_dataloader_v1

# Train/validation ratio
train_ratio = 0.90
split_idx = int(train_ratio * len(text_data))
train_data = text_data[:split_idx]
val_data = text_data[split_idx:]

torch.manual_seed(123)

train_loader = create_dataloader_v1(
    train_data,
    batch_size=2,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    drop_last=True,
    shuffle=True,
    num_workers=0
)

val_loader = create_dataloader_v1(
    val_data,
    batch_size=2,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    drop_last=False,
    shuffle=False,
    num_workers=0
)

In [39]:
# Sanity check

if total_tokens * (train_ratio) < GPT_CONFIG_124M["context_length"]:
    print("Not enough tokens for the training loader. "
          "Try to lower the `GPT_CONFIG_124M['context_length']` or "
          "increase the `training_ratio`")

if total_tokens * (1-train_ratio) < GPT_CONFIG_124M["context_length"]:
    print("Not enough tokens for the validation loader. "
          "Try to lower the `GPT_CONFIG_124M['context_length']` or "
          "decrease the `training_ratio`")

In [40]:
print("Train loader:")
for x, y in train_loader:
    print(x.shape, y.shape)

print("\nValidation loader:")
for x, y in val_loader:
    print(x.shape, y.shape)
    

Train loader:
torch.Size([2, 256]) torch.Size([2, 256])
torch.Size([2, 256]) torch.Size([2, 256])
torch.Size([2, 256]) torch.Size([2, 256])
torch.Size([2, 256]) torch.Size([2, 256])
torch.Size([2, 256]) torch.Size([2, 256])
torch.Size([2, 256]) torch.Size([2, 256])
torch.Size([2, 256]) torch.Size([2, 256])
torch.Size([2, 256]) torch.Size([2, 256])
torch.Size([2, 256]) torch.Size([2, 256])

Validation loader:
torch.Size([2, 256]) torch.Size([2, 256])


In [41]:
train_tokens = 0
for input_batch, target_batch in train_loader:
    train_tokens += input_batch.numel()

val_tokens = 0
for input_batch, target_batch in val_loader:
    val_tokens += input_batch.numel()

print("Training tokens:", train_tokens)
print("Validation tokens:", val_tokens)
print("All tokens:", train_tokens + val_tokens)

Training tokens: 4608
Validation tokens: 512
All tokens: 5120


Given the input and targets, we forward pass through the model for the input batch. Then given the logits, we flatten on the dimensions and calculate the cross entropy loss.

In [25]:
def calc_loss_batch(input_batch, target_batch, model):
    input_batch, target_batch = input_batch, target_batch
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())

    return loss


In [26]:

def calc_loss_loader(data_loader, model, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        # Reduce the number of batches to match the total number of batches in the data loader
        # if num_batches exceeds the number of batches in the data loader
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches

In [None]:
torch.manual_seed(123) # For reproducibility due to the shuffling in the data loader

train_loss = calc_loss_loader(train_loader, model)
val_loss = calc_loss_loader(val_loader, model)

print("Training loss:", train_loss)
print("Validation loss:", val_loss)

Training loss: 10.98758347829183
Validation loss: 10.98110580444336


## Training!!

In [28]:
def train_model_simple(model, train_loader, val_loader, optimizer, num_epochs,
                       eval_freq, eval_iter, start_context, tokenizer):
    # Initialize lists to track losses and tokens seen
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step = 0, -1

    # Main training loop
    for epoch in range(num_epochs):
        
        for input_batch, target_batch in train_loader:
            optimizer.zero_grad() # Reset loss gradients from previous batch iteration
            loss = calc_loss_batch(input_batch, target_batch, model)
            loss.backward() # Calculate loss gradients
            optimizer.step() # Update model weights using loss gradients
            tokens_seen += input_batch.numel()
            global_step += 1

            # Optional evaluation step
            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_loader, val_loader, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)
                print(f"Ep {epoch+1} (Step {global_step:06d}): "
                      f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")

        # Print a sample text after each epoch
        generate_and_print_sample(
            model, tokenizer, start_context
        )

    return train_losses, val_losses, track_tokens_seen


def evaluate_model(model, train_loader, val_loader, eval_iter):
    train_loss = calc_loss_loader(train_loader, model, num_batches=eval_iter)
    val_loss = calc_loss_loader(val_loader, model, num_batches=eval_iter)
    return train_loss, val_loss


def generate_and_print_sample(model, tokenizer, start_context):
    context_size = model.pos_emb.weight.shape[0]
    encoded = text_to_token_ids(start_context, tokenizer)
    token_ids = generate_text_simple(
        model=model, idx=encoded,
        max_new_tokens=50, context_size=context_size
    )
    decoded_text = token_ids_to_text(token_ids, tokenizer)
    print(decoded_text.replace("\n", " "))  # Compact print format

In [29]:
from scripts.perf_timer import PerfTimer

torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)

num_epochs = 10

timer = PerfTimer()

timer.start()
train_losses, val_losses, tokens_seen = train_model_simple(
    model, train_loader, val_loader, optimizer,
    num_epochs=num_epochs, eval_freq=5, eval_iter=5,
    start_context="Every effort moves you", tokenizer=tokenizer
)
timer.stop()

print(f"Took this long to train: {timer.elapsed_ms()} ms")


Ep 1 (Step 000000): Train loss 9.794, Val loss 9.909
Ep 1 (Step 000005): Train loss 8.038, Val loss 8.324
Every effort moves you,,,,,,,,,,,,,,.                                   
Ep 2 (Step 000010): Train loss 6.598, Val loss 7.041
Ep 2 (Step 000015): Train loss 5.996, Val loss 6.575
Every effort moves you, and, and, and, and, and, and, and. ", and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and,
Ep 3 (Step 000020): Train loss 5.558, Val loss 6.444
Ep 3 (Step 000025): Train loss 5.956, Val loss 7.675
Every effort moves you.                                                 
Ep 4 (Step 000030): Train loss 4.243, Val loss 6.283
Ep 4 (Step 000035): Train loss 4.304, Val loss 6.210
Every effort moves you.               "I--and's--and it's had been, and I had been the, and I had been the honour of the, and he had been, I had
Ep 5 (Step 000040): Train loss 3.443, Val loss 6.196
Every effort moves you know it was not a littleI glanced after him, and I had been his eye

In [29]:
tokenizer = tiktoken.get_encoding("gpt2")

token_ids = generate_text_simple(
    model=model,
    idx=text_to_token_ids("Every effort moves you", tokenizer),
    max_new_tokens=25,
    context_size=GPT_CONFIG_124M["context_length"]
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))

Output text:
 Every effort moves you rentingetic wasnم refres RexMeCHicular stren Mortgage TT remember gard ACTIONSussedOND Land Engeleddedemate breaths proxies GalaxyForm


## TTNN

In [1]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,   # Vocabulary size
    "context_length": 256, # Shortened context length (orig: 1024)
    "emb_dim": 768,        # Embedding dimension
    "n_heads": 12,         # Number of attention heads
    "n_layers": 12,        # Number of layers
    "drop_rate": 0.0,      # Dropout rate
    "qkv_bias": False      # Query-key-value bias
}

In [2]:
import ttnn
import tiktoken
import torch
from torch import nn
from scripts.text_helpers import text_to_token_ids, token_ids_to_text

torch.manual_seed(123)

device = None

2025-05-11 04:48:20.124 | DEBUG    | ttnn.library_tweaks:prepare_dir_as_metal_home:54 - Existing installation of 0.57.0rc60+any detected
2025-05-11 04:48:20.148 | DEBUG    | ttnn:<module>:83 - Initial ttnn.CONFIG:
Config{cache_path=/home/avgdev/.cache/ttnn,model_cache_path=/home/avgdev/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_should_raise_exception=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


## Initialize the GPT Model

In [3]:
if device:
  ttnn.close_device(device)

from scripts.gpt2_model_ttnn import GPTModel_ttnn

device_id = 0
device = ttnn.open_device(device_id=device_id)

model_ttnn = GPTModel_ttnn(GPT_CONFIG_124M, device)


                 Device | INFO     | Opening user mode device driver
[32m2025-05-11 04:48:22.049[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled

[32m2025-05-11 04:48:22.059[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled
[32m2025-05-11 04:48:22.061[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Harvesting mask for chip 0 is 0x200 (physical layout: 0x1, logical: 0x200, simulated harvesting mask: 0x0).
[32m2025-05-11 04:48:22.061[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled
[32m2025-05-11 04:48:22.061[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected PCI devices: [0]
[32m2025-05-11 04:48:22.061[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using local chip ids: 

New chip! We now have 1 chips
Chip initialization complete (found )
Chip initializing complete...
 ARC

 [4/4] DRAM

 [16/16] ETH

 CPU

Chip detection complete (found )


In [4]:
from scripts.generate_ttnn import generate_text_simple_ttnn

start_context = "Every effort moves you"
tokenizer = tiktoken.get_encoding("gpt2")

inputs_ttnn = ttnn.from_torch(
  text_to_token_ids(start_context, tokenizer),
  layout=ttnn.TILE_LAYOUT,
  dtype=ttnn.uint32,
  device=device
)
token_ids = generate_text_simple_ttnn(
  model=model_ttnn,
  idx_ttnn=inputs_ttnn,
  max_new_tokens=10,
  context_size=GPT_CONFIG_124M["context_length"],
  device=device
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))

Output text:
 Every effort moves you collagen HTTPSioliries Simone represented Compos Harrison circumstanceemp


In [5]:
from scripts.prepare_data import create_dataloader_v1

with open("data/the-verdict.txt", "r") as f:
  text_data = f.read()

# Train/validation ratio
train_ratio = 0.90
split_idx = int(train_ratio * len(text_data))
train_data = text_data[:split_idx]
val_data = text_data[split_idx:]

torch.manual_seed(123)

train_loader = create_dataloader_v1(
    train_data,
    batch_size=2,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    drop_last=True,
    shuffle=True,
    num_workers=0
)

val_loader = create_dataloader_v1(
    val_data,
    batch_size=2,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    drop_last=False,
    shuffle=False,
    num_workers=0
)

train_tokens = 0
for input_batch, target_batch in train_loader:
    train_tokens += input_batch.numel()

val_tokens = 0
for input_batch, target_batch in val_loader:
    val_tokens += input_batch.numel()


## Cross Entropy for TTNN

In [15]:
def softmax_ttnn(x_ttnn):
  return ttnn.div(
    ttnn.exp(x_ttnn),
    ttnn.sum(ttnn.exp(x_ttnn), dim=-1)
  )

In [32]:
def calc_loss_batch_ttnn(input_batch_ttnn, target_batch, model_ttnn, device):
    input_batch_ttnn, target_batch = input_batch_ttnn, target_batch
    logits_ttnn = model_ttnn(input_batch_ttnn)
    logits = ttnn.to_torch(logits_ttnn, device=device)

    loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())

    return loss

In [34]:
torch.manual_seed(123)

# every effort moves
inputs = torch.tensor([[16833, 3626, 6100]]) 

# effort moves you
targets = torch.tensor([[3626, 6100, 345]]) 

inputs_ttnn = ttnn.from_torch(
  inputs,
  dtype=ttnn.bfloat16,
  layout=ttnn.TILE_LAYOUT,
  device=device
)

loss_torch = calc_loss_batch_ttnn(inputs_ttnn, targets, model_ttnn, device)
loss_torch


TorchTensor(10.3125, dtype=torch.bfloat16)

In [42]:
def calc_loss_loader_ttnn(data_loader, model_ttnn, num_batches=None, device=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        # Reduce the number of batches to match the total number of batches in the data loader
        # if num_batches exceeds the number of batches in the data loader
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        input_batch_ttnn = ttnn.from_torch(input_batch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
        if i < num_batches:
            loss = calc_loss_batch_ttnn(input_batch_ttnn, target_batch, model_ttnn, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches

In [43]:
torch.manual_seed(123) # For reproducibility due to the shuffling in the data loader

train_loss = calc_loss_loader_ttnn(train_loader, model_ttnn, device=device)
val_loss = calc_loss_loader_ttnn(val_loader, model_ttnn, device=device)

print("Training loss:", train_loss)
print("Validation loss:", val_loss)

Training loss: 10.979166666666666
Validation loss: 10.9375


In [None]:
def evaluate_model_ttnn(model_ttnn, train_loader, val_loader, eval_iter, device):
    train_loss = calc_loss_loader_ttnn(train_loader, model_ttnn, num_batches=eval_iter, device=device)
    val_loss = calc_loss_loader_ttnn(val_loader, model_ttnn, num_batches=eval_iter, device=device)
    return train_loss, val_loss

def generate_and_print_sample_ttnn(model_ttnn, tokenizer, start_context, device):
    context_size = model_ttnn.pos_emb.weight.shape[0]
    encoded = text_to_token_ids(start_context, tokenizer)
    token_ids = generate_text_simple_ttnn(
      model=model_ttnn,
      idx_ttnn=encoded,
      max_new_tokens=50,
      context_size=context_size,
      device=device
    )
    decoded_text = token_ids_to_text(token_ids, tokenizer)
    print(decoded_text.replace("\n", " "))  # Compact print format

The tricky part...

In [45]:
def train_model_simple(model_ttnn, train_loader, val_loader, optimizer, num_epochs,
                       eval_freq, eval_iter, start_context, tokenizer, device):
  # Initialize lists to track losses and tokens seen
  train_losses, val_losses = [], []
  global_step = -1

  # Main training loop
  for epoch in range(num_epochs):
      
    for input_batch, target_batch in train_loader:
      # THIS IS IN TORCH!
      optimizer.zero_grad() # Reset loss gradients from previous batch iteration

      # Transfer the input batch to be in ttnn to perform the forward pass.
      # Loss will be calculated on the CPU afterwards.
      input_batch_ttnn = ttnn.from_torch(input_batch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
      loss = calc_loss_batch_ttnn(input_batch_ttnn, target_batch, model_ttnn, device=device)

      # 🚧 NOTE: All the weights, layers etc need to be updated
      loss.backward() # Calculate loss gradients
      optimizer.step() # Update model weights using loss gradients

      # 🚧 AT THIS POINT WE NEED TO UPDATE ALL THE INTERNAL TTNN WEIGHTS
      # Yeah, i know it sounds like double work.
      # Ex: model_ttnn.updateStuff()
      
      global_step += 1

      # Optional evaluation step
      if global_step % eval_freq == 0:
        train_loss, val_loss = evaluate_model(
            model, train_loader, val_loader, eval_iter)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print(f"Ep {epoch+1} (Step {global_step:06d}): "
              f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")

    # ✅ Print a sample text after each epoch
    generate_and_print_sample_ttnn(model_ttnn, tokenizer, start_context, device)

    return train_losses, val_losses

In [26]:
ttnn.close_device(device)

                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0
