In [1]:
# %%
from utils import *
from trainer import Trainer
# %%
device = 'cuda:0'

base_model = HookedTransformer.from_pretrained(
    "pythia-160m-deduped", 
    device=device, 
)

checkpoint_mid_model = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-160m-deduped", 
    device=device, 
    checkpoint_value = 512
)

The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48


Loaded pretrained model pythia-160m-deduped into HookedTransformer
Loaded pretrained model EleutherAI/pythia-160m-deduped into HookedTransformer


In [2]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm

def compile_all_tokens(sequence_length=256, batch_size=512, max_batches=1000, device='cuda:0'):
    """
    Iterates through the pile_dedup_sample dataset and compiles a new all_tokens tensor.

    Args:
        sequence_length (int): The length of each token sequence.
        batch_size (int): The number of sequences per batch.
        max_batches (int, optional): The maximum number of batches to process. Defaults to None.
        device (str): The device to store the tensor on.

    Returns:
        torch.Tensor: A tensor containing all tokenized sequences.
    """
    # Load dataset
    pile_dedup_sample = load_dataset('EleutherAI/the_pile_deduplicated', streaming=True, split='train')

    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-160m-deduped')

    all_token_batches = []

    current_batch = []
    batch_count = 0

    print("Compiling all_tokens tensor...")
    pbar = tqdm(total=max_batches)
    for sample in pile_dedup_sample:
        # Tokenize the text
        tokens = tokenizer.encode(sample['text'])
        
        # Get as many sequence_length chunks as possible from tokens
        for i in range(0, len(tokens) - sequence_length + 1, sequence_length):
            current_batch.append(tokens[i:i + sequence_length])
            
            # If batch is full, add to all_token_batches
            if len(current_batch) == batch_size:
                batch_tensor = torch.tensor(current_batch, dtype=torch.int32)
                all_token_batches.append(batch_tensor)
                current_batch = []
                batch_count += 1
                pbar.update(1)
                # Check if we've reached the maximum number of batches
                if max_batches and batch_count >= max_batches:
                    break
        
        # Break outer loop if we've hit max_batches
        if max_batches and batch_count >= max_batches:
            break

    pbar.close()

    # Handle the last batch if it's not empty
    if current_batch:
        batch_tensor = torch.tensor(current_batch, dtype=torch.int32)
        all_token_batches.append(batch_tensor)

    # Concatenate all batches into a single tensor
    all_tokens = torch.cat(all_token_batches, dim=0).to(device)

    print(f"Compiled all_tokens tensor with shape: {all_tokens.shape}")
    return all_tokens

In [3]:
actual_toks = compile_all_tokens(batch_size = 128, max_batches = 128)

Resolving data files:   0%|          | 0/1650 [00:00<?, ?it/s]

Compiling all_tokens tensor...


100%|██████████| 128/128 [00:24<00:00,  5.14it/s]

Compiled all_tokens tensor with shape: torch.Size([16384, 256])





In [5]:
256*128*1000

32768000

In [4]:
default_cfg = {
    "seed": 49,
    "batch_size": 128,
    "buffer_mult": 128,
    "lr": 5e-5,
    "num_tokens": 400_000_000,
    "l1_coeff": 2,
    "beta1": 0.9,
    "beta2": 0.999,
    "d_in": base_model.cfg.d_model,
    "dict_size": 768*8*2,
    "seq_len": 256,
    "enc_dtype": "fp32",
    "model_name": "pythia-160m-deduped",
    "site": "resid_pre",
    "device": "cuda:0",
    "model_batch_size": 4,
    "log_every": 100,
    "save_every": 30000,
    "dec_init_norm": 0.08,
    "hook_point": "blocks.5.hook_resid_pre",
    "wandb_project": "crosscoder-fun",
    "wandb_run_name": "toy-run-0",
}
cfg = arg_parse_update_cfg(default_cfg)

trainer = Trainer(cfg, base_model, checkpoint_mid_model, actual_toks)
trainer.train()
# %%

