In [None]:
from datasets import load_dataset

num_titles = 10000
val_frac = 0.1
seed = 1337
ds = load_dataset("julien040/hacker-news-posts", split="train", cache_dir="./data").shuffle(seed=seed)
titles = [row["title"].strip() for row in ds.take(num_titles)]
n = int(num_titles * (1 - val_frac))

In [3]:
from omegaconf import OmegaConf
from dataclasses import dataclass
@dataclass
class Hyperparameters:

    seed: int
    epochs: int
    val_frac: float
    num_titles: int
    vocab_size: int
    context_length: int  # Added context_length parameter

    log_file: str
    model_architecture: str 
    
    batch_size: int
    lr: float
    weight_decay: float
    scheduler: str # none, linear, cosine
    optimizer: str
    evals_per_epoch: float


from model.gpt import GPT, GPTConfig

@dataclass
class AttnConfig:
    d_model: int
    n_head: int
    block_size: int
    dropout: float

cfg = OmegaConf.load("config/hyperparams.yaml")
            # Update cfg with args

hparams = OmegaConf.to_container(cfg.hyperparams, resolve=True)
modelparams = OmegaConf.to_container(cfg.model_configs[hparams['model_architecture']], resolve=True)
attnparams = OmegaConf.to_container(cfg.attn_configs[modelparams['attention_layer']], resolve=True)

args = Hyperparameters(**hparams)

attn = AttnConfig(
    d_model=modelparams['d_model'],
    n_head=attnparams['n_head'],
    block_size=args.context_length,
    dropout=modelparams['dropout']
)

cfg = GPTConfig(
    vocab_size=args.vocab_size,
    block_size=args.context_length,
    attn_config = attn,
    activation_function = 'gelu',
    **modelparams
)

In [2]:
import wandb

wandb.init()

cfg = wandb.config

print("Learning rate:", cfg.lr)
print("Batch size:", cfg.batch_size)
print("Attention type:", cfg["sparse.attn_type"])


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


AttributeError: <class 'wandb.sdk.wandb_config.Config'> object has no attribute 'lr'

In [None]:
import torch

def sparseK(u, k):
    B, T = u.shape
    z_sorted, _ = torch.sort(u, dim=1, descending=True)  # (B, T)
    z_cumsum = torch.cumsum(z_sorted, dim=1)  # (B, T)

    # candidates: z and z-1
    beta_candidates = torch.cat([z_sorted, z_sorted - 1], dim=1)  # (B, 2*T)
    beta_sorted, _ = torch.sort(beta_candidates, dim=1, descending=True)

    tau = torch.zeros(B, device=u.device)
    done = torch.zeros(B, dtype=torch.bool, device=u.device)
    p = None

    batch_idx = torch.arange(B, device=u.device)

    print("Input u:\n", u)
    print("Sorted z:\n", z_sorted)
    print("Cumsum z:\n", z_cumsum)
    print("Beta candidates shape:", beta_candidates.shape)
    print("Beta sorted shape:", beta_sorted.shape)

    for i in range(beta_sorted.shape[1]):
        beta = beta_sorted[:, i][:, None]  # (B, 1)

        # indices
        u_idx = (z_sorted >= (beta + 1)).int().sum(dim=1) - 1
        w_idx = (z_sorted > beta).int().sum(dim=1) - 1

        denom = (w_idx - u_idx).float()
        denom = torch.where(denom == 0, torch.ones_like(denom), denom)

        tau_candidate = (
            (z_cumsum[batch_idx, w_idx] - z_cumsum[batch_idx, u_idx])
            + u_idx - k
        ) / denom

        # candidate projection
        p_candidate = torch.clamp(z_sorted - tau_candidate[:, None], 0, 1)
        sum_p = p_candidate.sum(dim=1)

        cond_sum = (sum_p - k).abs() < 1e-6
        cond = (z_sorted[batch_idx, w_idx] > tau_candidate) & \
               (z_sorted[batch_idx, u_idx] >= tau_candidate + 1) & \
               cond_sum

        print(f"\n--- Iter {i} ---")
        print("beta:", beta.squeeze())
        print("u_idx:", u_idx)
        print("w_idx:", w_idx)
        print("tau_candidate:", tau_candidate)
        print("sum_p:", sum_p)
        print("cond:", cond)

        tau[~done & cond] = tau_candidate[~done & cond]
        done = done | cond

        if done.all():
            p = p_candidate
            print(">>> Found valid tau at iter", i)
            break

    if p is None:
        print("No candidate satisfied condition → fallback to hard top-k")
        p = torch.clamp(z_sorted - tau[:, None], 0, 1)
        for b in range(B):
            if not torch.isclose(p[b].sum(), torch.tensor(float(k), device=u.device), atol=1e-6):
                topk_idx = torch.topk(z_sorted[b], k).indices
                p[b] = torch.zeros_like(z_sorted[b])
                p[b, topk_idx] = 1.0

    return p


