In [1]:
import os
import time
import random
import copy
import sys

import numpy as np
import torch

sys.path.append("/jet/home/azhang19/stat 260/STAT-260-Final-Project-Randomized-Linear-Algebra-Transformer")
from RLALLaMA3.LLaMA3 import ModelArgs
from RLALLaMA3.utils import (
    linear_warmup_cosine_decay_multiplicative,
    name_args,
    transformer_forward_pass,
    Args,
)
from RLALLaMA3.tasks import single_answer_seq_loss, get_dataset

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"

#import torch._dynamo
#torch._dynamo.config.suppress_errors = True

time_stamp = time.strftime("%Y-%m-%d %H-%M-%S", time.localtime())
print(f"Time stamp: {time_stamp}")

Time stamp: 2025-04-01 18-31-36


In [2]:
# Define the arguments
##args = parse_args()
args = Args(
    # Training
    standard_lr=3.16e-3,
    standard_epoch=80000,
    standard_warmup_steps=4000,
    batch_size=768,
    min_lr=1e-4,
    grad_clip_max_norm=1.0,
    use_amp=True,
    use_compile=False,

    # Data
    task="number_add",
    max_level=20,
    random_seq_len=True,
    number_range=(0, 99),

    # Model
    dim=32,
    n_layers=2,
    n_heads=4,
    hidden_dim=112,

    # Save
    save_path="/accounts/grad/zhangyunzhe2023/Neural-ODE/ckpt",
    final_save_path="/accounts/grad/zhangyunzhe2023/Neural-ODE/ckpt_final",
)

print(args, end="\n\n")

Args Configuration:

Training Parameters:
  standard_lr:        3.2e-03
  standard_epoch:     80000
  standard_warmup_steps: 4000
  batch_size:         768
  min_lr:             1.0e-04
  grad_clip_max_norm: 1.0
  use_amp:            True
  use_compile:       False

Data Parameters:
  task:              number_add
  max_level:         20
  random_seq_len:    True
  number_range:      (0, 99)

Model Architecture Parameters:
  dim:               32
  n_layers:          2
  n_heads:           4
  hidden_dim:        112

Save Path Parameters:
  save_path:         /accounts/grad/zhangyunzhe2023/Neural-ODE/ckpt
  final_save_path:   /accounts/grad/zhangyunzhe2023/Neural-ODE/ckpt_final



In [3]:
# Prepare the data
dataset, collate_fn, vocab_size, max_seq_len = get_dataset(args.task,
                                                           args.max_level,
                                                           args.random_seq_len,
                                                           args.number_range,
                                                           nested_tensor=False,
                                                           pad_to_longest=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn,
                                         num_workers=torch.get_num_threads(), pin_memory=True)

In [4]:
def mean_seq_len(dataloader, num_samples=100):
    """
    Calculate the mean sequence length of the dataset.
    """
    total_len = 0
    num_samples = min(num_samples, len(dataloader.dataset))
    for i, x in enumerate(dataloader):
        if i >= num_samples:
            break
        total_len += x[1].float().mean().item()
    return total_len / num_samples

mean_len = mean_seq_len(dataloader)
print(f"Mean sequence length: {mean_len}")
print(f"Max sequence length: {max_seq_len}")

Mean sequence length: 34.37289093017578
Max sequence length: 66


In [5]:
# Prepare the model

transformer_args = ModelArgs(
    dim=args.dim,
    n_layers=args.n_layers,
    n_heads=args.n_heads,
    hidden_dim=args.hidden_dim,
    vocab_size=vocab_size,
    norm_eps=1e-5,
    rope_theta=500000,
    max_seq_len=max_seq_len,

    sketch_mode = 'rademacher',
    attention_qkv_sketch_size = 16,
    attention_out_sketch_size = 16,
    feedforward_sketch_size_in = 16,
    feedforward_sketch_size_out = 64,
    deterministic = False
)

from RLALLaMA3.LLaMA3 import Transformer
model = Transformer(params=transformer_args)

model = model.to(device).train()

In [6]:
# Training configuration
standard_lr = args.standard_lr / 512
standard_epoch = args.standard_epoch * 512
standard_warmup_steps = args.standard_warmup_steps * 512
batch_size = args.batch_size

lr = standard_lr * batch_size
warmup_steps = standard_warmup_steps // batch_size
epochs = standard_epoch // batch_size

print("Derived Parameters:")
print(f"lr: {lr}")
print(f"warmup_steps: {warmup_steps}")
print(f"epochs: {epochs}")
print(f"grad_clip_max_norm: {args.grad_clip_max_norm}", end="\n\n")

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, fused=True)
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
            lr_lambda=lambda step: linear_warmup_cosine_decay_multiplicative(step, warmup_steps, epochs, args.min_lr))

scaler = torch.amp.GradScaler(device, enabled=args.use_amp)

Derived Parameters:
lr: 0.00474
warmup_steps: 2666
epochs: 53333
grad_clip_max_norm: 1.0



In [7]:
# Save the model and arguments

def save_record(path, model, record, args, time_stamp, extra_info=None):
    dict_name, file_name = name_args(args, "_")

    os.makedirs(f"{path}/{dict_name}", exist_ok=True)
    file_name = file_name + f"_{time_stamp}"
    if extra_info is not None:
        file_name += f"_{extra_info}"
    
    record_dict = {
        "model": model,
        "record": record,
        "args": args,
        "time_stamp": time_stamp,
    }
        
    torch.save(record_dict, f"{path}/{dict_name}/{file_name}.pth")

In [8]:
# Backwards pass
def backward_pass(model, loss, optimizer, scaler, scheduler, grad_clip_max_norm):
    optimizer.zero_grad(set_to_none=True)
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_max_norm)
    scaler.step(optimizer)
    scaler.update()
    scheduler.step()

