This is a preliminary toy code for converting a layer of a pre-trained model using our proposed Riemannian Metric.
In this notebook, we primarily compare our method with the typical Singular Value Decomposition method, which is easy to implement in this notebook without loading several packages and datasets.
Besides, the SVD-based method refers to an optimal compression method according to the Eckart–Young–Mirsky theorem.
Therefore, the superiority of our method over SVD on the task of matrix compression can mostly validate that our method's power.


In [1]:

import json
from pprint import pprint

import torch
import time
import random
import numpy as np
import torchvision

import torch.nn as nn
import torch.nn.functional as F


import torchvision.models as models
import torchvision.transforms as transforms

from datetime import datetime

SEED_ID = 42
torch.manual_seed(SEED_ID)

import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


cuda


In [2]:

resnet50_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)


In [3]:

# We pick a random layer's weights as the weight matrix(-ices).
tar_mats_1 = resnet50_model.state_dict()['layer2.0.conv1.weight'].to(device)


In [4]:

def quant_mat_torch(_mat, num_bits, axis=0):

    min_vals, max_vals = torch.min(_mat, axis=axis).values, torch.max(_mat, axis=axis).values
    range_unit = (max_vals - min_vals)/(2**num_bits)
    q_mat = min_vals + torch.round((_mat - min_vals)/range_unit)*range_unit

    return q_mat


In [5]:

def compress_mat_via_svd(tar_mat, _ratio, full_matrices=True, num_bits=-1):

    U, S, Vh = torch.linalg.svd(tar_mat, full_matrices=full_matrices)
    #print(U.shape, S.shape, Vh.shape)

    num_row, num_col = tar_mat.shape[:2]

    simple_rank = min(num_row, num_col)
    keep_cols = round(simple_rank*_ratio)

    comp_U, comp_S, comp_Vh = U[:, :keep_cols], torch.diag_embed(S[:keep_cols]), Vh[:keep_cols, :]

    if num_bits != -1:
        comp_U = quant_mat_torch(comp_U, num_bits)
        comp_S = quant_mat_torch(comp_S, num_bits)
        comp_Vh = quant_mat_torch(comp_Vh, num_bits)

    rec_mat = comp_U @ comp_S @ comp_Vh
    svd_params = comp_U.numel() + S[:keep_cols].numel() + comp_Vh.numel()
    rel_error = torch.mean(torch.abs(tar_mat - rec_mat))/torch.std(tar_mat)

    return rec_mat, rel_error.item(), svd_params


In [6]:

# We pick a specific rank-2 matrix as the target matrix for brevity.
# In pratical case, one can implement the following compression directly on rank-4 tensors.
tar_mat = tar_mats_1[:,:,0,0]

print(tar_mat.shape)


torch.Size([128, 256])


In [7]:

# We test the compression performance using a typical SVD (Singular Value Decomposition)-based method.
svd_ratio, num_bits = 0.5, 8
rec_mat, rec_error, rec_params = compress_mat_via_svd(tar_mat, svd_ratio, full_matrices=False, num_bits=num_bits)

print('--- No.Params of Weights:', tar_mat.numel())
print('--- No.Params of SVD:', rec_params,
      '--- Compress Ratio:', round(num_bits/32 * rec_params/tar_mat.numel(),4),
      '--- Rel Errors:', round(rec_error, 4))



--- No.Params of Weights: 32768
--- No.Params of SVD: 24640 --- Compress Ratio: 0.188 --- Rel Errors: 0.1758


The compression technique via SVD method refers to the theoretically optimal
compression efficiency using data-free method.

The compression ratio is 0.188.
The resulted relative error is 0.1758.

In [8]:


class dynMat_rieM(nn.Module):
    def __init__(self, num_input, num_output, q_dim, metric_dim):
        super(dynMat_rieM, self).__init__()

        self.num_input = num_input
        self.num_output = num_output
        self.q_dim = q_dim
        self.metric_dim = metric_dim
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

        self.input_Qs = nn.Parameter(torch.rand(num_input, q_dim, device=device))
        self.output_Qs = nn.Parameter(torch.rand(num_output, q_dim, device=device))

        self.metric1_linear1 = nn.Linear(q_dim, metric_dim, device=device)
        self.metric2_linear1 = nn.Linear(q_dim, metric_dim, device=device)
        self.metric3_linear1 = nn.Linear(q_dim, metric_dim, device=device)
        self.metric4_linear1 = nn.Linear(q_dim, metric_dim, device=device)
        self.metric5_linear1 = nn.Linear(q_dim, metric_dim, device=device)
        self.metric6_linear1 = nn.Linear(q_dim, metric_dim, device=device)

        self.metric_linear2 = nn.Linear(metric_dim, 1, device=device)


    def compute_relVecs(self, cat_id, num_bits):

        if num_bits != -1:
            input_Qs = quant_mat_torch(self.input_Qs, num_bits)
            output_Qs = quant_mat_torch(self.output_Qs, num_bits)
        else:
            input_Qs = self.input_Qs
            output_Qs = self.output_Qs

        return input_Qs.unsqueeze(1) + torch.cat((output_Qs[:,cat_id:], output_Qs[:,:cat_id]), dim=1)


    def forward(self, num_bits):

        _relVecs1 = self.compute_relVecs(0, num_bits)
        _relVecs2 = self.compute_relVecs(1*self.q_dim//6, num_bits)
        _relVecs3 = self.compute_relVecs(2*self.q_dim//6, num_bits)
        _relVecs4 = self.compute_relVecs(3*self.q_dim//6, num_bits)
        _relVecs5 = self.compute_relVecs(4*self.q_dim//6, num_bits)
        _relVecs6 = self.compute_relVecs(5*self.q_dim//6, num_bits)

        Rm_dist = self.metric_linear2(
            self.sigmoid(self.metric1_linear1(_relVecs1))+
            self.sigmoid(self.metric2_linear1(_relVecs2))+
            self.sigmoid(self.metric3_linear1(_relVecs3))+
            self.sigmoid(self.metric4_linear1(_relVecs4))+
            self.sigmoid(self.metric5_linear1(_relVecs5))+
            self.sigmoid(self.metric6_linear1(_relVecs6))
        ).squeeze(2)

        return Rm_dist


In [9]:

def train_dynMat(_dynMat, _optimizer, tar_mat, _epochs, num_bits):

    fail_count = 0
    pre_err = 999
    errs_list = []

    for _ep in range(_epochs):
        # During training, we assume full-precision.
        res_mat = _dynMat(num_bits=-1)
        loss_W = torch.sum((tar_mat - res_mat)**2)
        _optimizer.zero_grad()
        loss_W.backward()
        _optimizer.step()

        if _ep%5000 == 0:
            # During inference, we use quantized RieM.
            res_mat = _dynMat(num_bits=num_bits)
            rel_error = torch.mean(torch.abs(tar_mat - res_mat))/torch.std(tar_mat)
            if _ep%50000 == 0: print(_ep, loss_W.item(), '--- Rel Error:', rel_error.item(), datetime.now().time())
            errs_list.append(round(rel_error.item(), 5))

            if pre_err - rel_error < 5e-6 and _ep > round(_epochs*0.8): fail_count += 1
            if fail_count == 5: break
            pre_err = rel_error.item()

        #if _ep%50000 == 0 and _ep != 0: print(errs_list)

    if _epochs > 100000:
        print(errs_list)
        print('--- Min:', min(errs_list))

    return _dynMat, _optimizer


In [10]:

def pick_configuration(num_input, num_output, q_dim, metric_dim, max_pick, tar_mat, num_bits):


    best_dynMat1_model = dynMat_rieM(num_input, num_output, q_dim, metric_dim)
    best_dynMat1_optimizer = torch.optim.Adam(best_dynMat1_model.parameters(), lr=5e-5)
    _params = sum([param.numel() for param in best_dynMat1_model.parameters()])
    print('--- No.Params:', _params)

    min_error = 999
    for pick_id in range(max_pick):
        _dynMat1_model = dynMat_rieM(num_input, num_output, q_dim, metric_dim)
        _dynMat1_model.to(device)
        _dynMat1_optimizer = torch.optim.Adam(_dynMat1_model.parameters(), lr=5e-5)

        _dynMat1_model, _dynMat1_optimizer = train_dynMat(_dynMat1_model, _dynMat1_optimizer, tar_mat, 5000, num_bits)

        res_mat = _dynMat1_model(num_bits=num_bits)
        rel_error = torch.mean(torch.abs(tar_mat - res_mat))/torch.std(tar_mat)

        print('--- curent pick_id:', pick_id, '--- with Error:', rel_error.item())
        if rel_error < min_error:
            print('######## Pick_ID:', pick_id, '#### with Error:', rel_error.item())
            min_error = rel_error
            best_dynMat1_model = _dynMat1_model
            best_dynMat1_optimizer = _dynMat1_optimizer

    return best_dynMat1_model, best_dynMat1_optimizer



In [11]:

# Pick a relatively optimal initialization.
dynMat1_model, dynMat1_optimizer = pick_configuration(tar_mat.shape[0], tar_mat.shape[1], 52, 14, 100, tar_mat, num_bits)



--- No.Params: 24435
0 167176.78125 --- Rel Error: 63.48075485229492 15:44:57.151827
--- curent pick_id: 0 --- with Error: 0.7266885638237
######## Pick_ID: 0 #### with Error: 0.7266885638237
0 99855.53125 --- Rel Error: 48.90410232543945 15:45:16.454490
--- curent pick_id: 1 --- with Error: 0.6974816918373108
######## Pick_ID: 1 #### with Error: 0.6974816918373108
0 254.0826416015625 --- Rel Error: 1.8995635509490967 15:45:35.538680
--- curent pick_id: 2 --- with Error: 0.6880909204483032
######## Pick_ID: 2 #### with Error: 0.6880909204483032
0 4324.02099609375 --- Rel Error: 9.623333930969238 15:45:54.591602
--- curent pick_id: 3 --- with Error: 0.6922196745872498
0 31455.33984375 --- Rel Error: 27.3243465423584 15:46:13.629633
--- curent pick_id: 4 --- with Error: 0.6930917501449585
0 275.38555908203125 --- Rel Error: 1.9787579774856567 15:46:32.804646
--- curent pick_id: 5 --- with Error: 0.6881184577941895
0 1036.686279296875 --- Rel Error: 4.29775857925415 15:46:51.990866
--- cu

In [12]:

dynMat1_params = sum([param.numel() for param in dynMat1_model.parameters()])
print('--- No.Params of Weights:', tar_mat.numel())
print('--- No.Params of RieM-struct:', dynMat1_params,
      '--- Compress Ratio:', round(num_bits/32 * dynMat1_params/tar_mat.numel(),4))

dynMat1_model, dynMat1_optimizer = train_dynMat(dynMat1_model, dynMat1_optimizer, tar_mat, 1000001, num_bits)


--- No.Params of Weights: 32768
--- No.Params of RieM-struct: 24435 --- Compress Ratio: 0.1864
0 39.576995849609375 --- Rel Error: 0.6870150566101074 16:17:08.869364
50000 3.4940133094787598 --- Rel Error: 0.22884352505207062 16:20:22.178826
100000 1.8874965906143188 --- Rel Error: 0.17025858163833618 16:23:35.084027
150000 1.3928956985473633 --- Rel Error: 0.14737416803836823 16:26:47.124452
200000 1.1604812145233154 --- Rel Error: 0.1348249763250351 16:29:59.241137
250000 1.0229363441467285 --- Rel Error: 0.1264714151620865 16:33:11.784349
300000 0.9300796389579773 --- Rel Error: 0.12116575986146927 16:36:24.315625
350000 0.9887763261795044 --- Rel Error: 0.13266977667808533 16:39:36.052911
400000 0.8109122514724731 --- Rel Error: 0.1133427694439888 16:42:47.231562
450000 0.7713472843170166 --- Rel Error: 0.11129428446292877 16:45:58.982576
500000 0.7383248805999756 --- Rel Error: 0.10917417705059052 16:49:10.615333
550000 0.7096027731895447 --- Rel Error: 0.10784029215574265 16:52:2

The RieM's compression ratio is 0.1864, which is even lower than SVD's 0.188.

We observe that the data-free compression via RieM has a relative error up to 0.1013, which is significantly lower than the relative error 0.1758 obtained via SVD.