In IPython - skipped argparse


Estimating norm scaling factor: 100%|██████████| 100/100 [00:03<00:00, 25.58it/s]
Estimating norm scaling factor: 100%|██████████| 100/100 [00:03<00:00, 25.01it/s]


Refreshing the buffer!


100%|██████████| 16/16 [00:01<00:00, 11.92it/s]
ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtim_hua[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


  0%|          | 10/3125000 [00:00<17:31:17, 49.54it/s]

{'loss': 1196.2391357421875, 'l2_loss': 1196.2391357421875, 'l1_loss': 82.25203704833984, 'l0_loss': 6148.953125, 'l1_coeff': 0.0, 'lr': 5e-05, 'explained_variance': -0.14152106642723083, 'explained_variance_A': -0.3090880513191223, 'explained_variance_B': -0.044093161821365356}


  0%|          | 60/3125000 [00:00<9:44:19, 89.13it/s] 

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 24.68it/s]
  0%|          | 111/3125000 [00:01<10:54:54, 79.53it/s]

{'loss': 445.2242736816406, 'l2_loss': 444.2995300292969, 'l1_loss': 722.4590454101562, 'l0_loss': 11339.421875, 'l1_coeff': 0.00128, 'lr': 5e-05, 'explained_variance': 0.590605616569519, 'explained_variance_A': 0.45075660943984985, 'explained_variance_B': 0.6657016277313232}


  0%|          | 121/3125000 [00:01<10:28:34, 82.86it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 17.91it/s]
  0%|          | 188/3125000 [00:02<11:01:52, 78.69it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.59it/s]
  0%|          | 216/3125000 [00:03<16:28:43, 52.67it/s]

{'loss': 374.96240234375, 'l2_loss': 372.26580810546875, 'l1_loss': 1053.35693359375, 'l0_loss': 11583.421875, 'l1_coeff': 0.00256, 'lr': 5e-05, 'explained_variance': 0.7047267556190491, 'explained_variance_A': 0.6066105961799622, 'explained_variance_B': 0.7531558275222778}


  0%|          | 246/3125000 [00:04<11:24:30, 76.08it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 18.98it/s]
  0%|          | 307/3125000 [00:05<11:02:26, 78.61it/s]

{'loss': 282.74884033203125, 'l2_loss': 277.85748291015625, 'l1_loss': 1273.7899169921875, 'l0_loss': 11831.6328125, 'l1_coeff': 0.00384, 'lr': 5e-05, 'explained_variance': 0.7681727409362793, 'explained_variance_A': 0.681261420249939, 'explained_variance_B': 0.8019849061965942}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 18.25it/s]
  0%|          | 376/3125000 [00:06<10:44:19, 80.83it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 19.10it/s]
  0%|          | 415/3125000 [00:07<13:34:31, 63.93it/s]

{'loss': 223.29476928710938, 'l2_loss': 216.51205444335938, 'l1_loss': 1324.7491455078125, 'l0_loss': 11907.7265625, 'l1_coeff': 0.00512, 'lr': 5e-05, 'explained_variance': 0.8006809949874878, 'explained_variance_A': 0.7192561626434326, 'explained_variance_B': 0.8477643728256226}


  0%|          | 434/3125000 [00:07<12:07:57, 71.54it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 17.02it/s]
  0%|          | 494/3125000 [00:08<11:57:18, 72.60it/s]

{'loss': 208.29763793945312, 'l2_loss': 198.7939910888672, 'l1_loss': 1484.944580078125, 'l0_loss': 11846.265625, 'l1_coeff': 0.0064, 'lr': 5e-05, 'explained_variance': 0.8416138887405396, 'explained_variance_A': 0.782340943813324, 'explained_variance_B': 0.8581240177154541}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 17.71it/s]
  0%|          | 562/3125000 [00:09<11:19:41, 76.61it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.24it/s]
  0%|          | 610/3125000 [00:10<12:53:57, 67.28it/s]

