In [31]:
import data
import data_hf
from modelling import model
import jax.numpy as jnp
import jax
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
np.set_printoptions(edgeitems=30, linewidth=100000, 
    formatter=dict(float=lambda x: "%.3g" % x))

from importlib import reload
from tqdm import tqdm
model = reload(model)
import data
import data_hf
data = reload(data)
data_hf = reload(data_hf)
import data_shae
data_shae = reload(data_shae)
import finetune
finetune = reload(finetune)
import download_data
download_data = reload(download_data)

import numpy as np
import pandas as pd
import umap
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
from google.cloud import storage

In [11]:
cfg = model.Config(
    d_model=2048,
    ffw_multiplier=4,
    query_heads=8,
    key_heads=8,
    num_layers=12,
    key_dim=128,
    vocab_size=8,
    max_seq_len=8192,
    causal=True,
    use_attn_kernel=True,
    weight_dtype_at_rest=jnp.float32,
    active_weight_dtype=jnp.bfloat16,
    rules=model.fsdp_rules,
    mesh=model.create_mesh(),
    max_lr=3e-5,
    min_lr=3e-6,
    warmup_steps=50,
    total_steps=10000,
    return_sae_intermediates=True,
)

# Checkpoint manager setup
checkpoint_dir = "gs://minformer_data/pretrained_ckpt/v1"
ckpt_manager = model.make_mngr(path=checkpoint_dir)

weights, opt_state = model.load(ckpt_manager, cfg)
start_step = ckpt_manager.latest_step()



In [12]:
batch_size = 8
stage_2 = [ "gs://minformer_data/shae_8k/tfrecords/record_*.tfrecord"]
iter = data_shae.create_iterator(
    stage_1=[], stage_2=stage_2, batch_size=batch_size, shuffle=True
)
process_batch = model.process_batch_shae

In [13]:
batch = next(iter)

Found 0 files for stage 1
Found 2868 files for stage 2


In [14]:
from functools import partial

def fwd(weights, x, segment_ids):
    _, internals, x = model.forward(x, segment_ids, weights, cfg)
    last_nonzero = jnp.sum(segment_ids > 0, axis=-1)
    indices = last_nonzero[:, None, None] - 1
    last_xs = jnp.take_along_axis(x, indices, 1)
    return last_xs, internals

def input_shardings(
    mesh, rules
) -> tuple[jax.sharding.NamedSharding, jax.sharding.NamedSharding, jax.sharding.NamedSharding]:
    logical_axes = {
        "x": model.P("batch", "sequence"),
        "segment_ids": model.P("batch", "sequence"),
    }
    physical_axes = jax.tree.map(partial(model._logical_to_sharding, mesh=mesh, rules=rules), logical_axes)
    return physical_axes


fwd = jax.jit(fwd)

In [15]:
jax.random.PRNGKey

<function jax._src.random.PRNGKey(seed: 'int | ArrayLike', *, impl: 'PRNGSpecDesc | None' = None) -> 'KeyArray'>

In [27]:
# Make an SAE

features = 4096

expand = jax.nn.initializers.he_normal(in_axis=0, out_axis=1)(jax.random.PRNGKey(0), (cfg.d_model, features), jnp.float32)
contract = jax.nn.initializers.he_normal(in_axis=0, out_axis=1)(jax.random.PRNGKey(1), (features, cfg.d_model), jnp.float32)
l1_coeff = 1e-3
lr = 3e-4

sae_weights = {
    'expand': expand,
    'contract': contract,
}

def sae_shardings(
    mesh, rules
) -> tuple[jax.sharding.NamedSharding, jax.sharding.NamedSharding, jax.sharding.NamedSharding]:
    logical_axes = {
        "expand": model.P("batch", "ffw"),
        "contract": model.P("ffw", "batch"),
    }
    physical_axes = jax.tree.map(partial(model._logical_to_sharding, mesh=mesh, rules=rules), logical_axes)
    return physical_axes

sae_weights = jax.device_put(sae_weights, sae_shardings(cfg.mesh, cfg.rules))
sae_opt_state = model.init_optimizer_state(sae_weights)

def fwd_sae(sae_weights,
            activations, # [B, T, D]
            ):
    activations = activations.reshape(-1, cfg.d_model) # [B*T, D]
    latents = jnp.einsum('bd,df->bf', activations, sae_weights['expand'])
    latents = jax.nn.relu(latents)
    reconstructed = jnp.einsum('bf,fd->bd', latents, sae_weights['contract'])
    reconstruction_loss = jnp.mean((reconstructed-activations)**2)
    l1_loss = l1_coeff * jnp.sum(latents)
    loss = reconstruction_loss + l1_loss
    return loss, {'latents': latents, 'reconstruction_loss': reconstruction_loss, 'l1_loss': l1_loss}

grad_sae = jax.value_and_grad(fwd_sae)

def update_weights_sae(weights,
                       opt_state,
                       activations,
                       step):
    (loss, internals), grads = jax.value_and_grad(fwd_sae, has_aux=True)(weights, activations)
    weights, opt_state, _ = model.update_weights(weights, grads, opt_state, lr, step, cfg, {})
    return weights, opt_state, loss, internals
    

update_weights_fn = jax.jit(update_weights_sae)


In [14]:
for batch in iter:
    batch = jax.device_put({'x': batch['x'], 'segment_ids': batch['segment_ids']}, input_shardings(cfg.mesh, cfg.rules))
    _, internals = fwd(weights, batch['x'], batch['segment_ids'])
    activations = internals['layer_6_activations']
    sae_weights, sae_opt_state, loss, internals = update_weights_fn(sae_weights, sae_opt_state, activations, step=1)
    print(internals['reconstruction_loss'], internals['l1_loss'])



-


In [16]:
jax.tree.map(jnp.shape, internals)

{'layer_6_activations': (8, 8192, 2048),
 'layers': [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]}