In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import tqdm

In [7]:
class low_rank_and_sparse(nn.Module):
    def __init__(self, 
                 k: int,
                 H: torch.Tensor,
                 Weights: torch.Tensor,
                 ):
        """Vector Quantizer

        Args:
            initial_codebook (torch.Tensor): the initial codebook (n,k)
            H (torch.Tensor): the Hessian of shape (n,n)
            Weights (torch.Tensor): weights of shape (n,n)
        """
        super(low_rank_and_sparse, self).__init__()

        # self.codebook = nn.Parameter(initial_codebook, requires_grad=True)
        self.H = H
        self.Weights = Weights
        n = Weights.shape[0]

        self.A = nn.Parameter(torch.empty(n,k,device = Weights.device,
                                                            dtype = Weights.dtype
                                                            ).normal_(0,1), requires_grad=True)
        self.B = nn.Parameter(torch.empty(k,n,device = Weights.device,
                                                            dtype = Weights.dtype
                                                            ).uniform_(-1,1), requires_grad=True)
        
        x = torch.empty(n,n,device = Weights.device,
                                                            dtype = Weights.dtype
                                                            ).uniform_(0,1)
        

        self.sparse = nn.Parameter(
            torch.log(x/(1-x)), requires_grad=True)
        #randomly initialize the initial assignments through xavier uniform initialization
        #a matrix of shape (n,k)
        # self.initial_assignments = nn.Parameter(torch.empty(n,k,device = initial_codebook.device,
        #                                                     dtype = initial_codebook.dtype
        #                                                     ).uniform_(-1,1), requires_grad=True)


    def binary_penalty(self, sparse, beta):
        """Penalty function for the sparse assignments

        Args:
            sparse (torch.Tensor): the sparse assignments of shape (n,n)
                assumed to be past through 
            beta (float): a hyperparameter for the penalty, this penalty
            is 1- |2*a_ij - 1|**(beta)


        Returns:
            torch.Tensor: the penalty
        """
        penalty = torch.sum(1 - torch.abs(2*sparse - 1)**beta)
        return penalty
    
    def forward(self,beta,penalty_1_weight,penalty_2_weight):
        """Forward pass of the quantizer

        Args:
            beta (float): a hyperparameter for the first penalty, this penalty
            is 1- |2*a_ij - 1|**(beta)

            penalty_1_weight (float): the weight of the first penalty
            penalty_2_weight (float): the weight of the second penalty
            this penalty is the sum of the assignments, therefore enforceing 
            sparsity

        Returns:
            torch.Tensor: the quantized codebook
        """
        #get the assignments
        sparse_assignments = torch.sigmoid(self.sparse)*1.2 - 0.1
        sparse_assignments = torch.clip(sparse_assignments,0,1)

        low_rank_product = self.A @ self.B

        #add the sparse locations
        W_hat = low_rank_product + sparse_assignments * (self.Weights - low_rank_product)

        #get the difference between the quantized weights and the original weights
        diff = W_hat - self.Weights

        #get the loss
        loss = torch.einsum('ik,kl,il->', diff, self.H, diff)

        #get the penalty
        binary_penalty = self.binary_penalty(sparse_assignments,beta)

        #we have another penalty which enforce sparsity
        sparse_penalty = torch.sum(sparse_assignments)
        penalty = penalty_1_weight * binary_penalty + penalty_2_weight * sparse_penalty
        return loss + penalty, loss, penalty, binary_penalty, sparse_penalty




In [8]:
data = torch.load("/home/lliu/huffman/test/original_weights.pt")
weights = data['weights'].float()
Hessian = data['H'].float()

In [13]:

quantizer = low_rank_and_sparse(128, Hessian, weights)


In [14]:
def sparse_penalty_scheduler(current_penalty, current_weight,multiple_up = 1.001, multiple_down = 0.999, target_sparsity = 0.005, n = Hessian.shape[0]):
    """A scheduler for the sparse penalty"""
    current_sparsity = current_penalty / (n**2)
    if current_sparsity > target_sparsity:
        return current_weight * multiple_up
    else:
        return current_weight * multiple_down

In [None]:
optimizer = torch.optim.Adam(quantizer.parameters(), lr=1e-3)
#add a learning rate scheduler
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.5)

def cosine_annealing(i, T, scale):
    return scale * (1 + np.cos(i/T * 3.1415))

#train the quantizer
beta_init = 10
T = 1000

penalty_1_weight = 0.01
penalty_2_weight = 0.01

losses = []
errors = []
penalties = []
sparse_penalties = []
binary_penalties = []

