<a href="https://colab.research.google.com/github/sinamps/tensor-networks/blob/main/sina_monarch_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is version 1. I am going to try the full matrix svd, full matrix svd with permutation, and block-wise svd, and block-wise svd with permutation, and finally monarch code for block-wise svd with permutation.

In [1]:
!pip install transformers
!pip install pytorch_pretrained_bert --upgrade
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
from transformers import AutoConfig, AutoTokenizer, BertForPreTraining
# from transformers import cached_path, WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME 
# from transformers.file_utils import is_remote_url, hf_bucket_url
import torch, os, sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import locale
locale.getpreferredencoding = lambda: "UTF-8"
from einops import rearrange
import math
from torch.nn import functional as F

In [3]:
device = torch.device("cuda:0")
model = BertForPreTraining.from_pretrained("bert-large-uncased")
state_dict = model.state_dict()
key_to_load = 'bert.encoder.layer.0.attention.self.query.weight'
fixed_matrix = state_dict[key_to_load].to(device)
fixed_matrix.shape

torch.Size([1024, 1024])

In [4]:
def reconstruct_from_blocks(blocks):
    
    # blocks is a list of list of blocks
    
    block_size_x = blocks[0][0].shape[0]
    block_size_y = blocks[0][0].shape[1]
    
    x_size = len(blocks)*block_size_x
    y_size = len(blocks[0])*block_size_y
    
    new_matrix = torch.zeros((x_size, y_size))
    
    for i in range(len(blocks)):
        x_index = i*block_size_x
        for j in range(len(blocks[0])):
            y_index = j*block_size_y
            new_matrix[x_index:x_index+block_size_x,y_index:y_index + block_size_y] = blocks[i][j]
    
    return new_matrix


