Notebook to explore caching computed keys and values. Problems to solve:
- How to use flax to store these as variables
- Will they fit in memory?

In [23]:
from nimblegpt import get_config_for

In [24]:
config = get_config_for("gpt2")

In [25]:
config

attn_pdrop: 0.1
block_size: 1024
embd_pdrop: 0.1
model_type: gpt2
n_embd: 768
n_head: 12
n_layer: 12
resid_pdrop: 0.1
vocab_size: 50257

In [26]:
n_feat = config.n_embd // config.n_head
# Q/K/V are [config.block_size, n_feat]
n_cache_params = config.block_size * n_feat * 2 * config.n_layer
f"{n_cache_params:,}"

'1,572,864'

K/V parameters for the entire context are on the order of 1 MB.

# Test how parameters are named with using pmodel

In [27]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from functools import partial, partialmethod

In [28]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
from nimblegpt.kernels import pmodel
from nimblegpt.jmodel import JSingleHeadCausalSelfAttention
from nimblegpt import param_shapes

In [30]:
def partialclass(cls, *args, **kwargs):
    class PartialClass(cls):
        __init__ = partialmethod(cls.__init__, *args, **kwargs)
    return PartialClass

In [31]:
class A(nn.Module):
    name = "meme"
    __name__ = "meme3"
A.__name__ = "meme2"

In [32]:
A().__class__.__name__

'meme2'

In [33]:
class B(A):
    pass

In [34]:
B().__class__.__name__

'B'

In [35]:
A.name

In [36]:
A().name

In [37]:
A.__dict__

