## Goals:

1. Load the weights 
2. Kronecker Decomp.:
   * $K$ into $K_1, K_2$
   * $Q$ into $Q_1, Q_2$
3. check the attention
   * Project the matrix
   * Optimize $K Q^T$ and check correctness of compute

In [49]:
!pip install 


Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[K     |████████████████████████████████| 44 kB 219 kB/s eta 0:00:011
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [52]:
from einops import rearrange
from torch import Tensor
from typing import Tuple

def kronecker_decompose(A: Tensor, m: int, n: int, *, k: int = 1, niter: int = 10) -> Tuple[Tensor, Tensor]:
    """

      Frobenius-optimal decomposition of `A` into a sum of `k` Kronecker products.
      Algorithm from Van Loan and Pitsianis (1993),
      "Approximation with Kronecker Products"
      <https://bit.ly/46hT5aY>.

    Args:
        A: Matrix or batch of matrices to decompose, of shape (..., m * m2, n * n2)
        m: Desired number of rows in the left Kronecker factor(s)
        n: Desired number of columns in the left Kronecker factor(s)
        k: Number of Kronecker factors
        niter: Number of iterations for the low rank SVD algorithm
    Returns:
        Tuple of Kronecker factors (`left`, `right`) of shape `(..., k, m, n)` and
        `(..., k, A.shape[-2] // m, A.shape[-1] // n)` respectively.

    Raises:
        AssertionError: If the dimensions of `A` are not compatible with the desired
            number of rows and columns in the left Kronecker factor.

    """

    m2, n2 = A.shape[-2] // m, A.shape[-1] // n
    assert A.shape[-2:] == (m * m2, n * n2), "Dimensions do not match"

    # Reshape and permute A, then perform SVD
    A = rearrange(A, "... (m m2) (n n2) -> ... (m n) (m2 n2)", m=m, m2=m2, n=n, n2=n2)
    u, s, v = torch.svd_lowrank(A, q=k, niter=niter)

    # Unflatten the factors
    u = rearrange(u, "... (m n) k -> ... k m n", m=m, n=n, k=k)
    v = rearrange(v, "... (m2 n2) k -> ... k m2 n2", m2=m2, n2=n2, k=k)

    scale = s[..., None, None].sqrt()
    return u * scale, v * scale

In [26]:
import torch

device = "cuda"


checkpoint = torch.load('out-shakespeare-char/ckpt.pt')
checkpoint_model_args = checkpoint['model_args']

#model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
#                  bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line

              

In [71]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device", device)

Device cuda


In [29]:
checkpoint_model_args.keys(), checkpoint.keys()

(dict_keys(['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size', 'dropout']),
 dict_keys(['model', 'optimizer', 'model_args', 'iter_num', 'best_val_loss', 'config']))

In [53]:
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

for i, j in state_dict.items():
    print(i)

transformer.wte.weight
transformer.wpe.weight
transformer.h.0.ln_1.weight
transformer.h.0.attn.c_attn.weight
transformer.h.0.attn.c_proj.weight
transformer.h.0.ln_2.weight
transformer.h.0.mlp.c_fc.weight
transformer.h.0.mlp.c_proj.weight
transformer.h.1.ln_1.weight
transformer.h.1.attn.c_attn.weight
transformer.h.1.attn.c_proj.weight
transformer.h.1.ln_2.weight
transformer.h.1.mlp.c_fc.weight
transformer.h.1.mlp.c_proj.weight
transformer.h.2.ln_1.weight
transformer.h.2.attn.c_attn.weight
transformer.h.2.attn.c_proj.weight
transformer.h.2.ln_2.weight
transformer.h.2.mlp.c_fc.weight
transformer.h.2.mlp.c_proj.weight
transformer.h.3.ln_1.weight
transformer.h.3.attn.c_attn.weight
transformer.h.3.attn.c_proj.weight
transformer.h.3.ln_2.weight
transformer.h.3.mlp.c_fc.weight
transformer.h.3.mlp.c_proj.weight
transformer.h.4.ln_1.weight
transformer.h.4.attn.c_attn.weight
transformer.h.4.attn.c_proj.weight
transformer.h.4.ln_2.weight
transformer.h.4.mlp.c_fc.weight
transformer.h.4.mlp.c_proj.w

In [59]:
k_q_v = state_dict["transformer.h.0.attn.c_attn.weight"]

kk,q,v = k_q_v.split(384, dim=0)

k_q_v.shape, kk.shape, q.shape, v.shape

(torch.Size([1152, 384]),
 torch.Size([384, 384]),
 torch.Size([384, 384]),
 torch.Size([384, 384]))

In [103]:
k1, k2  = kronecker_decompose(kk, 384, 192 ,k=2)
q1, q2  = kronecker_decompose(q, 384, 192, k=2)
v1, v2  = kronecker_decompose(v, 384, 192, k=2)

k1[0].shape, k2[0].shape

(torch.Size([384, 192]), torch.Size([1, 2]))

In [107]:
s = torch.zeros(384, 384)

s = s.to(device)

for i in range(1):
    s += torch.kron(q1[i], q2[i])

s

tensor([[-0.0309,  0.0374,  0.0180,  ...,  0.0112,  0.0632, -0.0764],
        [-0.0830,  0.1004, -0.0485,  ...,  0.0371,  0.0017, -0.0020],
        [-0.0405,  0.0490,  0.0092,  ...,  0.0081, -0.0278,  0.0336],
        ...,
        [ 0.0804, -0.0972,  0.0399,  ..., -0.0204, -0.0327,  0.0396],
        [-0.0142,  0.0172, -0.0006,  ..., -0.0371, -0.0044,  0.0054],
        [-0.0096,  0.0116, -0.0151,  ...,  0.0752, -0.0128,  0.0154]],
       device='cuda:0')

In [111]:
K = torch.kron(k1[1], k2[1])
Q = torch.kron(q1[1], q2[1])

K.shape, Q.shape

(torch.Size([384, 384]), torch.Size([384, 384]))

In [113]:
checkpoint['config']

{'out_dir': 'out-shakespeare-char',
 'eval_interval': 250,
 'log_interval': 10,
 'eval_iters': 200,
 'eval_only': False,
 'always_save_checkpoint': False,
 'init_from': 'scratch',
 'wandb_log': False,
 'wandb_project': 'shakespeare-char',
 'wandb_run_name': 'mini-gpt',
 'dataset': 'shakespeare_char',
 'gradient_accumulation_steps': 1,
 'batch_size': 64,
 'block_size': 256,
 'n_layer': 6,
 'n_head': 6,
 'n_embd': 384,
 'dropout': 0.2,
 'bias': False,
 'learning_rate': 0.001,
 'max_iters': 5000,
 'weight_decay': 0.1,
 'beta1': 0.9,
 'beta2': 0.99,
 'grad_clip': 1.0,
 'decay_lr': True,
 'warmup_iters': 100,
 'lr_decay_iters': 5000,
 'min_lr': 0.0001,
 'backend': 'nccl',
 'device': 'cuda',
 'dtype': 'bfloat16',
 'compile': True}

In [117]:
batch_size, seqlen, n_embd = 64, 256, 384

x = torch.randn(batch_size, seqlen, n_embd, device = device)


In [123]:
xk = x@K
xq = x@Q.transpose()

xk.shape, xq.shape

(torch.Size([64, 256, 384]), torch.Size([64, 256, 384]))

In [125]:
384 / 6

64.0

In [130]:
from torch import nn

m = nn.Linear(20, 30)
ii = torch.randn(128, 20)
output = m.forward(ii)
output.shape


Add this to your pytorch series    XW   always batch sizes first

SyntaxError: invalid syntax (1452955411.py, line 9)