In [7]:
import yaml

with open('config/sweep_gpt_sparse.yaml') as f:
    data = yaml.safe_load(f)
print(data)

{'program': 'train.py', 'method': 'random', 'metric': {'name': 'val_loss', 'goal': 'minimize'}, 'parameters': {'sparse.n_head': {'values': [4, 8, 16]}, 'sparse.num_verts': {'values': [4, 8, 16]}, 'sparse.sparseblocksize': {'values': [32, 64, 128]}, 'sparse.vertsize': {'values': [64, 128, 256]}, 'sparse.n_bctx': {'values': [1, 2, 4]}, 'sparse.intermediate_dim': {'values': [0, 64, 128]}}}


In [2]:
import wandb

In [None]:
def merge_dotted_keys(base_dict, update_dict, target_path=None):
    """
    Merge keys with dots into nested dicts.
    If target_path is given, merge inside that nested dict.
    """
    import copy
    merged = copy.deepcopy(base_dict)
    
    # if target_path is provided, get the nested dict
    if target_path:
        d = merged
        for k in target_path:
            d = d.setdefault(k, {})
    else:
        d = merged
    
    for key, value in update_dict.items():
        parts = key.split(".")
        curr = d
        for p in parts[:-1]:
            curr = curr.setdefault(p, {})
        curr[parts[-1]] = value
    return merged


In [3]:
sweep_id = wandb.sweep('config/sweep_gpt_sparse.yaml', project="gpt-from-scratch", entity="arc_agi")

CommError: dictionary update sequence element #0 has length 1; 2 is required

In [5]:
from omegaconf import OmegaConf

In [10]:
cfg = OmegaConf.load('config\sweep_gpt_sparse.yaml')
# Convert to a plain dictionary
cfg_dict = OmegaConf.to_container(cfg, resolve=True)


orig_cfg = OmegaConf.load('config\hparams_gpt_sparse.yaml')

In [17]:
sweep_config = 'config\sweep_gpt_sparse.yaml'
cfg = 'config\sweep_gpt_sparse.yaml'


In [None]:
def sweep_train():
    orig_cfg = OmegaConf.load('config\hparams_gpt_sparse.yaml')  # defaults
    with wandb.init() as run:
        print(run)
        sweep_cfg = OmegaConf.create({"hyperparams": dict(run.config)})
        cfg = OmegaConf.merge(orig_cfg, sweep_cfg)
        print(cfg)

In [23]:
cfg = OmegaConf.load(sweep_config)
# Convert to a plain dictionary
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
print(cfg_dict)
sweep_id = wandb.sweep(cfg_dict, project="gpt-from-scratch", entity="arc_agi")
wandb.agent(sweep_id, function=sweep_train)

{'program': 'train.py', 'method': 'random', 'metric': {'name': 'val_loss', 'goal': 'minimize'}, 'parameters': {'sparse.n_head': {'values': [4, 8, 16]}, 'sparse.num_verts': {'values': [4, 8, 16]}, 'sparse.sparseblocksize': {'values': [32, 64, 128]}, 'sparse.vertsize': {'values': [64, 128, 256]}, 'sparse.n_bctx': {'values': [1, 2, 4]}, 'sparse.intermediate_dim': {'values': [0, 64, 128]}}}
Create sweep with ID: t35wi9ov
Sweep URL: https://wandb.ai/arc_agi/gpt-from-scratch/sweeps/t35wi9ov