{'loss': 170.83360290527344, 'l2_loss': 159.59698486328125, 'l1_loss': 1463.1014404296875, 'l0_loss': 11847.7890625, 'l1_coeff': 0.00768, 'lr': 5e-05, 'explained_variance': 0.8532363176345825, 'explained_variance_A': 0.7932829260826111, 'explained_variance_B': 0.8856824636459351}


  0%|          | 620/3125000 [00:10<11:43:41, 74.00it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 18.32it/s]
  0%|          | 690/3125000 [00:11<10:32:49, 82.28it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 19.87it/s]
  0%|          | 718/3125000 [00:12<15:18:49, 56.67it/s]

{'loss': 143.1224365234375, 'l2_loss': 129.58837890625, 'l1_loss': 1510.497802734375, 'l0_loss': 11795.6796875, 'l1_coeff': 0.00896, 'lr': 5e-05, 'explained_variance': 0.8798946142196655, 'explained_variance_A': 0.8330066204071045, 'explained_variance_B': 0.9060719013214111}


  0%|          | 746/3125000 [00:13<11:35:50, 74.83it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 18.55it/s]
  0%|          | 815/3125000 [00:14<10:46:11, 80.58it/s]

{'loss': 128.8684539794922, 'l2_loss': 113.14756774902344, 'l1_loss': 1535.242431640625, 'l0_loss': 11729.953125, 'l1_coeff': 0.01024, 'lr': 5e-05, 'explained_variance': 0.8980711698532104, 'explained_variance_A': 0.8574055433273315, 'explained_variance_B': 0.9199135303497314}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.47it/s]
  0%|          | 874/3125000 [00:15<11:15:50, 77.04it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 18.26it/s]
  0%|          | 914/3125000 [00:16<13:46:02, 63.03it/s]

{'loss': 110.98519134521484, 'l2_loss': 93.09612274169922, 'l1_loss': 1552.87060546875, 'l0_loss': 11711.8984375, 'l1_coeff': 0.01152, 'lr': 5e-05, 'explained_variance': 0.9132499694824219, 'explained_variance_A': 0.8797569274902344, 'explained_variance_B': 0.9307528734207153}


  0%|          | 944/3125000 [00:16<12:13:52, 70.95it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.16it/s]
  0%|          | 1003/3125000 [00:17<12:37:25, 68.74it/s]

{'loss': 108.40829467773438, 'l2_loss': 88.49311828613281, 'l1_loss': 1555.87353515625, 'l0_loss': 11600.8984375, 'l1_coeff': 0.0128, 'lr': 5e-05, 'explained_variance': 0.9183162450790405, 'explained_variance_A': 0.8912547826766968, 'explained_variance_B': 0.9322361350059509}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 19.03it/s]
  0%|          | 1069/3125000 [00:18<11:39:54, 74.39it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 17.62it/s]
  0%|          | 1118/3125000 [00:19<13:13:28, 65.62it/s]

{'loss': 105.28305053710938, 'l2_loss': 82.33505249023438, 'l1_loss': 1629.8292236328125, 'l0_loss': 11448.546875, 'l1_coeff': 0.01408, 'lr': 5e-05, 'explained_variance': 0.9319586753845215, 'explained_variance_A': 0.9131718873977661, 'explained_variance_B': 0.937177300453186}


  0%|          | 1128/3125000 [00:20<11:57:43, 72.54it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.80it/s]
  0%|          | 1187/3125000 [00:21<11:17:33, 76.84it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 17.66it/s]
  0%|          | 1218/3125000 [00:21<15:18:02, 56.71it/s]

{'loss': 95.05500030517578, 'l2_loss': 69.89958190917969, 'l1_loss': 1637.72265625, 'l0_loss': 11321.4296875, 'l1_coeff': 0.01536, 'lr': 5e-05, 'explained_variance': 0.9398634433746338, 'explained_variance_A': 0.9261484146118164, 'explained_variance_B': 0.9451087713241577}


  0%|          | 1258/3125000 [00:22<10:41:56, 81.10it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 15.96it/s]
  0%|          | 1318/3125000 [00:23<11:38:18, 74.55it/s]

