In [1]:
import os
import time
import random
import copy
import sys

import numpy as np
import torch
from torchtnt.utils.flops import FlopTensorDispatchMode

sys.path.append("/jet/home/azhang19/stat 260/STAT-260-Final-Project-Randomized-Linear-Algebra-Transformer")
from RLALLaMA3.LLaMA3 import ModelArgs
from RLALLaMA3.utils import (
    linear_warmup_cosine_decay_multiplicative,
    name_args,
    transformer_forward_pass,
    Args,
)
from RLALLaMA3.tasks import single_answer_seq_loss, get_dataset

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"

#import torch._dynamo
#torch._dynamo.config.suppress_errors = True

time_stamp = time.strftime("%Y-%m-%d %H-%M-%S", time.localtime())
print(f"Time stamp: {time_stamp}")

Time stamp: 2025-04-03 20-04-16


In [13]:
from RLALLaMA3.RLACore import projection_sketch_mm
A = torch.randn(1, 1, 256, 256).to(device)
B = torch.randn(256, 256).to(device)
gt = A @ B
(projection_sketch_mm(A, B, 128) - gt).square().sum() / gt.square().sum()

tensor(1.9707)

In [2]:
# Define the arguments
##args = parse_args()
args = Args(
    # Training
    standard_lr=3.16e-3,
    standard_epoch=80000,
    standard_warmup_steps=4000,
    batch_size=1024,
    min_lr=1e-4,
    grad_clip_max_norm=1.0,
    use_amp=True,
    use_compile=True,
    # model_type uses default 'node'

    # Data
    task="number_add",
    max_level=40,
    random_seq_len=True,
    number_range=(0, 99), # Explicitly provided, overrides default_factory if needed

    # Model
    dim=256,
    n_layers=2,
    n_heads=4,
    hidden_dim=896,

    # RLA parameters (using defaults from the dataclass definition)
    sketch_mode='rademacher',
    deterministic=True,
    attention_qkv_sketch_size=48,
    attention_out_sketch_size=36,
    feedforward_sketch_size_in=48,
    feedforward_sketch_size_out=48,
    attention_score_sketch_size=24,
    attention_weighed_sum_sketch_size=24,

    # Save
    save_path="/accounts/grad/zhangyunzhe2023/Neural-ODE/ckpt",
    final_save_path="/accounts/grad/zhangyunzhe2023/Neural-ODE/ckpt_final",
)


print(args, end="\n\n")

Args Configuration:

Training Parameters:
  model_type:         node
  standard_lr:        3.2e-03
  standard_epoch:     80000
  standard_warmup_steps: 4000
  batch_size:         1024
  min_lr:             1.0e-04
  grad_clip_max_norm: 1.0
  use_amp:            True
  use_compile:        True

Data Parameters:
  task:               number_add
  max_level:          40
  random_seq_len:     True
  number_range:       (0, 99)

Model Architecture Parameters:
  dim:                256
  n_layers:           2
  n_heads:            4
  hidden_dim:         896

RLA Parameters:
  sketch_mode:        rademacher
  deterministic:      True
  attn_qkv_sketch:    48
  attn_out_sketch:    36
  ffn_in_sketch:      48
  ffn_out_sketch:     48
  attn_score_sketch:  24
  attn_sum_sketch:    24

Save Path Parameters:
  save_path:          /accounts/grad/zhangyunzhe2023/Neural-ODE/ckpt
  final_save_path:    /accounts/grad/zhangyunzhe2023/Neural-ODE/ckpt_final



In [3]:
# Prepare the data
dataset, collate_fn, vocab_size, max_seq_len = get_dataset(args.task,
                                                           args.max_level,
                                                           args.random_seq_len,
                                                           args.number_range,
                                                           nested_tensor=False,
                                                           pad_to_longest=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn,
                                         num_workers=torch.get_num_threads(), pin_memory=True)

In [4]:
def mean_seq_len(dataloader, num_samples=100):
    """
    Calculate the mean sequence length of the dataset.
    """
    total_len = 0
    num_samples = min(num_samples, len(dataloader.dataset))
    for i, x in enumerate(dataloader):
        if i >= num_samples:
            break
        total_len += x[1].float().mean().item()
    return total_len / num_samples