def factors(n):
    return [(i, n // i) for i in range(1, math.floor(math.sqrt(n)) + 1) if n % i == 0]


def low_rank_project(M, rank):
    """Supports batches of matrices as well.
    """
    U, S, Vt = torch.linalg.svd(M)
    S_sqrt = S[..., :rank].sqrt()
    U = U[..., :rank] * rearrange(S_sqrt, '... rank -> ... 1 rank')
    Vt = rearrange(S_sqrt, '... rank -> ... rank 1') * Vt[..., :rank, :]
    return U, Vt


def calculate_all_norms(matrix_1, matrix_2):
    
#     frob = torch.norm(matrix_1.to(matrix_2.device) - matrix_2, p='fro')
    difference_matrix = matrix_1.to(matrix_2.device) - matrix_2
    
    frob = torch.linalg.matrix_norm(difference_matrix, ord='fro')
    
    nuc = torch.linalg.matrix_norm(difference_matrix, ord='nuc')
    
    spectral = torch.linalg.matrix_norm(difference_matrix, ord=2)
    
    norms = {
        'fro': frob,
        'nuc': nuc,
        'spectral': spectral,        
    }
    return norms

In [5]:
def get_svd(A, r):
  
    """
    Finds the two low-rank matrices of rank r from the matrix A.

    Args:
    A: The input matrix of size m*n.
    r: The rank of the low-rank matrices.

    Returns:
    U: The first low-rank matrix of size m*r.
    V: The second low-rank matrix of size r*n.
    """

    # Get the singular value decomposition of A.
    U, S, V = torch.linalg.svd(A)

    # Keep the first r singular values and their corresponding singular vectors.
    U = U[:, :r]
    S = S[:r]
    V = V[:r, :]

    # Multiply U with S.
    U = U @ S.diag()
    
    reconstructed = torch.matmul(U, V)
#     print(U.shape, V.shape, reconstructed.shape)
    # Return the two low-rank matrices.
    return U, V, reconstructed

In [6]:
def get_new_svd(A, r):
  
    """
    Finds the two low-rank matrices of rank r from the matrix A.

    Args:
    A: The input matrix of size m*n.
    r: The rank of the low-rank matrices.

    Returns:
    U: The first low-rank matrix of size m*r.
    V: The second low-rank matrix of size r*n.
    """
    '''
    # Get the singular value decomposition of A.
    U, S, V = torch.linalg.svd(A)

    # Keep the first r singular values and their corresponding singular vectors.
    U = U[:, :r]
    S = S[:r]
    V = V[:r, :]

    # Multiply U with S.
    U = U @ S.diag()
    
    U = U.cdouble()
    V = V.cdouble()
    '''
    U, V = low_rank_project(A, rank=r)
    # Vp = rearrange(V, 'k r 1 s -> r k s')
    Vp = rearrange(V, 'i j -> j i')
    # Up = rearrange(U, 'k r s 1 -> k s r')
    Up = rearrange(U, 'i j -> j i')
    reconstructed = torch.matmul(Vp, Up)
#     print(U.shape, V.shape, reconstructed.shape)
    # Return the two low-rank matrices.
    return Vp, Up, reconstructed

In [7]:
def blockdiag_butterfly_project(M, sizes=None):
    """Only works for square matrices for now
    """
    m, n = M.shape
    if m != n:
        raise NotImplementedError('Only support square matrices')
    if sizes is None:
        # Find the factors that are closest to sqrt(n)
        sizes = factors(n)[-1]
        # Larger factor first is probably more efficient, idk
        sizes = (sizes[1], sizes[0])
    assert n == sizes[0] * sizes[1]
    M_permuted_batched = rearrange(M, '(p k) (r s) -> k r p s', k=sizes[1], r=sizes[0])
    U, Vt = low_rank_project(M_permuted_batched, rank=1)
    w1_bfly = rearrange(Vt, 'k r 1 s -> r k s')
    w2_bfly = rearrange(U, 'k r s 1 -> k s r')
    return w1_bfly, w2_bfly

In [8]:
class BlockdiagButterflyMultiply(torch.autograd.Function):

    """This is a faster implementation, with careful memory copies for the fastest
    bmm performance.
    The backward pass is also written manually with careful memory copies.
    Arguments:
        x: (batch, n)
        w1_bfly: (k, q, p), where k = n / p
        w2_bfly: (l, s, r), where l = k * q / r = n * q / (p * r)
    Outputs:
        out: (batch, m), where m = l * s = n * s * q / (p * r)
    """

    @staticmethod
    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float16)
    def forward(ctx, x, w1_bfly, w2_bfly):
        batch_shape, n = x.shape[:-1], x.shape[-1]
        batch_dim = np.prod(batch_shape)
        k, q, p = w1_bfly.shape
        l, s, r = w2_bfly.shape
        assert k * p == n
        assert l * r == k * q
        x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1)
        out1 = torch.empty(batch_dim, k, q, device=x.device, dtype=x.dtype).transpose(0, 1)
        out1 = torch.bmm(x_reshaped, w1_bfly.transpose(-1, -2), out=out1)
        out1 = out1.transpose(0, 1).reshape(batch_dim, r, l).transpose(-1, -2).contiguous().transpose(0, 1)
        out2 = torch.empty(batch_dim, l, s, device=x.device, dtype=x.dtype).transpose(0, 1)
        out2 = torch.bmm(out1, w2_bfly.transpose(-1, -2), out=out2)
        out2 = out2.permute(1, 2, 0).reshape(*batch_shape, s * l)
        ctx.save_for_backward(x, w1_bfly, w2_bfly, out1)
        return out2

    @staticmethod
    @torch.cuda.amp.custom_bwd
    def backward(ctx, dout):
        x, w1_bfly, w2_bfly, out1 = ctx.saved_tensors
        batch_shape, n = x.shape[:-1], x.shape[-1]
        batch_dim = np.prod(batch_shape)
        k, q, p = w1_bfly.shape
        l, s, r = w2_bfly.shape
        # assert k * p == n
        # assert l * r == k * q
        dx, dw1_bfly, dw2_bfly = None, None, None
        # dout_reshaped = dout.reshape(batch_dim, sqrtn, sqrtn).permute(2, 1, 0).contiguous()
        dout_reshaped = dout.reshape(batch_dim, s, l).transpose(-1, -2).contiguous()
        dout_reshaped = dout_reshaped.transpose(0, 1)
        if ctx.needs_input_grad[2]:
            # dw2_bfly = torch.empty(l, s, r, device=w2_bfly.device, dtype=w2_bfly.dtype)
            # dw2_bfly = torch.bmm(dout_reshaped.transpose(-1, -2), out1, out=dw2_bfly)
            dw2_bfly = torch.bmm(dout_reshaped.transpose(-1, -2), out1.conj())
        if ctx.needs_input_grad[1] or ctx.needs_input_grad[0]:
            dout1 = torch.empty(batch_dim, l, r, device=x.device, dtype=x.dtype).transpose(0, 1)
            dout1 = torch.bmm(dout_reshaped, w2_bfly.conj(), out=dout1)
            dout1 = dout1.transpose(0, 1).transpose(-1, -2).contiguous().reshape(batch_dim, k, q).transpose(0, 1)
            # dout1 = dout1.permute(1, 2, 0).contiguous().transpose(0, 1)
            if ctx.needs_input_grad[0]:
                dx = torch.empty(batch_dim, k, p, device=x.device, dtype=x.dtype)
                dx = torch.bmm(dout1, w1_bfly.conj(), out=dx.transpose(0, 1)).transpose(0, 1).reshape(*batch_shape, n)
            if ctx.needs_input_grad[1]:
                x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1)
                dw1_bfly = torch.bmm(dout1.transpose(-1, -2), x_reshaped.conj())
        return dx, dw1_bfly, dw2_bfly