{'loss': 84.62615966796875, 'l2_loss': 59.290550231933594, 'l1_loss': 1522.572509765625, 'l0_loss': 11218.109375, 'l1_coeff': 0.01664, 'lr': 5e-05, 'explained_variance': 0.9455729722976685, 'explained_variance_A': 0.9297982454299927, 'explained_variance_B': 0.9535121321678162}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 19.15it/s]
  0%|          | 1384/3125000 [00:24<9:02:18, 96.00it/s] 

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 27.78it/s]
  0%|          | 1425/3125000 [00:24<9:58:53, 86.93it/s] 

{'loss': 76.20369720458984, 'l2_loss': 50.09722900390625, 'l1_loss': 1456.833984375, 'l0_loss': 10993.1015625, 'l1_coeff': 0.01792, 'lr': 5e-05, 'explained_variance': 0.9531610012054443, 'explained_variance_A': 0.9407783150672913, 'explained_variance_B': 0.9591221213340759}


  0%|          | 1439/3125000 [00:25<8:57:25, 96.87it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 14.50it/s]
  0%|          | 1506/3125000 [00:26<11:39:48, 74.39it/s]

{'loss': 88.17947387695312, 'l2_loss': 58.55598068237305, 'l1_loss': 1542.89013671875, 'l0_loss': 10851.21875, 'l1_coeff': 0.0192, 'lr': 5e-05, 'explained_variance': 0.9521947503089905, 'explained_variance_A': 0.9452762603759766, 'explained_variance_B': 0.9498292803764343}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 23.53it/s]
  0%|          | 1571/3125000 [00:27<9:00:15, 96.36it/s] 

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 24.74it/s]
  0%|          | 1611/3125000 [00:27<10:31:27, 82.44it/s]

{'loss': 76.43123626708984, 'l2_loss': 46.04279708862305, 'l1_loss': 1483.8106689453125, 'l0_loss': 10791.984375, 'l1_coeff': 0.02048, 'lr': 5e-05, 'explained_variance': 0.9579405784606934, 'explained_variance_A': 0.9513286352157593, 'explained_variance_B': 0.9610964059829712}


  0%|          | 1633/3125000 [00:28<10:40:50, 81.23it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 15.16it/s]
  0%|          | 1695/3125000 [00:29<12:12:59, 71.02it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 15.96it/s]
  0%|          | 1712/3125000 [00:30<22:53:00, 37.91it/s]

{'loss': 78.07183074951172, 'l2_loss': 45.22510528564453, 'l1_loss': 1509.500244140625, 'l0_loss': 10602.6171875, 'l1_coeff': 0.02176, 'lr': 5e-05, 'explained_variance': 0.9614866971969604, 'explained_variance_A': 0.9585089683532715, 'explained_variance_B': 0.959269642829895}


  0%|          | 1761/3125000 [00:30<11:33:28, 75.06it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.24it/s]
  0%|          | 1819/3125000 [00:32<11:49:48, 73.33it/s]

{'loss': 79.27752685546875, 'l2_loss': 44.69355773925781, 'l1_loss': 1501.04052734375, 'l0_loss': 10418.9296875, 'l1_coeff': 0.02304, 'lr': 5e-05, 'explained_variance': 0.964844822883606, 'explained_variance_A': 0.9624900817871094, 'explained_variance_B': 0.9608471393585205}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.82it/s]
  0%|          | 1886/3125000 [00:33<11:01:23, 78.70it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 18.93it/s]
  0%|          | 1916/3125000 [00:33<15:30:01, 55.97it/s]