mean_len = mean_seq_len(dataloader)
print(f"Mean sequence length: {mean_len}")
print(f"Max sequence length: {max_seq_len}")

Mean sequence length: 63.691806640625
Max sequence length: 126


In [5]:
# Prepare the model

transformer_args = ModelArgs(
    dim=args.dim,
    n_layers=args.n_layers,
    n_heads=args.n_heads,
    hidden_dim=args.hidden_dim,
    vocab_size=vocab_size,
    norm_eps=1e-5,
    rope_theta=500000,
    max_seq_len=max_seq_len,

    sketch_mode = 'rademacher',
    attention_qkv_sketch_size = args.attention_qkv_sketch_size,
    attention_out_sketch_size = args.attention_out_sketch_size,
    feedforward_sketch_size_in = args.feedforward_sketch_size_in,
    feedforward_sketch_size_out = args.feedforward_sketch_size_out,
    attention_score_sketch_size = args.attention_score_sketch_size,
    attention_weighed_sum_sketch_size = args.attention_weighed_sum_sketch_size,
    deterministic = args.deterministic,
)

from RLALLaMA3.LLaMA3 import Transformer
model = Transformer(params=transformer_args)

model = model.to(device).train()

In [6]:
with torch.inference_mode():
    model.deterministic_mode(args.deterministic)
    model.eval()
    # Check the model
    toy_input = torch.randint(0, vocab_size, (args.batch_size, max_seq_len), device=device)
    with FlopTensorDispatchMode(model) as ftdm:
        res = model(toy_input)
    flops_0 = sum(i for i in ftdm.flop_counts[''].values())

    model.deterministic_mode(True)
    with FlopTensorDispatchMode(model) as ftdm:
        res = model(toy_input)
    flops_1 = sum(i for i in ftdm.flop_counts[''].values())
model.deterministic_mode(args.deterministic)
model.train()

print(f"Flops (Set): {flops_0}")
print(f"Flops (Deterministic): {flops_1}")
print(f"Flops (Ratio): {flops_1 / flops_0}")

Flops (Set): 16647192576
Flops (Deterministic): 16647192576
Flops (Ratio): 1.0


In [7]:
model = model.deterministic_mode(True)
with FlopTensorDispatchMode(model) as ftdm:
    res = model(toy_input)
print(sum(i for i in ftdm.flop_counts[''].values()))

262622674944


In [8]:
def rla_mm_flops_ratio(m, d, n, k):
    return k/m + k/n + k/d

def k_for_target_ratio(m, d, n, target_ratio):
    # Calculate the sum of reciprocals
    inv_sum = (1/n) + (1/m) + (1/d)
    if inv_sum <= 0: # Should not happen with positive dims
            return None
    # Solve for k: k = target_ratio / (1/n + 1/m + 1/d)
    k_needed = target_ratio / inv_sum
    # Return the ceiling as k must be an integer
    return k_needed

In [9]:
target_flop_ratio = 0.5 # Example target ratio (e.g., 50% reduction)

# Model dimensions
D = args.dim
H = args.hidden_dim
L = max_seq_len
N_h = args.n_heads
if N_h <= 0 or D % N_h != 0:
    print(f"Warning: Invalid n_heads ({N_h}) or dim ({D}). Setting head_dim to D.")
    D_h = D
else:
    D_h = D // N_h

# Output dimension for QKV linear layer
D_qkv_out = D * 3
# Output dimension for FFN input layer
D_ffn_in_out = H * 2

print(f"Calculating FLOPs ratios/needed k assuming L={L}, D={D}, H={H}, D_h={D_h}, N_h={N_h}")
print(f"Target FLOPs Ratio for k calculation: {target_flop_ratio:.2f}\n")