mappingproxy({'__module__': '__main__',
              'name': None,
              '__name__': 'meme3',
              '__doc__': 'A(parent: Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f317eb37040>, name: Optional[str] = None)',
              '__annotations__': {'parent': typing.Union[typing.Type[flax.linen.module.Module], typing.Type[flax.core.scope.Scope], typing.Type[flax.linen.module._Sentinel], NoneType],
               'name': typing.Optional[str]},
              'parent': <flax.linen.module.ParentDescriptor at 0x7f31b1ea6440>,
              '__dataclass_params__': _DataclassParams(init=True,repr=False,eq=True,order=False,unsafe_hash=True,frozen=False),
              '__dataclass_fields__': {'parent': Field(name='parent',type=typing.Union[typing.Type[flax.linen.module.Module], typing.Type[flax.core.scope.Scope], typing.Type[flax.linen.module._Sentinel], NoneType],default=

In [38]:
class B():
    name = "meme2"

In [39]:
B.name

'meme2'

In [40]:
B().name

'meme2'

In [41]:
gpt_module = pmodel.PGPT.MakeWithSHCSA(config, JSingleHeadCausalSelfAttention)

In [42]:
JSingleHeadCausalSelfAttention.__init__??

[0;31mSignature:[0m
[0mJSingleHeadCausalSelfAttention[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mself[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mn_feat[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mparent[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mType[0m[0;34m[[0m[0mflax[0m[0;34m.[0m[0mlinen[0m[0;34m.[0m[0mmodule[0m[0;34m.[0m[0mModule[0m[0;34m][0m[0;34m,[0m [0mType[0m[0;34m[[0m[0mflax[0m[0;34m.[0m[0mcore[0m[0;34m.[0m[0mscope[0m[0;34m.[0m[0mScope[0m[0;34m][0m[0;34m,[0m [0mType[0m[0;34m[[0m[0mflax[0m[0;34m.[0m[0mlinen[0m[0;34m.[0m[0mmodule[0m[0;34m.[0m[0m_Sentinel[0m[0;34m][0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;34m<[0m[0mflax[0m[0;34m.[0m[0mlinen[0m[0;34m.[0m[0mmodule[0m[0;34m.[0m[0m_Sentinel[0m [0mobject[0m [0mat[0m [0;36m0x7f317eb37040[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mname[0m[0;34m:[0m [0mO

In [43]:
JSingleHeadCausalSelfAttention(n_feat=2).init(jax.random.PRNGKey(0), jnp.ones((3, 4)))

FrozenDict({
    params: {
        Dense_0: {
            kernel: Array([[-0.38087067, -0.2863471 , -0.17491162,  0.07998355,  0.378069  ,
                     0.9020071 ],
                   [ 0.3810474 ,  0.5744289 , -0.79939866, -0.18784852,  0.61980057,
                    -0.6187186 ],
                   [-0.01715958, -0.4304841 ,  0.10749505, -0.85787666,  0.5338018 ,
                     0.9018709 ],
                   [-0.26826772,  0.44272065, -0.07614978,  0.26189852,  0.8921967 ,
                     0.39000404]], dtype=float32),
            bias: Array([0., 0., 0., 0., 0., 0.], dtype=float32),
        },
    },
})

In [44]:
gpt_module.name

In [45]:
class Test:

    def __setattr__(self, name, val):
        print(name, val)

t = Test()

In [46]:
t.name = "meme"

name meme


In [47]:
setattr(t, "name", "meme")

name meme


In [48]:
t.__dict__["name"] = "meme"

In [49]:
t.name

'meme'

In [50]:
params = gpt_module.init(jax.random.PRNGKey(0), jnp.ones((3,), dtype=jnp.int32))

In [51]:
gpt_module.apply(params, jnp.ones((3,), dtype=jnp.int32))

Array([[ 1.1961181 ,  0.99947613, -1.8678186 , ..., -0.02459382,
         0.46241692,  1.0360956 ],
       [ 0.937554  ,  1.1457242 , -1.8385999 , ..., -0.32137394,
         0.5954865 ,  0.5541202 ],
       [ 1.0446986 ,  1.1978903 , -1.7543857 , ..., -0.48017293,
         0.79395825,  0.53017473]], dtype=float32)

In [52]:
param_shapes(params)

{'params': {'Embed_0': {'embedding': '(50257, 768)'},
  'Embed_1': {'embedding': '(1024, 768)'},
  'Block_0': {'LayerNorm_0': {'scale': '(768)', 'bias': '(768)'},
   'CausalSelfAttention_0': {'VmapSingleHeadCausalSelfAttention_0': {'Dense_0': {'bias': '(12, 192)',
      'kernel': '(12, 768, 192)'}},
    'Dense_0': {'kernel': '(768, 768)', 'bias': '(768)'}},
   'LayerNorm_1': {'scale': '(768)', 'bias': '(768)'},
   'Dense_0': {'kernel': '(768, 3072)', 'bias': '(3072)'},
   'Dense_1': {'kernel': '(3072, 768)', 'bias': '(768)'}},
  'Block_1': {'LayerNorm_0': {'scale': '(768)', 'bias': '(768)'},
   'CausalSelfAttention_0': {'VmapSingleHeadCausalSelfAttention_0': {'Dense_0': {'bias': '(12, 192)',
      'kernel': '(12, 768, 192)'}},
    'Dense_0': {'kernel': '(768, 768)', 'bias': '(768)'}},
   'LayerNorm_1': {'scale': '(768)', 'bias': '(768)'},
   'Dense_0': {'kernel': '(768, 3072)', 'bias': '(3072)'},
   'Dense_1': {'kernel': '(3072, 768)', 'bias': '(768)'}},
  'Block_2': {'LayerNorm_0': {'

In [53]:
from dataclasses import dataclass

class MC:

    def __getattr__(cls, key):
        print(key)
        return super().__getattr__(key)

@dataclass
class A:
    
    def __post_init__(self):
        self.__class__.__name__ = "meme"

In [54]:
A.__name__, A().__class__.__name__

('A', 'meme')

In [55]:
type(A()), A.__class__

(__main__.A, type)

In [56]:
type(A()).__name__

'meme'

In [57]:
A.__class__

type

In [58]:
class B(A):
    pass

In [59]:
B.__name__, B().__class__.__name__

('B', 'meme')

In [60]:
type(B()).__name__

'meme'

In [61]:
B.__name__

'meme'

In [62]:
B().__class__.__name__

'meme'

In [69]:
type("C", (B,), {})().__class__.mro().index(A)

2

In [1]:
class P:
    name = "meme"

In [2]:
P.name

'meme'

In [6]:
from dataclasses import dataclass

@dataclass
class C:
    name: str

In [7]:
from functools import partial

In [8]:
C2 = partial(C, name="meme")

In [9]:
C2()

C(name='meme')

In [10]:
C2(name="meme2")

C(name='meme2')

In [11]:
def name(name):
    def decorator(cls):
        return partial(cls, name=name)
    return decorator

In [12]:
@name("meme")
@dataclass
class C:
    name: str

In [13]:
C(name = "meme2").name

'meme2'

In [14]:
@dataclass
class D:
    
    def __post_init__(self):
        self.__class__.__name__ = "meme"

    def __init_subclass__(cls) -> None:
        cls.__name__ = "meme2"

In [15]:
class E(D):
    pass

In [16]:
E.__name__

'meme2'

In [17]:
E.mro()

[__main__.E, __main__.D, object]

In [18]:
type(E)

type

In [24]:
type(type("custom", (E,), {}))

type

In [19]:
import inspect

In [20]:
from nimblegpt.jmodel import JSingleHeadCausalSelfAttention

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
inspect.getfile(JSingleHeadCausalSelfAttention)

'/home/trist/nimbleGPT/nimblegpt/jmodel.py'

In [22]:
inspect.getfile(type("custom", (JSingleHeadCausalSelfAttention,), {}))

'/opt/conda/lib/python3.10/abc.py'

In [23]:
type("custom", (JSingleHeadCausalSelfAttention,), {}).__module__

'abc'