{'loss': 72.38909912109375, 'l2_loss': 36.85110855102539, 'l1_loss': 1461.265869140625, 'l0_loss': 10219.90625, 'l1_coeff': 0.02432, 'lr': 5e-05, 'explained_variance': 0.9678746461868286, 'explained_variance_A': 0.9654660224914551, 'explained_variance_B': 0.9672067165374756}


  0%|          | 1944/3125000 [00:34<11:43:42, 73.97it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 15.89it/s]
  0%|          | 2013/3125000 [00:35<10:46:53, 80.46it/s]

{'loss': 84.3740005493164, 'l2_loss': 47.00244140625, 'l1_loss': 1459.8265380859375, 'l0_loss': 10101.1796875, 'l1_coeff': 0.0256, 'lr': 5e-05, 'explained_variance': 0.9704200029373169, 'explained_variance_A': 0.9699584245681763, 'explained_variance_B': 0.9558425545692444}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 18.40it/s]
  0%|          | 2072/3125000 [00:36<11:38:08, 74.55it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.65it/s]
  0%|          | 2119/3125000 [00:37<12:57:49, 66.91it/s]

{'loss': 75.4444580078125, 'l2_loss': 37.754051208496094, 'l1_loss': 1402.1728515625, 'l0_loss': 9886.4921875, 'l1_coeff': 0.02688, 'lr': 5e-05, 'explained_variance': 0.9743948578834534, 'explained_variance_A': 0.9730702638626099, 'explained_variance_B': 0.9664196968078613}


  0%|          | 2138/3125000 [00:37<11:11:38, 77.49it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.36it/s]
  0%|          | 2196/3125000 [00:38<12:12:03, 71.10it/s]

{'loss': 65.39859008789062, 'l2_loss': 29.175479888916016, 'l1_loss': 1286.33203125, 'l0_loss': 9749.90625, 'l1_coeff': 0.02816, 'lr': 5e-05, 'explained_variance': 0.9732116460800171, 'explained_variance_A': 0.9710004329681396, 'explained_variance_B': 0.9735487699508667}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 18.03it/s]
  0%|          | 2265/3125000 [00:40<11:08:57, 77.80it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 18.42it/s]
  0%|          | 2314/3125000 [00:41<13:01:23, 66.61it/s]

{'loss': 74.16581726074219, 'l2_loss': 31.922100067138672, 'l1_loss': 1434.908935546875, 'l0_loss': 9565.921875, 'l1_coeff': 0.02944, 'lr': 5e-05, 'explained_variance': 0.9778450727462769, 'explained_variance_A': 0.979667067527771, 'explained_variance_B': 0.9699521064758301}


  0%|          | 2323/3125000 [00:41<12:05:19, 71.75it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.78it/s]
  0%|          | 2390/3125000 [00:42<11:13:29, 77.27it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 18.72it/s]
  0%|          | 2419/3125000 [00:43<15:37:35, 55.51it/s]

{'loss': 64.96277618408203, 'l2_loss': 26.388341903686523, 'l1_loss': 1255.67822265625, 'l0_loss': 9403.28125, 'l1_coeff': 0.03072, 'lr': 5e-05, 'explained_variance': 0.975999116897583, 'explained_variance_A': 0.9741057753562927, 'explained_variance_B': 0.9762933254241943}


  0%|          | 2456/3125000 [00:43<12:48:53, 67.68it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.76it/s]
  0%|          | 2510/3125000 [00:44<12:14:55, 70.81it/s]

{'loss': 62.888763427734375, 'l2_loss': 23.43719482421875, 'l1_loss': 1232.8614501953125, 'l0_loss': 9296.375, 'l1_coeff': 0.032, 'lr': 5e-05, 'explained_variance': 0.9782857894897461, 'explained_variance_A': 0.977840781211853, 'explained_variance_B': 0.9777477979660034}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.52it/s]
  0%|          | 2581/3125000 [00:46<10:33:38, 82.13it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 23.87it/s]
  0%|          | 2619/3125000 [00:46<11:07:45, 77.93it/s]