# Helper function to print results
def print_rla_part_info(part_name, m, d, n, current_k, target_ratio):
    if current_k > 0:
        ratio = rla_mm_flops_ratio(m, d, n, current_k)
        print(f"{part_name} (Current k={current_k}):")
        print(f"  m={m}, d={d}, n={n}")
        print(f"  Current FLOPs Ratio ≈ {ratio:.4f}")
    else:
        print(f"{part_name} (Current k=N/A - Deterministic):")
        print(f"  m={m}, d={d}, n={n}")
        print(f"  Current FLOPs Ratio = 1.0000")

    k_needed = k_for_target_ratio(m, d, n, target_ratio)
    if k_needed is not None:
        print(f"  Approx k needed for ratio {target_ratio:.2f} ≈ {k_needed}")
    else:
        print(f"  Target ratio {target_ratio:.2f} not achievable or dims invalid.")
    print("-" * 20)

# 1. Attention QKV Projection
print_rla_part_info("Attention QKV Projection", L, D, D_qkv_out, args.attention_qkv_sketch_size, target_flop_ratio)

# 2. Attention Output Projection
print_rla_part_info("Attention Output Projection", L, D, D, args.attention_out_sketch_size, target_flop_ratio)

# 3. FeedForward Input Projection
print_rla_part_info("FeedForward Input Projection", L, D, D_ffn_in_out, args.feedforward_sketch_size_in, target_flop_ratio)

# 4. FeedForward Output Projection
print_rla_part_info("FeedForward Output Projection", L, H, D, args.feedforward_sketch_size_out, target_flop_ratio)

# 5. SDPA QK^T Calculation
print_rla_part_info("SDPA QK^T Calculation", L, D_h, L, args.attention_score_sketch_size, target_flop_ratio)

# 6. SDPA Scores@V Calculation
print_rla_part_info("SDPA Scores@V Calculation", L, L, D_h, args.attention_weighed_sum_sketch_size, target_flop_ratio)

Calculating FLOPs ratios/needed k assuming L=126, D=256, H=896, D_h=64, N_h=4
Target FLOPs Ratio for k calculation: 0.50

Attention QKV Projection (Current k=48):
  m=126, d=256, n=768
  Current FLOPs Ratio ≈ 0.6310
  Approx k needed for ratio 0.50 ≈ 38.0377358490566
--------------------
Attention Output Projection (Current k=36):
  m=126, d=256, n=256
  Current FLOPs Ratio ≈ 0.5670
  Approx k needed for ratio 0.50 ≈ 31.748031496062993
--------------------
FeedForward Input Projection (Current k=48):
  m=126, d=256, n=1792
  Current FLOPs Ratio ≈ 0.5952
  Approx k needed for ratio 0.50 ≈ 40.32
--------------------
FeedForward Output Projection (Current k=48):
  m=126, d=896, n=256
  Current FLOPs Ratio ≈ 0.6220
  Approx k needed for ratio 0.50 ≈ 38.58373205741627
--------------------
SDPA QK^T Calculation (Current k=24):
  m=126, d=64, n=126
  Current FLOPs Ratio ≈ 0.7560
  Approx k needed for ratio 0.50 ≈ 15.874015748031496
--------------------
SDPA Scores@V Calculation (Current k=24)

In [10]:
# Training configuration
standard_lr = args.standard_lr / 512
standard_epoch = args.standard_epoch * 512
standard_warmup_steps = args.standard_warmup_steps * 512
batch_size = args.batch_size

lr = standard_lr * batch_size
warmup_steps = standard_warmup_steps // batch_size
epochs = standard_epoch // batch_size

print("Derived Parameters:")
print(f"lr: {lr}")
print(f"warmup_steps: {warmup_steps}")
print(f"epochs: {epochs}")
print(f"grad_clip_max_norm: {args.grad_clip_max_norm}", end="\n\n")

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, fused=True)
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
            lr_lambda=lambda step: linear_warmup_cosine_decay_multiplicative(step, warmup_steps, epochs, args.min_lr))

scaler = torch.amp.GradScaler(device, enabled=args.use_amp)

Derived Parameters:
lr: 0.00632
warmup_steps: 2000
epochs: 40000
grad_clip_max_norm: 1.0



