In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time

In [2]:
#time the time it takes to do a multiplication with sparse tensors

device = torch.device("cuda:0")
dtype = torch.float16
n = 1000
x = torch.randn(100, n, n).to(device).to(dtype)



In [24]:
class Quantizer3(nn.Module):
    def __init__(self, weights, 
                 mask = None,
                 scale_init = None):
        super(Quantizer3, self).__init__()
        self.weights = weights

        if mask is None:
            self.mask = torch.ones_like(weights, dtype=torch.bool)
        else:
            self.mask = mask

        #if we do not have a scale we need to find one
        #default will be 1/3*range of weights
        if scale_init is None:
            self.scale = 1/3*(weights.max()-weights.min())
        else:
            self.scale = scale_init
        print(self.scale)
        
        #set scale to be a parameter
        self.scale_activation = F.softplus
        self.scale = nn.Parameter(
            torch.log(torch.exp(self.scale) - 1)
        )

    def forward(self, x):

        # scale = self.scale_activation(self.scale)
        scale = self.scale
        if torch.all(self.mask):
            quantized_weights = torch.clip(torch.round(self.weights/scale), min=-1, max=1)
            return x @ self.weights, quantized_weights
        quantized_weights = self.weights.clone()
        quantized = torch.clip(torch.round(self.weights[self.mask]/scale), min=-1, max=1)
        assert ~torch.isnan(quantized).any()
        quantized_weights[self.mask] = scale * quantized
        return x.reshape(-1, quantized_weights.shape[1]) @ quantized_weights, quantized

In [25]:
weights = torch.randn(1000,1000).to(device).to(dtype)
mask = torch.randn(1000,1000).to(device).to(dtype) > 0

quantizer = Quantizer3(weights, mask,torch.tensor(1))
quantizer.to(device).to(dtype)


optimizer = torch.optim.SGD(quantizer.parameters(), lr=0.0001)


tensor(1)


In [26]:
y,_ = quantizer(x)
y_actual = x.reshape(-1, weights.shape[1]) @ weights

losses = []
for i in range(500):
    y,q = quantizer(x)
    optimizer.zero_grad()
    loss = F.mse_loss(y, y_actual)
    loss.backward()
    #print the gradients
    # print(quantizer.scale.grad)
    # print(quantizer.scale_activation(quantizer.scale).item())
    optimizer.step()
    print(quantizer.scale.grad)
    # print(quantizer.scale_activation(quantizer.scale).item())
    losses.append(loss.item())
    # print(quantizer.scale_activation(quantizer.scale))
    # print(loss.item(), quantizer.scale_activation(quantizer.scale).item(), loss,torch.any(torch.isnan(q)), torch.any(torch.isnan(y)))
    # print()
    # raise Exception("stop")

tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
tensor(0., device='c

In [18]:
losses

[340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,
 340.0,


In [None]:
_,q = quantizer(x)

In [None]:
torch.unique(q, return_counts=True)

In [None]:
quantizer.scale

In [None]:
import matplotlib.pyplot as plt
plt.plot(losses)

In [None]:
import numba.cuda
import numba as nb

threadsperblock = 32
blockspergrid = (an_array.size + (threadsperblock - 1)) // threadsperblock
increment_by_one[blockspergrid, threadsperblock](an_array)

@numba.cuda.jit(device=True)
def binary_sparse_matrix_multiplication(x, row_idxs):
    #x is a 3d tensor with shape (batch, n, n)
    #row idxs is a list of list of booleans, representing the column indicies that are 1 in the sparse matrix
    #returns a 3d tensor with shape (batch, n, n)
    #both are boolean tensors

    y = x.clone()
    # return x[:,:,row_idxs].sum(dim=3)

    # raise Exception("stop")
    for i, row_idx in enumerate(row_idxs):
        # for j in range(len(row_idx)):
        y[:,:,i] += x[:,:,row_idx].sum(dim=2) 
    return y


In [None]:
times_sparse = []
times_naive = []
sparsities = torch.logspace(-4, -1, 10)
with torch.no_grad():
    for sparsity in sparsities:
        w = torch.zeros_like(x[0])
        row,col = torch.randint(0,n, (2, int(n**2*sparsity)))
        w[row, col] = 1


        row_idxs = []
        for i in range(n):
            row_idxs.append(w[i]==1)
        print("here")
        print(row_idxs[0].shape)
        # raise Exception("stop")
        start = time.time()
        y_sparse = binary_sparse_matrix_multiplication(x, row_idxs)
        times_sparse.append(time.time()-start)

        
        start = time.time()
        y_naive = F.linear(x, w)
        times_naive.append(time.time()-start)

        # assert torch.allclose(y_sparse, y_naive, atol=1e-2)

        print(f"sparsity: {sparsity}, sparse: {times_sparse[-1]}, naive: {times_naive[-1]}")

In [None]:
times_sparse = []
times_naive = []
sparsities = torch.logspace(-4, -1, 10)
with torch.no_grad():
    for sparsity in sparsities:
        w = torch.zeros_like(x[0])
        row,col = torch.randint(0,n, (2, int(n**2*sparsity)))
        w[row, col] = 1


        row_idxs = []
        for i in range(n):
            row_idxs.append(torch.where(w[i]==1)[0].cpu())
        print("here")
        # print(row_idxs)
        # raise Exception("stop")
        start = time.time()
        y_sparse = binary_sparse_matrix_multiplication(x, row_idxs)
        times_sparse.append(time.time()-start)

        
        start = time.time()
        y_naive = F.linear(x, w)
        times_naive.append(time.time()-start)

        # assert torch.allclose(y_sparse, y_naive, atol=1e-2)

        print(f"sparsity: {sparsity}, sparse: {times_sparse[-1]}, naive: {times_naive[-1]}")



In [None]:
import matplotlib.pyplot as plt

plt.plot(sparsities,times_sparse, label="sparse")
plt.plot(sparsities,times_naive, label="naive")
plt.xscale("log")
plt.ylim(0,0.2)

In [None]:
import matplotlib.pyplot as plt

plt.plot(sparsities,times_sparse, label="sparse")
plt.plot(sparsities,times_naive, label="naive")
plt.xscale("log")
# plt.yscale("log")

In [None]:
y_sparse

In [None]:
i,j,k = torch.where(~torch.isclose(y_sparse, y_naive, atol=1e-3))

In [None]:
for l in range(len(i)):
    print(y_naive[i[l],j[l],k[l]].item(), y_sparse[i[l],j[l],k[l]].item())
    # print(x[i[l],:,k[l]], w[j[l],:])
    print(torch.sum(x[i[l],j[l],:]@w[k[l],:]).item())

In [None]:
torch.allclose((x.reshape(-1,n) @ w.T).reshape(-1,n,n), y_naive, atol=1e-3)

In [None]:
torch.max(torch.abs(y_sparse[4]-y_naive[4]))