{'loss': 62.67619705200195, 'l2_loss': 22.9033203125, 'l1_loss': 1195.0985107421875, 'l0_loss': 9069.921875, 'l1_coeff': 0.03328, 'lr': 5e-05, 'explained_variance': 0.9788900017738342, 'explained_variance_A': 0.97809898853302, 'explained_variance_B': 0.9786554574966431}


  0%|          | 2633/3125000 [00:46<9:31:55, 90.99it/s] 

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 20.35it/s]
  0%|          | 2705/3125000 [00:47<10:53:26, 79.64it/s]

{'loss': 62.2692985534668, 'l2_loss': 21.223968505859375, 'l1_loss': 1187.654296875, 'l0_loss': 8928.1953125, 'l1_coeff': 0.03456, 'lr': 5e-05, 'explained_variance': 0.9806457757949829, 'explained_variance_A': 0.9801451563835144, 'explained_variance_B': 0.9803112745285034}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.48it/s]
  0%|          | 2763/3125000 [00:48<11:49:26, 73.35it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 18.97it/s]
  0%|          | 2811/3125000 [00:49<12:26:58, 69.66it/s]

{'loss': 59.92285919189453, 'l2_loss': 19.132579803466797, 'l1_loss': 1138.12158203125, 'l0_loss': 8729.6484375, 'l1_coeff': 0.03584, 'lr': 5e-05, 'explained_variance': 0.982133150100708, 'explained_variance_A': 0.9802696108818054, 'explained_variance_B': 0.9824930429458618}


  0%|          | 2831/3125000 [00:50<10:37:54, 81.57it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 17.82it/s]
  0%|          | 2891/3125000 [00:51<11:15:53, 76.99it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 15.97it/s]
  0%|          | 2911/3125000 [00:51<19:58:40, 43.41it/s]

{'loss': 66.20458984375, 'l2_loss': 20.71979522705078, 'l1_loss': 1225.3448486328125, 'l0_loss': 8626.7734375, 'l1_coeff': 0.03712, 'lr': 5e-05, 'explained_variance': 0.9829503893852234, 'explained_variance_A': 0.9829859733581543, 'explained_variance_B': 0.98067307472229}


  0%|          | 2954/3125000 [00:52<11:02:39, 78.52it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 18.76it/s]
  0%|          | 3012/3125000 [00:53<11:47:37, 73.53it/s]

{'loss': 65.09286499023438, 'l2_loss': 20.955669403076172, 'l1_loss': 1149.406005859375, 'l0_loss': 8505.3515625, 'l1_coeff': 0.0384, 'lr': 5e-05, 'explained_variance': 0.9807083606719971, 'explained_variance_A': 0.9803259372711182, 'explained_variance_B': 0.9803898334503174}


  0%|          | 3021/3125000 [00:53<11:52:37, 73.02it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 11.35it/s]
  0%|          | 3078/3125000 [00:55<14:21:04, 60.43it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 17.37it/s]
  0%|          | 3115/3125000 [00:55<15:15:24, 56.84it/s]

{'loss': 76.85281372070312, 'l2_loss': 23.97357940673828, 'l1_loss': 1332.6419677734375, 'l0_loss': 8292.4453125, 'l1_coeff': 0.03968, 'lr': 5e-05, 'explained_variance': 0.9840060472488403, 'explained_variance_A': 0.9869582653045654, 'explained_variance_B': 0.9769103527069092}


  0%|          | 3144/3125000 [00:56<11:18:32, 76.68it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.18it/s]
  0%|          | 3211/3125000 [00:57<10:53:17, 79.64it/s]

{'loss': 64.2777099609375, 'l2_loss': 20.178346633911133, 'l1_loss': 1076.6446533203125, 'l0_loss': 8149.4453125, 'l1_coeff': 0.04096, 'lr': 5e-05, 'explained_variance': 0.9812037944793701, 'explained_variance_A': 0.9790123701095581, 'explained_variance_B': 0.9819427728652954}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 17.90it/s]
  0%|          | 3270/3125000 [00:58<11:20:42, 76.43it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.27it/s]
  0%|          | 3319/3125000 [00:59<13:00:51, 66.63it/s]