[34m[1mwandb[0m: Agent Starting Run: hrcytqn4 with config:
[34m[1mwandb[0m: 	sparse.intermediate_dim: 128
[34m[1mwandb[0m: 	sparse.n_bctx: 2
[34m[1mwandb[0m: 	sparse.n_head: 4
[34m[1mwandb[0m: 	sparse.num_verts: 16
[34m[1mwandb[0m: 	sparse.sparseblocksize: 128
[34m[1mwandb[0m: 	sparse.vertsize: 128
Traceback (most recent call last):
  File "c:\Users\teeds\miniconda3\envs\llm_train\lib\site-packages\wandb\agents\pyagent.py", line 297, in _run_job
    self._function()
  File "C:\Users\teeds\AppData\Local\Temp\ipykernel_229868\385059661.py", line 2, in sweep_train
    orig_cfg = OmegaConf.load(args.orig_yaml)  # defaults
NameError: name 'args' is not defined

[34m[1mwandb[0m: [32m[41mERROR[0m Run hrcytqn4 errored: name 'args' is not defined
[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 4gw97p56 with config:
[34m[1mwandb[0m: 	sparse.intermediate_dim: 64
[34m[1mwandb[0m: 	sparse.n_

In [15]:
for key, value in sweep_cfg['parameters'].items():
    print(key, value)

sparse.n_head {'values': 4}
sparse.num_verts {'values': 4}
sparse.sparseblocksize {'values': 32}
sparse.vertsize {'values': 64}
sparse.n_bctx {'values': 1}
sparse.intermediate_dim {'values': 0}


In [13]:
print(cfg)

{'hyperparams': {'seed': 1337, 'epochs': 7, 'val_frac': 0.1, 'num_titles': 100000, 'vocab_size': 16000, 'context_length': 256, 'model_architecture': 'gpt', 'log_file': './logs/mainrun.log', 'batch_size': 128, 'lr': 0.007, 'weight_decay': 0.0, 'scheduler': 'cosine', 'optimizer': 'adagrad', 'evals_per_epoch': 3}, 'model_configs': {'gpt': {'d_model': 256, 'hidden_layer': 256, 'n_layer': 6, 'dropout': 0.1, 'init_method': 'xavier', 'attention_layer': 'sparse'}, 'unet_gpt': {'d_model': 512, 'hidden_layer': 128, 'n_layer': 6, 'dropout': 0.1, 'init_method': 'xavier', 'attention_layer': 'sparse', 'bottleneck_sizes': [512, 256, 256, 128, 128, 256]}}, 'attn_configs': {'causal': {'n_head': 8, 'intermediate_dim': 0}, 'sparse': {'attn_type': 'fixed', 'n_head': 8, 'num_verts': 8, 'local_attn_ctx': 32, 'sparseblocksize': 64, 'vertsize': 128, 'n_bctx': 2, 'intermediate_dim': 0}}, 'program': 'train.py', 'method': 'random', 'metric': {'name': 'val_loss', 'goal': 'minimize'}, 'parameters': {'sparse.n_head

In [7]:
def merge_dotted_keys(base_dict, update_dict, target_path=None):
    """
    Merge keys with dots into nested dicts.
    If target_path is given, merge inside that nested dict.
    """
    import copy
    merged = copy.deepcopy(base_dict)
    
    # if target_path is provided, get the nested dict
    if target_path:
        d = merged
        for k in target_path:
            d = d.setdefault(k, {})
    else:
        d = merged
    
    for key, value in update_dict.items():
        parts = key.split(".")
        curr = d
        for p in parts[:-1]:
            curr = curr.setdefault(p, {})
        curr[parts[-1]] = value
    return merged


In [9]:
# Suppose cfg_dict comes from OmegaConf.to_container()
sweep_params = {k: v['values'][0] for k, v in cfg_dict['parameters'].items()}  # pick first value for example
hyperparams = cfg['hyperparams']
attn_configs = cfg['attn_configs']

hyperparams = merge_dotted_keys(base_dict=hyperparams, update_dict=sweep_params, target_path=['attn_configs'])


ConfigKeyError: Missing key hyperparams
    full_key: hyperparams
    object_type=dict