In [1]:
import torch
import torch.nn as nn

In [2]:
data = torch.load('/home/lliu/huffman/test/original_weights.pt')

In [3]:
#move the tensors cuda:7
device = torch.device("cuda:6")

for key in data.keys():
    if isinstance(data[key], torch.Tensor):
        data[key] = data[key].to(device)
    elif isinstance(data[key], list):
        for i in range(len(data[key])):
            data[key][i] = data[key][i].to(device)

In [None]:
data

In [5]:
import tqdm


@torch.jit.script
def assigment_step(W:torch.Tensor, H:torch.Tensor, quantized_vectors:torch.Tensor):
    """_summary_

    Args:
        W (torch.tensor): weights of shape (n,n)
        H (torch.tensor): matrix of shape (n,n)
        quantized_vectors (torch.tensor): quantized vectors of shape (n,k)
            where k is the number of quantized vectors
    
    Returns:
        torch.tensor: updated assignments
        torch.tensor: updated errors
    """

    #create a tensor of shape (n,k,n)
    #where the slice [i,k,:] consists of 
    # W[i] - quantized_vectors[k]   
    #minus the k quantized vectors
    # print(W.dtype)
    # print(torch.max(H))
    # print(quantized_vectors.shape)
    assert torch.all(torch.isfinite(W)), f"W is not finite, {W}, {W[~torch.isfinite(W)]}"
    assert torch.all(torch.isfinite(quantized_vectors)), f"quantized_vectors is not finite, {quantized_vectors}, {quantized_vectors[~torch.isfinite(quantized_vectors)]}"
    assignments = torch.zeros(W.shape[0], dtype=torch.long)
    error = torch.tensor(0.0, dtype=W.dtype, device=W.device)
    for i in range(W.shape[0]):
        diff = W[i].unsqueeze(0) - quantized_vectors.T
        errors = torch.einsum('jk,kl,jl->j', diff, H, diff)
        assignments[i] = torch.argmin(errors)
        error += errors[assignments[i]]
    return assignments, error

@torch.jit.script
def update_step(W:torch.Tensor, prev_quantized:torch.Tensor, assignments:torch.Tensor)->torch.Tensor:
    """_summary_

    Args:
        W (torch.tensor): weights of shape (n,n)
        assignments (torch.tensor): assignments of shape (n,)
    
    Returns:
        quantized_vectors (torch.tensor): quantized vectors of shape (n,k)
    """

    #initialize the updated quantized vectors
    updated_quantized_vectors = torch.zeros_like(prev_quantized)

    #the quantized vectors are just the mean of the weights
    #that are assigned to the same cluster
    for i in range(updated_quantized_vectors.shape[0]):
        # assert torch.all(torch.isfinite(W[assignments == i])), f"W[assigments == i] is not finite, {W[assignments == i]}, {W[assignments == i][~torch.isfinite(W[assignments == i])]}"
        if torch.any(assignments == i):
            updated_quantized_vectors[:,i] = W[assignments == i].mean(dim=0)
    
    return updated_quantized_vectors





def vector_quantize(W, H, k, max_iters = 1000, 
                    max_init_iters = 10,
                    convergence_threshold = 1e-3):
    """_summary_

    Args:
        W (torch.tensor): weights of shape (n,n)
        H (torch.tensor): matrix of shape (n,n)
        k (int): number of quantized vectors
    
    Returns:
        torch.tensor: quantized vectors of shape (n,k)
    """
    assert torch.all(torch.isfinite(H)), f"H is not finite, {H}, {H[~torch.isfinite(H)]}"
    min_error = float('inf')
    bar = tqdm.tqdm(total=max_init_iters*max_iters)
    for i in range(max_init_iters):
        #initialize the quantized vectors
        indexs = torch.randperm(W.shape[1])[:k]
        quantized_vectors = W[indexs,:].T
        converged = False
        for i in range(max_iters):
            bar.update(1)
            updated_assignments,error = assigment_step(W, H, quantized_vectors)
            # print("error", error)
            quantized_vectors = update_step(W, quantized_vectors, updated_assignments)
            # print(updated_quantized_vectors.shape)
            if i != 0:
                if torch.allclose(assignments, updated_assignments, atol=convergence_threshold):
                    print(f"Converged after {i} iterations, error {error}")
                    converged = True
                    bar.update(max_iters - i - 1)
                    break
            assignments = updated_assignments
        if not converged:
            print("warning: did not converge")
        if error < min_error:
            min_error = error
            best_quantized_vectors = quantized_vectors
            best_assignments = assignments
    print("quantized with best error", min_error)
    return best_quantized_vectors, best_assignments

In [None]:
W = data["weights"].half()
H = data["H"].half()

quantized_vectors, assignments = vector_quantize(W, H/H.shape[0], 256)

In [7]:
quantized_weights = quantized_vectors[:,assignments].T

In [None]:
torch.sum((quantized_weights - W)**2)/torch.sum(W**2)

In [None]:
quantized_weights

In [10]:
data = torch.load("/home/lliu/huffman/test/original_weights2.pt")

device = torch.device("cuda:6")

for key in data.keys():
    # if isinstance(data[key], torch.Tensor):
    #     data[key] = data[key].to(device)
    if isinstance(data[key], list):
        for i in range(len(data[key])):
            data[key][i] = data[key][i].to(device)


In [11]:
for key in data.keys():
    if isinstance(data[key], list):
        data[key] = torch.stack(data[key], dim=0)

In [12]:
x = data["Input"].reshape(-1, W.shape[0]).to(W.dtype)
y = data["Output"].reshape(-1, W.shape[0]).to(W.dtype)

In [None]:
y_hat = x @ quantized_weights.T
print(y_hat)
print(torch.mean((y_hat - y)**2).item()/torch.mean(y**2).item())

In [14]:
#plot a pca of the weights
import matplotlib.pyplot as plt

from sklearn.decomposition import PCA
#also get a TSNE
from sklearn.manifold import TSNE

low_dim = TSNE(n_components=2)

x,y = low_dim.fit_transform(W.T.cpu().numpy()).T

In [15]:
#sort the quantized vectors by cosine similarity
from sklearn.metrics.pairwise import cosine_similarity

cosine_sim = cosine_similarity(quantized_vectors.T.cpu().numpy())

In [16]:
cosine_sim[:,0]

lookup_map = {j.item():i for i,j in enumerate(torch.argsort(torch.tensor(cosine_sim[:,0])))}

In [None]:
lookup_map

In [None]:
torch.argsort(torch.tensor(cosine_sim[:,0]))

In [None]:


colors = [plt.cm.viridis(lookup_map[i]/256) for i in assignments.cpu().numpy()]

plt.figure(figsize=(10,10)) 
plt.yscale('linear')
plt.scatter(x, y, c=colors)

In [None]:
#plot out the pca with the colors of the assignments
plt.scatter(pca.(W.cpu().numpy())[:,0], pca.transform(W.cpu().numpy())[:,1], c=assignments.cpu().numpy())