{'loss': 62.970916748046875, 'l2_loss': 17.99297523498535, 'l1_loss': 1064.818603515625, 'l0_loss': 8010.140625, 'l1_coeff': 0.04224, 'lr': 5e-05, 'explained_variance': 0.9832982420921326, 'explained_variance_A': 0.981763482093811, 'explained_variance_B': 0.9837067127227783}


  0%|          | 3330/3125000 [00:59<11:32:40, 75.11it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 17.90it/s]
  0%|          | 3399/3125000 [01:00<10:37:10, 81.65it/s]

{'loss': 75.37968444824219, 'l2_loss': 26.67446517944336, 'l1_loss': 1119.145751953125, 'l0_loss': 7913.953125, 'l1_coeff': 0.04352, 'lr': 5e-05, 'explained_variance': 0.9829422831535339, 'explained_variance_A': 0.9826867580413818, 'explained_variance_B': 0.975132405757904}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 15.66it/s]
  0%|          | 3460/3125000 [01:02<11:24:27, 76.01it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.84it/s]
  0%|          | 3510/3125000 [01:03<12:24:18, 69.90it/s]

{'loss': 67.4717788696289, 'l2_loss': 18.937278747558594, 'l1_loss': 1083.359375, 'l0_loss': 7736.890625, 'l1_coeff': 0.0448, 'lr': 5e-05, 'explained_variance': 0.9843300580978394, 'explained_variance_A': 0.9831665754318237, 'explained_variance_B': 0.9829293489456177}


  0%|          | 3521/3125000 [01:03<11:08:07, 77.87it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 17.50it/s]
  0%|          | 3590/3125000 [01:04<10:52:33, 79.72it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 16.68it/s]
  0%|          | 3619/3125000 [01:05<16:14:48, 53.37it/s]

{'loss': 66.74794006347656, 'l2_loss': 17.36458396911621, 'l1_loss': 1071.6873779296875, 'l0_loss': 7593.5859375, 'l1_coeff': 0.04608, 'lr': 5e-05, 'explained_variance': 0.9849069118499756, 'explained_variance_A': 0.9839762449264526, 'explained_variance_B': 0.9842144250869751}


  0%|          | 3648/3125000 [01:05<11:35:29, 74.80it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 19.05it/s]
  0%|          | 3712/3125000 [01:06<11:23:55, 76.06it/s]

{'loss': 63.79872512817383, 'l2_loss': 17.244823455810547, 'l1_loss': 982.9793701171875, 'l0_loss': 7455.9140625, 'l1_coeff': 0.04736, 'lr': 5e-05, 'explained_variance': 0.9838117361068726, 'explained_variance_A': 0.9811895489692688, 'explained_variance_B': 0.9848067760467529}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 10.45it/s]
  0%|          | 3779/3125000 [01:08<14:11:49, 61.07it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 19.79it/s]
  0%|          | 3816/3125000 [01:09<14:12:42, 61.01it/s]

{'loss': 66.96625518798828, 'l2_loss': 18.287511825561523, 'l1_loss': 1000.7964477539062, 'l0_loss': 7367.0078125, 'l1_coeff': 0.04864, 'lr': 5e-05, 'explained_variance': 0.9830179810523987, 'explained_variance_A': 0.9810871481895447, 'explained_variance_B': 0.9835548400878906}


  0%|          | 3837/3125000 [01:09<11:07:13, 77.96it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 14.56it/s]
  0%|          | 3902/3125000 [01:10<13:27:54, 64.39it/s]

