In [55]:
import torch
import torch.nn as nn
import numpy as np 
import random

import tqdm
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda:5")

H = torch.load("test/original_weights.pt")["H"].to(device).float()
weights = torch.load("test/original_weights.pt")["weights"].to(device).float()    

In [3]:
damping = 1e-3

H = H + damping * torch.eye(H.shape[0]).to(device)

In [4]:
H_chol = torch.linalg.cholesky(H)
H_inv = torch.cholesky_inverse(H_chol)

In [5]:
d = 4
n_blocks = int(np.ceil(H.shape[0] / d))
weights_reshaped = weights.reshape((weights.shape[0], n_blocks, d))

block_inverses = []
for i in range(int(n_blocks)):
    start = i * d
    end = (i + 1) * d
    block = H[start:end, start:end]
    block_chol = torch.linalg.cholesky(block)
    block_inv = torch.cholesky_inverse(block_chol)
    block_inverses.append(block_inv)

block_inverses = torch.stack(block_inverses)
print(block_inverses.shape)

torch.Size([1024, 4, 4])


In [6]:
(torch.empty((4,3,5)) @ torch.empty((4,5,2))).shape

torch.Size([4, 3, 2])

In [7]:
removal_errors = weights_reshaped.unsqueeze(-2) @ block_inverses.unsqueeze(0) @ weights_reshaped.unsqueeze(-1)

removal_errors = removal_errors.squeeze(-1).squeeze(-1)
print(removal_errors.shape)

torch.Size([4096, 1024])


In [56]:
frac_to_remove = 0.125
structure_size = 8
n_to_remove = int(frac_to_remove * structure_size)

sparse_mask = torch.ones_like(removal_errors).reshape(removal_errors.shape[0], -1, structure_size)

sparse_mask = sparse_mask.scatter_(-1, removal_errors.reshape(removal_errors.shape[0], -1, structure_size).argsort(-1)[:, :,:n_to_remove], 0)

sparse_mask = sparse_mask.reshape(removal_errors.shape)

In [57]:
sparse_mask = sparse_mask.repeat_interleave(d, dim=-1)
sparse_mask = sparse_mask.bool()

In [58]:
sparse_mask[0,:23]

tensor([ True,  True,  True,  True,  True,  True,  True,  True, False, False,
        False, False,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True], device='cuda:5')

In [59]:
sparse_mask.sum(1)

tensor([3584, 3584, 3584,  ..., 3584, 3584, 3584], device='cuda:5')

In [60]:
changes_non_masked = torch.zeros(sparse_mask.shape[0] * (sparse_mask.shape[-1] - torch.sum(sparse_mask[0]).item())).to(device).requires_grad_(True)





lr = 1e-1
lr_multiple = 0.9
prev_loss = 1e9
n_steps = 1000

for i in range(n_steps):
    diff = torch.zeros_like(weights)
    diff[~sparse_mask] = changes_non_masked
    diff[sparse_mask] = -weights[sparse_mask]
    # print(diff)
    loss = torch.einsum('ik,kl,il->', diff, H/H.shape[0], diff)
    print(loss.item())
    loss.backward()
    # print(changes_non_masked.grad)
    if loss.item() > prev_loss:
        print("Decreasing learning rate")
        lr *= lr_multiple
    prev_loss = loss.item()
    with torch.no_grad():
        changes_non_masked.grad = torch.clamp(changes_non_masked.grad, -1, 1)
        changes_non_masked -= lr * changes_non_masked.grad
        changes_non_masked.grad.zero_()



2208.296875
1870.84228515625
1646.024169921875
1489.6251220703125
1376.3428955078125
1292.00732421875
1228.068603515625
1179.02734375
1140.774658203125
1110.45556640625
1086.05712890625
1066.09765625
1049.5567626953125
1035.72314453125
1024.0380859375
1014.0767822265625
1005.5115356445312
998.0770263671875
991.5503540039062
985.755615234375
980.5457763671875
975.8153686523438
971.4874877929688
967.5029296875
963.8155517578125
960.38818359375
957.19775390625
954.218505859375
951.423583984375
948.7957763671875
946.3216552734375
943.9798583984375
941.7530517578125
939.62744140625
937.5919189453125
935.6373291015625
933.7556762695312
931.9404907226562
930.1863403320312
928.48828125
926.8419189453125
925.2438354492188
923.6904296875
922.178955078125
920.706787109375
919.271484375
917.8709106445312
916.5030517578125
915.166259765625
913.8587646484375
912.579345703125
911.326171875
910.098388671875
908.8946533203125
907.7137451171875
906.554931640625
905.4171142578125
904.29931640625
903.2008

KeyboardInterrupt: 

In [61]:
delta_weights = torch.zeros_like(weights)

for i in tqdm.tqdm(range(weights.shape[0])):
    # print((torch.cholesky_inverse(torch.cholesky(H[sparse_mask[i],:][:,sparse_mask[i]])) @ weights[i,sparse_mask[i]]).shape)
    delta_weights[i] = - H_inv[:,sparse_mask[i]] @ torch.cholesky_inverse(torch.linalg.cholesky(H_inv[sparse_mask[i],:][:,sparse_mask[i]])) @ weights[i,sparse_mask[i]]

100%|██████████| 4096/4096 [01:30<00:00, 45.09it/s]


In [30]:
torch.max(torch.abs(weights_new[sparse_mask]))

tensor(0.0049, device='cuda:5')

In [54]:
weights_new = weights + delta_weights
weights_new[sparse_mask] = 0
weights_new

diff = weights - weights_new

average_error = torch.sum(torch.abs(diff)**1)/torch.sum(torch.abs(weights)**1)

H_error = torch.einsum('ik,kl,il->', diff, H/H.shape[0], diff)

free, total = torch.cuda.mem_get_info(device)
print(f"free = {free/1024/1024}, total = {total/1024/1024}")

print(f"average error {average_error}, H error {H_error}")

free = 46991.0, total = 48676.75
average error 2.3888800144195557, H error 16.494220733642578


In [33]:
weights_new

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0235, -0.0183, -0.1044,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0031,  0.0086, -0.0074],
        ...,
        [-0.0741, -0.0847,  0.0871,  ...,  0.0643,  0.0177,  0.0345],
        [-0.0340,  0.0067,  0.0181,  ..., -0.0323, -0.0030, -0.0301],
        [ 0.0604, -0.0908,  0.1693,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:5')