In [9]:
@torch.compile(disable=not args.use_compile)
def train_step(model, train_data, mean_len, optimizer, scheduler, scaler, args):
    device = train_data[0].device
    
    with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=args.use_amp):
        tokens, lengths, ans_starts, ans_lengths = train_data
        pred = transformer_forward_pass(tokens[:, :-1], model)
        
        result = single_answer_seq_loss(pred, tokens, lengths, ans_starts, ans_lengths)
        GPT_loss, full_seq_acc, ans_region_acc, ans_char_acc = result
        # Normalize the GPT loss by the batch size but not the sequence length
        GPT_loss = GPT_loss / args.batch_size
        total_loss = GPT_loss
        total_loss_for_backward = total_loss / mean_len
    
    if torch.isnan(total_loss) or torch.isinf(total_loss):# or (total_loss > smoothed_loss * 1.1):
        return [total_loss]
    
    with torch.no_grad():
        safe_params = [copy.deepcopy(i.state_dict()) for i in [model, optimizer, scheduler]]

    backward_pass(model, total_loss_for_backward, optimizer, scaler, scheduler, args.grad_clip_max_norm)
    
    data = [GPT_loss, 0, total_loss, full_seq_acc, ans_region_acc, ans_char_acc]
    
    data += [0, 0]
    
    with torch.inference_mode():
        data = torch.tensor(data).cpu().numpy()

    return data, safe_params

In [None]:
record = np.zeros((epochs, 9))
num_NaNs = 0
smoothed_loss = None

safe_params = [copy.deepcopy(i.state_dict()) for i in [model, optimizer, scheduler]]

epoch = 0

for train_data in dataloader:
    if epoch >= epochs:
        break

    train_data = [x.to(device) for x in train_data]

    t0 = time.time()

    result = train_step(model, train_data, mean_len, optimizer, scheduler, scaler, args)

    if len(result) == 1:
        data = result
        total_loss = data[0]
        num_NaNs += 1
        print(f"Epoch: {epoch}")
        print("Instability detected")
        print(f"Total Loss: {total_loss.item()}\n")
        model.load_state_dict(safe_params[0])
        optimizer.load_state_dict(safe_params[1])
        scheduler.load_state_dict(safe_params[2])
        optimizer.zero_grad(set_to_none=True)
        continue

    data, safe_params = result
    smoothed_loss = 0.99 * smoothed_loss + 0.01 * data[2].item() if smoothed_loss is not None else data[2].item()
    epoch = epoch + 1

    record[epoch - 1, :-1] = data
    record[epoch - 1, -1] = num_NaNs
        
    names = ["GPT loss", "Energy Reg", "Total_loss", "Full Seq Acc",
             "Ans Region Acc", "Ans Char Acc", "Mean Steps", "Std Steps", "Num NaNs"]

    print(f"Epoch: {epoch}")
    for name, value in zip(names, data):
        print(f"{name}: {value}")
    print(f"Smoothed Loss: {smoothed_loss}")
    print(f"Time: {time.time() - t0}\n")

    if epoch % 100 == 0:
        #save_record(args.save_path, model, record, args, time_stamp)
        pass

Epoch: 1
GPT loss: 94.88225555419922
Energy Reg: 0.0
Total_loss: 94.88225555419922
Full Seq Acc: 0.07002757489681244
Ans Region Acc: 0.0
Ans Char Acc: 0.0813223123550415
Mean Steps: 0.0
Std Steps: 0.0
Smoothed Loss: 94.88225555419922
Time: 0.4295940399169922

Epoch: 2
GPT loss: 95.43161010742188
Energy Reg: 0.0
Total_loss: 95.43161010742188
Full Seq Acc: 0.11606083065271378
Ans Region Acc: 0.0
Ans Char Acc: 0.0723208412528038
Mean Steps: 0.0
Std Steps: 0.0
Smoothed Loss: 94.88774909973144
Time: 0.07771754264831543

Epoch: 3
GPT loss: 95.12255859375
Energy Reg: 0.0
Total_loss: 95.12255859375
Full Seq Acc: 0.11765364557504654
Ans Region Acc: 0.0
Ans Char Acc: 0.06352324783802032
Mean Steps: 0.0
Std Steps: 0.0
Smoothed Loss: 94.89009719467161
Time: 0.048958539962768555

Epoch: 4
GPT loss: 95.12603759765625
Energy Reg: 0.0
Total_loss: 95.12603759765625
Full Seq Acc: 0.12103992700576782
Ans Region Acc: 0.0
Ans Char Acc: 0.07344262301921844
Mean Steps: 0.0
Std Steps: 0.0
Smoothed Loss: 94.89

Epoch: 100
GPT loss: 86.0082015991211
Energy Reg: 0.0
Total_loss: 86.0082015991211
Full Seq Acc: 0.25461894273757935
Ans Region Acc: 0.0
Ans Char Acc: 0.0398288331925869
Mean Steps: 0.0
Std Steps: 0.0
Smoothed Loss: 92.43370221371526
Time: 0.044524192810058594

Epoch: 101
GPT loss: 86.84423828125
Energy Reg: 0.0
Total_loss: 86.84423828125
Full Seq Acc: 0.2552184760570526
Ans Region Acc: 0.0
Ans Char Acc: 0.030154049396514893
Mean Steps: 0.0
Std Steps: 0.0
Smoothed Loss: 92.37780757439062
Time: 0.04427194595336914



KeyboardInterrupt: 

In [None]:
#save_record(args.final_save_path, model, record, args, time_stamp, smoothed_loss)