{'loss': 66.82257080078125, 'l2_loss': 17.41907501220703, 'l1_loss': 989.6533203125, 'l0_loss': 7222.4921875, 'l1_coeff': 0.04992, 'lr': 5e-05, 'explained_variance': 0.9840904474258423, 'explained_variance_A': 0.9820393323898315, 'explained_variance_B': 0.9846476316452026}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 12.93it/s]
  0%|          | 3966/3125000 [01:12<15:05:27, 57.45it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 14.24it/s]
  0%|          | 4019/3125000 [01:13<12:52:52, 67.30it/s]

{'loss': 74.26235961914062, 'l2_loss': 17.999221801757812, 'l1_loss': 1098.889404296875, 'l0_loss': 7132.4140625, 'l1_coeff': 0.0512, 'lr': 5e-05, 'explained_variance': 0.9853003025054932, 'explained_variance_A': 0.9854974746704102, 'explained_variance_B': 0.9838504791259766}


  0%|          | 4029/3125000 [01:13<11:43:51, 73.90it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 17.95it/s]
  0%|          | 4088/3125000 [01:14<11:24:34, 75.98it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 20.23it/s]
  0%|          | 4117/3125000 [01:15<16:34:32, 52.30it/s]

{'loss': 66.85575866699219, 'l2_loss': 16.087154388427734, 'l1_loss': 967.3894653320312, 'l0_loss': 6987.875, 'l1_coeff': 0.05248, 'lr': 5e-05, 'explained_variance': 0.9853802919387817, 'explained_variance_A': 0.9834029674530029, 'explained_variance_B': 0.9860292673110962}


  0%|          | 4154/3125000 [01:16<15:06:05, 57.41it/s]

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 13.59it/s]
  0%|          | 4219/3125000 [01:17<11:02:11, 78.55it/s]

{'loss': 71.39600372314453, 'l2_loss': 18.13732147216797, 'l1_loss': 990.6749267578125, 'l0_loss': 6850.9453125, 'l1_coeff': 0.05376, 'lr': 5e-05, 'explained_variance': 0.9853328466415405, 'explained_variance_A': 0.9837422370910645, 'explained_variance_B': 0.9842305183410645}
Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 18.88it/s]
  0%|          | 4283/3125000 [01:18<9:58:05, 86.96it/s] 

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 32.48it/s]
  0%|          | 4322/3125000 [01:18<10:01:08, 86.52it/s]

{'loss': 64.45153045654297, 'l2_loss': 15.300433158874512, 'l1_loss': 893.0069580078125, 'l0_loss': 6742.8046875, 'l1_coeff': 0.05504, 'lr': 5e-05, 'explained_variance': 0.9854601621627808, 'explained_variance_A': 0.9827118515968323, 'explained_variance_B': 0.9864538908004761}


  0%|          | 4336/3125000 [01:19<8:53:47, 97.44it/s] 

Refreshing the buffer!


100%|██████████| 8/8 [00:00<00:00, 14.83it/s]
  0%|          | 4403/3125000 [01:20<14:48:29, 58.54it/s]

{'loss': 72.37870788574219, 'l2_loss': 17.371700286865234, 'l1_loss': 976.6868896484375, 'l0_loss': 6639.3984375, 'l1_coeff': 0.05632, 'lr': 5e-05, 'explained_variance': 0.9848951101303101, 'explained_variance_A': 0.983425498008728, 'explained_variance_B': 0.9847656488418579}
Refreshing the buffer!


  0%|          | 0/8 [00:00<?, ?it/s]
  0%|          | 4409/3125000 [01:20<15:52:50, 54.58it/s]


FileNotFoundError: [Errno 2] No such file or directory: '/workspace/crosscoder-model-diff-replication/checkpoints'

In [4]:
?load_dataset

[0;31mSignature:[0m
[0mload_dataset[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mpath[0m[0;34m:[0m [0mstr[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mname[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mstr[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdata_dir[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mstr[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdata_files[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mSequence[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m,[0m [0mMapping[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mSequence[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m][0m[0;34m][0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0msplit[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mdatasets[0m[0;34m.[0m[0msplits[0m[0;34m