blockdiag_butterfly_multiply = BlockdiagButterflyMultiply.apply

This is the code based on monarch's code to get L and R from monarch decomposition.

In [9]:
def get_monarch(fixed_matrix):
    # set seed
    # torch.random.manual_seed(0)
    n = fixed_matrix.shape[1]
    log_n = torch.log2(torch.tensor(n))
    # n = 1 << log_n
    log_n = int(log_n)
    sqrtn = 1 << (log_n // 2)
    batch_size = 1
    eye = torch.eye(n, device=device)
    # myweights = fixed_matrix.cfloat()
    # transform = torch.fft.fft if direction == 'fft' else torch.fft.ifft
    # dft = transform(eye, norm='ortho').t()
    # dft = transform(myweights, norm='ortho').t()
    # perm = bitreversal_permutation(n)
    # We don't actually need the bitreversal permutation, any permutation that swap
    # the axes of the sqrtn x sqrtn input will work.
    perm = rearrange(torch.arange(n, device=device), '(i j) -> (j i)', i=sqrtn)
    # The BP (butterfly - permutation) decomposition of FFT / iFFT
    # Converting to complex128 makes the approximation an order of magnitude more accurate
    w1_fft_projected, w2_fft_projected = blockdiag_butterfly_project(fixed_matrix[:, perm])
    # w1_fft_projected, w2_fft_projected = w1_fft_projected.cfloat(), w2_fft_projected.cfloat()
    recons = blockdiag_butterfly_multiply(eye, w1_fft_projected, w2_fft_projected)
    # recons = torch.matmul(w1_fft_projected, w2_fft_projected).cfloat()
    # fft_projected = blockdiag_butterfly_multiply(myweights, w1_fft_projected, w2_fft_projected).t()
    # fft_projected = blockdiag_butterfly_multiply(myinput, w1_fft_projected, w2_fft_projected)
    # fft_projected = torch.matmul(myweights_t, w1_fft_projected)
    # fft_projected = torch.matmul(fft_projected, w2_fft_projected).t()
    # print("shapes")
    # print(fft_projected.shape)
    # print(myinput.shape)
    # print(myinput[:, perm].shape)
    # print("max abs difference:", (fft_projected - myinput[:, perm]).abs().max())
    # assert torch.allclose(fft_projected, dft[:, perm], rtol=1e-4, atol=1e-4)
    # print(torch.norm(fft_projected - myinput[:, perm], p='fro'))
    # return fft_projected
    # x = torch.randn(batch_size, n, dtype=torch.complex64, device=
    return w1_fft_projected, w2_fft_projected, recons

In [10]:
def split_and_run_svd(fixed_matrix, block_size_x, block_size_y, rank, decomp):
    
    # Should return the full reconstructed matrix along with the block-wise low-rank decompositions
    # x for rows, y for columns
    # assuming that each index is perfectly divisible by the block sizes
    
    n_blocks_x = int(fixed_matrix.shape[0]/block_size_x)
    n_blocks_y = int(fixed_matrix.shape[1]/block_size_y)
    
    print(f'Num Blocks in x (rows): {n_blocks_x}')
    print(f'Num Blocks in y (cols): {n_blocks_y}')
    
    num_params = np.prod([n_blocks_x, n_blocks_y, (block_size_x+block_size_y)*rank])

    print('Num parameters in the resulting decomposition ', num_params)
    print('Num parameters in the original matrix ', fixed_matrix.shape[0]*fixed_matrix.shape[1])
    
    
    # partition original matrix into blocks
    
    fixed_matrix_blocks = list()
    # lmf_blocks = list()
    for i in range(n_blocks_x):
        x_index = i*block_size_x
        column_block_list = []
        for j in range(n_blocks_y):
            y_index = j*block_size_y
            block = fixed_matrix[x_index:x_index+block_size_x, y_index:y_index + block_size_y]
            column_block_list.append(block)
        fixed_matrix_blocks.append(column_block_list)
    
    low_ranks = list()
    low_ranks_blocks = list() # reconstructed blocks from low-rank decompositions
    
    for i in range(len(fixed_matrix_blocks)):
        low_ranks_cols = list()
        low_ranks_cols_blocks = list()
        
        for j in range(len(fixed_matrix_blocks[0])):
            if decomp == 'svd':
              left, right, reconstructed = get_svd(fixed_matrix_blocks[i][j], rank)
            elif decomp == 'svd_p':
              left, right, reconstructed = get_new_svd(fixed_matrix_blocks[i][j], rank)
            elif decomp == 'monarch':
              left, right, reconstructed = get_monarch(fixed_matrix_blocks[i][j])
            else:
              raise Exception("No decomposition is specified!")
                
            # left, right, reconstructed = get_monarch(fixed_matrix_blocks[i][j])

#             left, right, reconstructed = get_sgd(fixed_matrix_blocks[i][j], rank)
#             print(left.shape, right.shape, reconstructed.shape)
            low_ranks_cols.append((left, right))
            low_ranks_cols_blocks.append(reconstructed)
#         print(f'{i+1}/{len(fixed_matrix_blocks)} done')
        
        low_ranks.append(low_ranks_cols)
        low_ranks_blocks.append(low_ranks_cols_blocks)
    return low_ranks, low_ranks_blocks, fixed_matrix_blocks, num_params

In [11]:
def find_closest_rank_full_matrix_decomposition(fixed_matrix, desired_num_params):
    
    rank = int(desired_num_params / (fixed_matrix.shape[0] + fixed_matrix.shape[1]))
    num_p = (fixed_matrix.shape[0] + fixed_matrix.shape[1])*rank
    
    print(f'Rank that gives parameters closest to {desired_num_params} is rank:{rank} with params: {num_p}')
    
    left, right, reconstructed = get_svd(fixed_matrix, rank)
    
    return left, right, reconstructed

In [12]:
svd_low_ranks, svd_low_ranks_blocks, svd_fixed_matrix_blocks, svd_num_params_block_svd = split_and_run_svd(fixed_matrix, block_size_x=32, block_size_y=32, rank=1, decomp='svd')
ma_low_ranks, ma_low_ranks_blocks, ma_fixed_matrix_blocks, ma_num_params_block_svd = split_and_run_svd(fixed_matrix, block_size_x=32, block_size_y=32, rank=1, decomp='monarch')

Num Blocks in x (rows): 32
Num Blocks in y (cols): 32
Num parameters in the resulting decomposition  65536
Num parameters in the original matrix  1048576
Num Blocks in x (rows): 32
Num Blocks in y (cols): 32
Num parameters in the resulting decomposition  65536
Num parameters in the original matrix  1048576


In [13]:
perm = rearrange(torch.arange(1024, device=device), '(i j) -> (j i)', i=32)

reconstructed_matrix_svd = reconstruct_from_blocks(svd_fixed_matrix_blocks)
reconstructed_matrix_ma = reconstruct_from_blocks(svd_fixed_matrix_blocks)
print("diff: reconstructed from blocks vs. original:")
print(torch.norm(reconstructed_matrix_svd.to(fixed_matrix.device) - fixed_matrix, p='fro'))
assert torch.norm(reconstructed_matrix_svd.to(device) - reconstructed_matrix_ma.to(device), p='fro') == torch.tensor(0)
print("\n")
reconstructed_block_svd = reconstruct_from_blocks(svd_low_ranks_blocks)
print("diff: reconstructed from low-rank svd blocks vs. original:")
print(torch.norm(reconstructed_block_svd.to(fixed_matrix.device) - fixed_matrix, p='fro'))
print("\n")
reconstructed_block_monarch = reconstruct_from_blocks(ma_low_ranks_blocks)
print("diff: reconstructed from low-rank monarch blocks vs. original:")
print(torch.norm(reconstructed_block_monarch.to(fixed_matrix.device) - fixed_matrix[:, perm], p='fro'))
print("\n")
_, _, reconstructed_svd_full_matched_rank = find_closest_rank_full_matrix_decomposition(fixed_matrix, svd_num_params_block_svd)
print("diff: reconstructed from full matrix svd with matching rank vs. original:")
print(torch.norm(reconstructed_svd_full_matched_rank.to(fixed_matrix.device) - fixed_matrix, p='fro'))
print("\n")
_, _, reconstructed_svd_full = get_svd(fixed_matrix, 1)
print("diff: reconstructed from full matrix svd with rank 1 vs. original:")
print(torch.norm(reconstructed_svd_full.to(fixed_matrix.device) - fixed_matrix, p='fro'))
print("\n")
_, _, reconstructed_monarch_full = get_monarch(fixed_matrix)
print("diff: reconstructed from full matrix monarch vs. original:")
print(torch.norm(reconstructed_monarch_full.to(fixed_matrix.device) - fixed_matrix[:, perm], p='fro'))


diff: reconstructed from blocks vs. original:
tensor(0., device='cuda:0')


diff: reconstructed from low-rank svd blocks vs. original:
tensor(33.2463, device='cuda:0')


diff: reconstructed from low-rank monarch blocks vs. original:
tensor(44.6162, device='cuda:0')


Rank that gives parameters closest to 65536 is rank:32 with params: 65536
diff: reconstructed from full matrix svd with matching rank vs. original:
tensor(31.0410, device='cuda:0')


diff: reconstructed from full matrix svd with rank 1 vs. original:
tensor(35.6478, device='cuda:0')


diff: reconstructed from full matrix monarch vs. original:
tensor(38.3040, device='cuda:0')


In [14]:
print(calculate_all_norms(fixed_matrix, reconstructed_monarch_full))
print(calculate_all_norms(fixed_matrix, reconstructed_block_monarch))
print(calculate_all_norms(fixed_matrix, reconstructed_svd_full_matched_rank))
print(calculate_all_norms(fixed_matrix, reconstructed_block_svd))

{'fro': tensor(38.3195, device='cuda:0'), 'nuc': tensor(955.1649, device='cuda:0'), 'spectral': tensor(5.2226, device='cuda:0')}
{'fro': tensor(44.5996), 'nuc': tensor(1155.3595), 'spectral': tensor(5.3252)}
{'fro': tensor(31.0410, device='cuda:0'), 'nuc': tensor(785.1836, device='cuda:0'), 'spectral': tensor(2.3623, device='cuda:0')}
{'fro': tensor(33.2460), 'nuc': tensor(845.9597), 'spectral': tensor(3.5511)}


In [21]:
# torch.manual_seed(0)
better = list()
counter = 0
for i in range(500):
  myinput = torch.rand(16, 1024).to(device)
  # assert torch.allclose(myrecons, fixed_matrix[:, perm].cfloat(), rtol=1e-4, atol=1e-4)
  true_res = torch.matmul(myinput, fixed_matrix.t())

  #_, _, recons_fullsvd = get_svd(fixed_matrix, 1)
  #full_svd_res = torch.matmul(myinput, recons_fullsvd)
  #print(torch.norm(full_svd_res - true_res, p='fro'))

  # L, R, ma_recons = get_monarch(fixed_matrix)
  # monarch_res = blockdiag_butterfly_multiply(myinput[:, perm], L, R)
  # monarch_res_2 = torch.matmul(myinput[:, perm], ma_recons)
  # print(torch.norm(monarch_res - true_res, p='fro'))
  # mon = torch.norm(monarch_res_2 - true_res, p='fro')



  # print("\n")
  # ysvd = torch.matmul(myinput, reconstructed_svd_full)
  # print(torch.norm(ysvd - true_res, p='fro'))
  ysvd_matchedrank = torch.matmul(myinput, reconstructed_svd_full_matched_rank)
  svd = torch.norm(ysvd_matchedrank - true_res, p='fro')
  print(svd)
  
  ymonarch = torch.matmul(myinput, reconstructed_monarch_full)
  mon = torch.norm(ymonarch - true_res, p='fro')
  print(mon)
  # print("\n")
  # yblocksvd = torch.matmul(myinput, reconstructed_block_svd.to(device))
  # print(torch.norm(yblocksvd - true_res, p='fro'))
  # yblockmonarch = torch.matmul(myinput, reconstructed_block_monarch.to(device))
  # print(torch.norm(yblockmonarch - true_res, p='fro'))
  print("\n")
  if mon <= svd:
    counter = counter + 1

print(counter)

tensor(88.2836, device='cuda:0')
tensor(77.7914, device='cuda:0')


tensor(89.7439, device='cuda:0')
tensor(78.8039, device='cuda:0')


tensor(87.7617, device='cuda:0')
tensor(78.3627, device='cuda:0')


tensor(87.4255, device='cuda:0')
tensor(76.7803, device='cuda:0')


tensor(89.0154, device='cuda:0')
tensor(78.4655, device='cuda:0')


tensor(89.1359, device='cuda:0')
tensor(78.8584, device='cuda:0')


tensor(87.0653, device='cuda:0')
tensor(77.8473, device='cuda:0')


tensor(88.8364, device='cuda:0')
tensor(78.2332, device='cuda:0')


tensor(87.1475, device='cuda:0')
tensor(76.7762, device='cuda:0')


tensor(87.7558, device='cuda:0')
tensor(78.1693, device='cuda:0')


tensor(87.5617, device='cuda:0')
tensor(77.9478, device='cuda:0')


tensor(88.4784, device='cuda:0')
tensor(77.9631, device='cuda:0')


tensor(87.6916, device='cuda:0')
tensor(77.7030, device='cuda:0')


tensor(88.6739, device='cuda:0')
tensor(79.1525, device='cuda:0')


tensor(88.1446, device='cuda:0')
tensor(77.4345,

In [17]:
print(calculate_all_norms(reconstructed_matrix_svd, fixed_matrix))
print(calculate_all_norms(reconstructed_svd_full, fixed_matrix))

# print(calculate_all_norms(ysvd, true_res))
# print(calculate_all_norms(ymonarch, true_res))

{'fro': tensor(0., device='cuda:0'), 'nuc': tensor(0., device='cuda:0'), 'spectral': tensor(0., device='cuda:0')}
{'fro': tensor(35.6478, device='cuda:0'), 'nuc': tensor(880.5148, device='cuda:0'), 'spectral': tensor(4.7447, device='cuda:0')}