for i in tqdm.tqdm(range(T)):
    optimizer.zero_grad()
    beta = (1-i/T) * (beta_init-1) + 1
    loss,error,penalty, binary_penalty, sparse_penalty = quantizer(beta, penalty_1_weight, penalty_2_weight)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    errors.append(error.item())
    penalties.append(penalty.item())
    
    penalty_2_weight = sparse_penalty_scheduler(sparse_penalty.item(), penalty_2_weight)
    sparse_penalties.append(sparse_penalty.item())
    binary_penalties.append(binary_penalty.item())

    if i % 100 == 99:
        tqdm.tqdm.write(f"Loss: {loss.item()}, Error: {error.item()/Hessian.shape[0]}, Penalty: {penalty.item()} Beta: {beta}, \n" +\
                        f"sparse penalty: {sparse_penalty.item()/(Hessian.shape[0]**2)}, binary penalty: {binary_penalty.item()} penalty_2_weight: {penalty_2_weight}")

    # scheduler.step()

In [None]:
import matplotlib.pyplot as plt
plt.plot(losses)
# plt.plot(errors)
plt.plot(penalties)
# plt.xscale('log')
plt.yscale('log')


In [None]:
plt.plot(errors[10000:])
plt.yscale('log')

In [None]:
errors[-1]


In [None]:
min(errors)/weights.shape[0]

In [None]:
#get the assignments
#get the assignments
sparse_assignments = torch.sigmoid(quantizer.sparse)

low_rank_product = quantizer.A @ quantizer.B

sparse_assignments = torch.round(sparse_assignments)
#add the sparse locations
W_hat = low_rank_product + sparse_assignments * (quantizer.Weights - low_rank_product)
print(W_hat)
#get the difference between the quantized weights and the original weights
print(quantizer.Weights)
diff = W_hat - quantizer.Weights
print(diff)
print(torch.mean(torch.abs(diff))/torch.mean(torch.abs(quantizer.Weights)))
# print(diff)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(sparse_assignments.detach().cpu().numpy(),interpolation='nearest', aspect='auto')
#scale it to be a square image
#add a colorbar
plt.colorbar()
plt.show()


In [None]:
torch.sum(sparse_assignments)/(H.shape[0]**2)

In [None]:
plt.hist(sparse_assignments.detach().cpu().numpy().flatten(),bins=100)
# plt.yscale('log')

In [None]:
print(quantizer.codebook.shape)

In [None]:
print(quantizer.Weights)

In [6]:
W = weights.clone()
H = Hessian.clone()
dead = torch.diag(H) == 0

columns = W.shape[1]
rows = W.shape[0]
H[dead, dead] = 1
W[:, dead] = 0
blocksize = 128
prunen = 0
prunem = 0
percdamp = 0.01
sparsity = 0.5
dev = W.device

Losses = torch.zeros(rows, device=dev)

damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(columns, device=dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H

mask = None

for i1 in range(0, columns, blocksize):
    i2 = min(i1 + blocksize, columns)
    count = i2 - i1

    W1 = W[:, i1:i2].clone()
    Q1 = torch.zeros_like(W1)
    Err1 = torch.zeros_like(W1)
    Losses1 = torch.zeros_like(W1)
    Hinv1 = Hinv[i1:i2, i1:i2]

    if prunen == 0: 
        if mask is not None:
            mask1 = mask[:, i1:i2]
        else:
            tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
            thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
            mask1 = tmp <= thresh
    else:
        mask1 = torch.zeros_like(W1) == 1

    for i in range(count):
        w = W1[:, i]
        d = Hinv1[i, i]

        if prunen != 0 and i % prunem == 0:
            tmp = W1[:, i:(i + prunem)] ** 2 / (torch.diag(Hinv1)[i:(i + prunem)].reshape((1, -1))) ** 2
            mask1.scatter_(1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True)

        q = w.clone()
        q[mask1[:, i]] = 0

        # if hasattr(self, 'quantizer'):
        #     q = quantize(
        #         q.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
        #     ).flatten()

        Q1[:, i] = q
        Losses1[:, i] = (w - q) ** 2 / d ** 2

        err1 = (w - q) / d
        W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
        Err1[:, i] = err1

    W[:, i1:i2] = Q1
    Losses += torch.sum(Losses1, 1) / 2

    W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

In [None]:
diff = weights - W
print(torch.mean(torch.abs(diff))/torch.mean(torch.abs(weights)))
torch.einsum('ik,kl,il->', diff, Hessian, diff)/H.shape[0]