In [None]:
from helpers.eqx import named_parameters

for pn, p in named_parameters(gpt):
    print(pn, p.shape)

In [None]:
temp = jax.random.normal(key, (5,))
temp.shape

In [None]:
t, = temp.shape
t

In [None]:
learning_rate = 5e-5
b1 = 0.001
b2 = 0.005

In [None]:
def create_conditional_tree(tree, condition):
  """Creates a new tree based on a condition applied to leaf nodes.

  Args:
    tree: The input JAX tree.
    condition: A function that takes a leaf node and returns a boolean.

  Returns:
    A new JAX tree with leaves modified based on the condition.
  """

  def tree_fn(node, **kwargs):
    if jax.tree_leaves(tree)[0].ndim > 2:
      return False
    else:
      return True

  return jax.tree.map(tree_fn, tree)

# Example usage
conditional_tree = create_conditional_tree(gpt, lambda x: x.ndim > 2)

print(conditional_tree)

In [None]:
from helpers.eqx import named_parameters
import optax

param_dict = {pn: p for pn, p in named_parameters(gpt)}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.ndim >= 2]
decay_param_tree = eqx.filter(gpt, lambda l: any([jnp.array_equal(l, x) for x in decay_params]), replace=False)
decay_param_tree = eqx.filter(decay_param_tree, lambda l: l is False, replace=True)
nodecay_params = [p for n, p in param_dict.items() if p.ndim < 2]

num_decay_params = sum(jax.numpy.size(p) for p in decay_params)
num_nodecay_params = sum(jax.numpy.size(p) for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW optimizer and use the fused version if it is available

optimizer = optax.adamw(learning_rate=learning_rate, b1=b1, b2=b2, weight_decay=1e-1, mask=decay_param_tree)

## TODO: MEKA RUN KRLA BALAPAAAAAN

In [1]:
import sys
sys.path.insert(0, '..')

In [2]:
from model import GPT, GPTConfig
import jax
import equinox as eqx
import jax.numpy as jnp

In [3]:
key = jax.random.PRNGKey(0)

gpt = GPT.create_instance(GPTConfig(), key)

In [None]:
temp = jax.random.normal(key, (5,))
temp.shape

In [None]:
learning_rate = 5e-5
b1 = 0.001
b2 = 0.005

In [15]:
from helpers.eqx import named_parameters
import optax

param_dict = {pn: p for pn, p in named_parameters(gpt)}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.ndim >= 2]
decay_param_tree = eqx.filter(gpt, lambda l: any([jnp.array_equal(l, x) for x in decay_params]), replace=False)
decay_param_tree = eqx.filter(decay_param_tree, lambda l: l is False, replace=True)
nodecay_params = [p for n, p in param_dict.items() if p.ndim < 2]

num_decay_params = sum(jax.numpy.size(p) for p in decay_params)
num_nodecay_params = sum(jax.numpy.size(p) for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW optimizer and use the fused version if it is available

optimizer = optax.adamw(learning_rate=learning_rate, b1=b1, b2=b2, weight_decay=1e-1, mask=decay_param_tree)

## TODO: MEKA RUN KRLA BALAPAAAAAN

In [14]:
len(decay_params), len(nodecay_params)

In [None]:
for path, p in jax.tree_util.tree_flatten_with_path(gpt)[0]:
    print(path)

In [None]:
for param in params:
    print(param.shape)

In [None]:
from typing import Callable

def named_parameters(model: eqx.Module, is_leaf: Callable = None):
    out = []

    for path, p in jax.tree_util.tree_flatten_with_path(model, is_leaf=is_leaf)[0]:
        if not eqx.is_array(p):
            continue
        pn = ''

        for index in range(len(path)):
            if isinstance(path[index], jax._src.tree_util.DictKey):
                pn += '.' + path[index].key
            else:
                pn += str(path[index])

        out.append((pn, p))
    
    return out

In [None]:
for pn, p in named_parameters(gpt):
    print(pn, p.shape)

In [None]:
def find_sub_tree(model: eqx.Module, sub_tree_name: str, filter_fn: Callable = None):
    out = []
    for path, p in jax.tree_util.tree_flatten_with_path(model, is_leaf=filter_fn)[0]:
        pn = ''

        for index in range(len(path)):
            if isinstance(path[index], jax._src.tree_util.DictKey):
                pn += '.' + path[index].key
            else:
                pn += str(path[index])

        if filter_fn:
            if filter_fn(p) and pn.endswith(sub_tree_name):
                out.append(p)
        elif pn.endswith(sub_tree_name):
            out.append(p)

    return out