In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [44]:
#hyperparameters

codebook_size = 2**(2*3)
codebook_dim = 2
device = torch.device("cuda:7")


In [45]:
data = torch.load("/home/lliu/huffman/test/original_weights2.pt")
for key in data.keys():
    if isinstance(data[key], torch.Tensor):
        data[key] = data[key].to(device)
    elif isinstance(data[key], list):
        for i in range(len(data[key])):
            data[key][i] = data[key][i].to(device)

#if a list stack the tensors
def stack_tensors(tensors):
    return torch.stack(tensors, dim=0)

for key in data.keys():
    if isinstance(data[key], list):
        data[key] = stack_tensors(data[key])
    

In [None]:

H = torch.load("/home/lliu/huffman/test/original_weights.pt")["H"].to(device)
H

In [47]:
H = H/H.shape[0]
weights = data["weights"]

In [None]:
#make a codebook through kmeans
from sklearn.cluster import KMeans
import numpy as np

kmeans = KMeans(n_clusters=codebook_size, random_state=0).fit(weights.reshape(-1, codebook_dim).cpu().numpy())

In [49]:
class reconstruct(torch.autograd.Function):

    @staticmethod
    def forward(ctx, assignments, codebook):
        ctx.save_for_backward(assignments, codebook)
        return torch.sum(assignments.unsqueeze(-1)*codebook, dim=2).reshape(weights.shape)
    
    @staticmethod
    def backward(ctx, grad_output):
        assignments, codebook = ctx.saved_tensors
        # print(grad_output.shape)
        # print(assignments.shape)
        # print(codebook.shape)
        grad_per_vector_block = grad_output.reshape(grad_output.shape[0], grad_output.shape[1]//codebook_dim, codebook_dim)
        codebook_max = codebook[torch.argmax(assignments, dim=2),:]
        # print(codebook_max.shape)
        # print(assignments.shape)
        custom_grad_func = codebook_max.unsqueeze(2) - codebook
        custom_grad_func = -torch.sign(custom_grad_func)/(1+torch.abs(custom_grad_func))
        # print(custom_grad_func.shape)
        grad_assignments = torch.sum(grad_per_vector_block.unsqueeze(2)*custom_grad_func, dim=3)
        # print(grad_assignments.shape)
        # raise Exception
        # grad_assignments = torch.sum(grad_per_vector_block.unsqueeze(-1)*codebook.T, dim=2)
        # print(grad_assignments.shape)
        # # print(grad_per_vector_block.shape)
        # print(grad_per_vector_block.unsqueeze(-1).shape, assignments.shape)
        grad_codebook = torch.sum(grad_per_vector_block.unsqueeze(-1)*assignments.unsqueeze(2), dim = (0,1)).T
        # print(grad_assignments.shape, grad_codebook.shape)
        return grad_assignments, grad_codebook
        raise Exception
        grad_assignments = grad_output.unsqueeze(-1)*codebook
        grad_codebook = torch.sum(grad_output.unsqueeze(-1)*assignments, dim=0)
        return grad_assignments, grad_codebook
    

In [10]:
# codebook = torch.tensor(kmeans.cluster_centers_, device=device, dtype = weights.dtype)

In [None]:

torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
assignments = torch.ones(weights.shape[0], weights.shape[1]//codebook_dim, codebook_size, device=device,
                            dtype=weights.dtype, requires_grad=True)
codebook = torch.tensor(kmeans.cluster_centers_, device=device, dtype = weights.dtype,
                        requires_grad=True)

x = data["Input"].reshape(-1, weights.shape[0]).to(weights.dtype)
y = data["Output"].reshape(-1, weights.shape[0]).to(weights.dtype)



def compute_reconstructed(assignments, codebook,activation = nn.Identity()):
    return torch.sum(activation(assignments.unsqueeze(-1)*codebook), dim=2).reshape(weights.shape)

# compute_reconstructed = reconstruct.apply

n_iters = 10000
lr = 1

beta_init = 10
beta_final = 1  
lambda_ = 1

import tqdm
losses = []
x_losses = []
one_hot_losses = []
reconstruction_losses = []

H = H.half()

# assignments_activation = nn.Identity()
assignments_activation = nn.Softmax(dim=2)

for i in tqdm.tqdm(range(n_iters)):
    assignments_soft = assignments_activation(assignments)
    # print(assignments_soft)
    reconstructed = compute_reconstructed(assignments_soft, codebook)
    # print(reconstructed)
    diff = weights - reconstructed
    loss_reconstruction = torch.sum(diff**2, dim=(1)).mean()*300
    # x_loss = torch.sum((x @ reconstructed.T - y)**2, dim=1).mean()
    x_loss = torch.einsum('ik,kl,il->', diff, H, diff)
    beta = beta_init + (beta_final - beta_init)*i/n_iters
    one_hot_loss = torch.sum(1-torch.abs(2*assignments_soft-1)**beta,dim = 1).mean()
    # print("one_hot_loss", one_hot_loss.item())
    # one_hot_loss = 0
    loss = loss_reconstruction + x_loss + lambda_*one_hot_loss
    # assert torch.isfinite(loss).all()
    loss.backward()
    # print("loss", loss.item(), "loss_reconstruction", 1/300*loss_reconstruction.item()/(torch.sum(weights**2,dim = 1).mean().item()),
    #              "x_loss", x_loss.item(), "one_hot_loss", one_hot_loss.item())
    if i%100 == 0:
        print("loss", loss.item(), "loss_reconstruction", 1/300*loss_reconstruction.item()/(torch.sum(weights**2,dim = 1).mean().item()),
                 "x_loss", x_loss.item(), "one_hot_loss", one_hot_loss.item())
    losses.append(loss.item())
    x_losses.append(x_loss.item())
    # one_hot_losses.append(one_hot_loss.item())
    reconstruction_losses.append(loss_reconstruction.item())
    assert torch.isfinite(loss).all()
    # raise Exception

    with torch.no_grad():
        if i%100 == 0:

            print("assignments.grad", torch.norm(assignments.grad).item(),
                    "codebook.grad", torch.norm(codebook.grad).item())
            # print(assignments_soft.grad)
            print("assignments", assignments[0,0,0].item(),lr*assignments.grad[0,0,0].item())
            print("codebook", codebook[0,0].item(), lr*1e-6*codebook.grad[0,0].item())
        assignments -= lr*assignments.grad
        codebook -= lr*1e-6*codebook.grad
        # print("codebook", codebook[0,0].item())
        # print("assignments", assignments[0,0,0])

        assignments.grad = None
        codebook.grad = None
    # raise Exception


In [None]:
import matplotlib.pyplot as plt

# plt.plot(losses)
plt.plot(np.array(x_losses)/x_losses[0])
# plt.plot(one_hot_losses)
plt.plot(np.array(reconstruction_losses)/reconstruction_losses[0])
plt.yscale("log")

In [None]:
np.ndarray(x_losses)

In [None]:
plt.hist(assignments_soft.detach().cpu().numpy().flatten(), bins=100)
plt.yscale("log")

In [None]:
(x.reshape(-1, weights.shape[0]) @ weights).reshape(y.shape)

In [None]:
assignments_final = assignments.argmax(dim=2)
# assignments_final = torch.softmax(assignments, dim=2)
assignments_final.shape

In [None]:
assignments_final

In [39]:
reconstructed_final = compute_reconstructed(F.one_hot(assignments_final, codebook_size).half(), codebook)

# reconstructed_final = compute_reconstructed(assignments_final, codebook)

In [None]:
y_hat = x @ reconstructed_final.T
print(y_hat)
print(torch.mean((y_hat - y)**2).item()/torch.mean(y**2).item())

In [None]:
y

In [None]:
weights_hat = reconstructed_final

print(torch.mean((weights_hat - weights)**2).item()/torch.mean(weights**2).item())

In [None]:
weights

In [None]:
(np.prod(assignments.shape)*16 + np.prod(codebook.shape)*16)/(np.prod(weights.shape)*16)*16

In [None]:
import time

import torch
import torch.nn as nn

def get_llama(model):
    import torch
    def skip(*args, **kwargs):
        pass
    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip
    from transformers import LlamaForCausalLM
    model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto')
    model.seqlen = 2048
    return model

model = get_llama("huggyllama/llama-7b")

In [None]:
model.model.layers[0].self_attn

In [None]:
3/16*4096*4

In [None]:
weights = []
for layer in model.model.layers:
    weights.append(layer.self_attn.q_proj.weight.clone())
    weights.append(layer.self_attn.k_proj.weight.clone())
    weights.append(layer.self_attn.v_proj.weight.clone())
    weights.append(layer.self_attn.o_proj.weight.clone())

weights