In [1]:
import torch
import torch.nn as nn
import numpy as np 
import random

import tqdm
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda:5")

H = torch.load("test/original_weights.pt")["H"].to(device).float()
weights = torch.load("test/original_weights.pt")["weights"].to(device).float()   

In [35]:
torch.random.manual_seed(0)
np.random.seed(0)
torch.cuda.manual_seed(0)
random.seed(0)
def create_mask(data,percent_top):
    """
    data: torch.tensor of shape (n)
    percent_top: float, the percentage of the top values to keep
    """

    threshold = torch.quantile(data, 1-percent_top/100)
    return data < threshold


d = 64
percent_dense_rowise = 2
percent_dense_columnwise = 3


row_mask = create_mask(torch.norm(weights, dim = 1), percent_dense_rowise)
column_mask = create_mask(torch.norm(weights, dim = 0), percent_dense_columnwise) & create_mask(torch.norm(H, dim = 0), percent_dense_columnwise)

print("row_mask.sum() = ", row_mask.sum())
print("column_mask.sum() = ", column_mask.sum())


def mask_round(mask, d):

    while mask.sum() % d != 0:
        mask[torch.randint(0, mask.shape[0], (1,))] = False

    return mask

row_mask = mask_round(row_mask, d)
column_mask = mask_round(column_mask, d)

mask = row_mask.unsqueeze(1) & column_mask.unsqueeze(0)

print("row_mask.sum() = ", row_mask.sum())
print("column_mask.sum() = ", column_mask.sum())

row_mask.sum() =  tensor(4014, device='cuda:5')
column_mask.sum() =  tensor(3957, device='cuda:5')
row_mask.sum() =  tensor(3968, device='cuda:5')
column_mask.sum() =  tensor(3904, device='cuda:5')


In [37]:
torch.random.manual_seed(0)
np.random.seed(0)
torch.cuda.manual_seed(0)
random.seed(0)
import tqdm
import sklearn.cluster

n_iters = 25
d = 64
k_low_rank = 3
keep_top = 0.01

# print("average bits:", (np.log2(k_magnitude_codebook) + np.log2(k_coseine_codebook))/d)

def get_bytes(bits):
    #return as MB
    return bits/8/1024/1024


overhead = k_low_rank * 16 * d

encoding_bits = (16 * k_low_rank)/d * torch.sum(mask).item()

sparse_bits = 16 * torch.sum(~mask).item()
print("sparse bits:", sparse_bits/(weights.shape[0] * weights.shape[1]),
        "encoding bits:", encoding_bits/(weights.shape[0] * weights.shape[1]),
        "overhead:", overhead/(weights.shape[0] * weights.shape[1]))
print("bits per value:", (sparse_bits + encoding_bits + overhead)/(weights.shape[0] * weights.shape[1]))
# raise ValueError


weights_masked = weights[row_mask,:][:,column_mask]
print("weights_masked.shape = ", weights_masked.shape)
subvector_assignments = torch.arange(weights_masked.shape[1]).reshape((-1, d))

weights_reshaped = weights_masked[:,subvector_assignments] 

sparse bits: 1.2265625 encoding bits: 0.6925048828125 overhead: 0.00018310546875
bits per value: 1.91925048828125
weights_masked.shape =  torch.Size([3968, 3904])


In [38]:
weights_reshaped.shape

torch.Size([3968, 61, 64])

In [39]:
class low_rank_quantizers(nn.Module):
    def __init__(self, d1,d2, k, d3): 
        super(low_rank_quantizers, self).__init__()

        self.weights = nn.parameter.Parameter(torch.randn(d1 * d2, k))
        self.codebook = nn.parameter.Parameter(torch.rand(k, d3))

        self.d1, self.d2, self.k, self.d3 = d1, d2, k, d3

    def forward(self):

        return (self.weights @ self.codebook).reshape(self.d1, self.d2, self.d3)

In [40]:
model = low_rank_quantizers(weights_reshaped.shape[0], weights_reshaped.shape[1], k_low_rank, d)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr = 1e-1)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 1, gamma = 0.9)
free, total = torch.cuda.mem_get_info(device)
print(f"free = {free/1024/1024}, total = {total/1024/1024}")