In [11]:
# Save the model and arguments

def save_record(path, model, record, args, time_stamp, extra_info=None):
    dict_name, file_name = name_args(args, "_")

    os.makedirs(f"{path}/{dict_name}", exist_ok=True)
    file_name = file_name + f"_{time_stamp}"
    if extra_info is not None:
        file_name += f"_{extra_info}"
    
    record_dict = {
        "model": model,
        "record": record,
        "args": args,
        "time_stamp": time_stamp,
    }
        
    torch.save(record_dict, f"{path}/{dict_name}/{file_name}.pth")

In [12]:
# Backwards pass
def backward_pass(model, loss, optimizer, scaler, scheduler, grad_clip_max_norm):
    optimizer.zero_grad(set_to_none=True)
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_max_norm)
    scaler.step(optimizer)
    scaler.update()
    scheduler.step()

In [13]:
@torch.compile(disable=not args.use_compile)
def train_step(model, train_data, mean_len, optimizer, scheduler, scaler, args):
    device = train_data[0].device
    
    with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=args.use_amp):
        tokens, lengths, ans_starts, ans_lengths = train_data
        pred = transformer_forward_pass(tokens[:, :-1], model)
        
        result = single_answer_seq_loss(pred, tokens, lengths, ans_starts, ans_lengths)
        GPT_loss, full_seq_acc, ans_region_acc, ans_char_acc = result
        # Normalize the GPT loss by the batch size but not the sequence length
        GPT_loss = GPT_loss / args.batch_size
        total_loss = GPT_loss
        total_loss_for_backward = total_loss / mean_len
    
    if torch.isnan(total_loss) or torch.isinf(total_loss):# or (total_loss > smoothed_loss * 1.1):
        return [total_loss]
    
    with torch.no_grad():
        safe_params = [copy.deepcopy(i.state_dict()) for i in [model, optimizer, scheduler]]

    backward_pass(model, total_loss_for_backward, optimizer, scaler, scheduler, args.grad_clip_max_norm)
    
    data = [GPT_loss, 0, total_loss, full_seq_acc, ans_region_acc, ans_char_acc]
    
    data += [0, 0]
    
    with torch.inference_mode():
        data = torch.tensor(data).cpu().numpy()

    return data, safe_params

In [None]:
record = np.zeros((epochs, 9))
num_NaNs = 0
smoothed_loss = None

safe_params = [copy.deepcopy(i.state_dict()) for i in [model, optimizer, scheduler]]

epoch = 0

for train_data in dataloader:
    if epoch >= epochs:
        break

    train_data = [x.to(device) for x in train_data]

    t0 = time.time()

    result = train_step(model, train_data, mean_len, optimizer, scheduler, scaler, args)

    if len(result) == 1:
        data = result
        total_loss = data[0]
        num_NaNs += 1
        print(f"Epoch: {epoch}")
        print("Instability detected")
        print(f"Total Loss: {total_loss.item()}\n")
        model.load_state_dict(safe_params[0])
        optimizer.load_state_dict(safe_params[1])
        scheduler.load_state_dict(safe_params[2])
        optimizer.zero_grad(set_to_none=True)
        continue

    data, safe_params = result
    smoothed_loss = 0.99 * smoothed_loss + 0.01 * data[2].item() if smoothed_loss is not None else data[2].item()
    epoch = epoch + 1

    record[epoch - 1, :-1] = data
    record[epoch - 1, -1] = num_NaNs
        
    names = ["GPT loss", "Energy Reg", "Total_loss", "Full Seq Acc",
             "Ans Region Acc", "Ans Char Acc", "Mean Steps", "Std Steps", "Num NaNs"]

    print(f"Epoch: {epoch}")
    for name, value in zip(names, data):
        print(f"{name}: {value}")
    print(f"Smoothed Loss: {smoothed_loss}")
    print(f"Time: {time.time() - t0}\n")

    if epoch % 100 == 0:
        #save_record(args.save_path, model, record, args, time_stamp)
        pass

In [None]:
#save_record(args.final_save_path, model, record, args, time_stamp, smoothed_loss)