# raise ValueError

n_iters = 10000
lr = 1e-3
clamp_gradients = 1e-1
prev_loss = 1e10

lambda_1 = 1
lambda_2 = 1000

prev_H_error = 1e10

for i in range(n_iters):
    weights_reconstructed = torch.empty_like(weights_masked)
    
    tmp = model()
    weights_reconstructed[:,subvector_assignments] = tmp



    weights_quantized = torch.empty_like(weights)

    weights_quantized[mask] = weights_reconstructed.flatten()
    weights_quantized[~mask] = weights[~mask]

    diff = weights - weights_quantized
    average_error = torch.sum(torch.abs(diff)**1)/torch.sum(torch.abs(weights)**1)

    H_error = torch.einsum('ik,kl,il->', diff, H/H.shape[0], diff)


    if i % (n_iters//10) == 0:
        print(f"average error {average_error}, H error {H_error}")
    # print(f"average error {average_error}, H error {H_error}")

    if H_error > prev_H_error:
        print("H error increased")
        lr_scheduler.step()
        print("lr = ", lr_scheduler.get_last_lr())

    prev_H_error = H_error
    
    optimizer.zero_grad()
    H_error.backward()
    optimizer.step()

free = 46849.0, total = 48676.75
average error 34.04465866088867, H error 8654.181640625
H error increased
lr =  [0.09000000000000001]
H error increased
lr =  [0.08100000000000002]
H error increased
lr =  [0.07290000000000002]
H error increased
lr =  [0.06561000000000002]
H error increased
lr =  [0.05904900000000002]
H error increased
lr =  [0.05314410000000002]
average error 1.4173794984817505, H error 23.564693450927734
average error 1.407180666923523, H error 20.67806053161621
H error increased
lr =  [0.04782969000000002]
H error increased
lr =  [0.043046721000000024]
average error 1.415610432624817, H error 20.336143493652344
H error increased
lr =  [0.03874204890000002]
H error increased
lr =  [0.03486784401000002]
H error increased
lr =  [0.03138105960900001]
H error increased
lr =  [0.028242953648100012]
H error increased
lr =  [0.025418658283290013]
H error increased
lr =  [0.022876792454961013]
H error increased
lr =  [0.020589113209464913]
average error 1.421230673789978, H e

In [41]:
H_error

tensor(20.1729, device='cuda:5', grad_fn=<ViewBackward0>)

In [42]:
weights_quantized

tensor([[-0.0460,  0.0039, -0.0506,  ..., -0.0152, -0.0212, -0.0279],
        [-0.0129,  0.0014, -0.0127,  ..., -0.0024, -0.0125, -0.0106],
        [ 0.0112,  0.0043,  0.0134,  ...,  0.0100,  0.0020,  0.0076],
        ...,
        [ 0.0147, -0.0099, -0.0214,  ..., -0.0049,  0.0427,  0.0095],
        [-0.0204, -0.0035, -0.0539,  ..., -0.0182,  0.0262,  0.0080],
        [ 0.0287, -0.0051,  0.0744,  ...,  0.0021, -0.0246, -0.0282]],
       device='cuda:5', grad_fn=<IndexPutBackward0>)

In [43]:
weights

tensor([[-0.0096, -0.0301,  0.0085,  ...,  0.0178, -0.0052, -0.0365],
        [-0.0029, -0.0101,  0.0100,  ...,  0.0147,  0.0040, -0.0104],
        [-0.0004,  0.0139, -0.0074,  ..., -0.0083, -0.0070,  0.0146],
        ...,
        [-0.0107, -0.0061,  0.0310,  ..., -0.0052, -0.0143,  0.0236],
        [-0.0104, -0.0213, -0.0129,  ..., -0.0199, -0.0143, -0.0103],
        [ 0.0184,  0.0119,  0.0195,  ...,  0.0343, -0.0327, -0.0355]],
       device='cuda:5')