In [None]:
import torch
from torch import matmul, exp, log, abs
import torch.nn as nn
from torch.utils.data import DataLoader

from IPython.display import clear_output
from tqdm.notebook import tqdm

from transformers import AutoTokenizer,AutoModelForQuestionAnswering, AutoModelForMultipleChoice, AutoModelForSequenceClassification, AutoModel, BertModel, ElectraModel, set_seed
from transformers import TrainingArguments, Trainer
from transformers import DataCollatorWithPadding
from transformers.modeling_outputs import TokenClassifierOutput, SequenceClassifierOutput
from transformers import AdamW,get_scheduler

from datasets import load_metric, load_dataset

import random
import numpy as np
import evaluate
import math
from peft import get_peft_model, LoraConfig, TaskType, PeftConfig, PeftModel

import matplotlib.pyplot as plt

from numpy import linalg as LA

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
device

'cuda'

In [None]:
# Setting seed
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

set_seed(seed = SEED)

In [None]:
# language model hyper parameters
batch_size = 40
epochs = 3
lr = 1e-5
weight_decay = 0.1

ALPHA=1.0
RANK=1

## Load Models fine-tuned with LoRA

In [None]:

sst2_42_model_id = './models/bert_lora_sst2_42'
imdb_model_id = './models/bert_lora_imdb_42'

sst2_42_config = PeftConfig.from_pretrained(sst2_42_model_id)
imdb_config = PeftConfig.from_pretrained(imdb_model_id)

sst2_42_model = AutoModelForSequenceClassification.from_pretrained(sst2_42_config.base_model_name_or_path)
sst2_42_model = PeftModel.from_pretrained(sst2_42_model, sst2_42_model_id)

imdb_model = AutoModelForSequenceClassification.from_pretrained(imdb_config.base_model_name_or_path)
imdb_model = PeftModel.from_pretrained(imdb_model, imdb_model_id)

tokenizer = AutoTokenizer.from_pretrained(sst2_28_config.base_model_name_or_path)

## Grassman Similarity

In [None]:
def grassman_dist(u_A, u_B, i, j):

    # print(u_A[:, :i].T.shape, u_B[:, :j].shape)A
    # return (LA.norm(np.matmul(u_A[:, :i].T, u_B[:, :j])) ** 2) / min(i, j)
    return torch.div(torch.pow(torch.norm(torch.matmul(u_A[:, :i].T, u_B[:, :j]), p='fro'), 2), min(i, j)).numpy()

## Define plotting function

In [None]:
def plot_all_layers(modelA: nn.Module, modelB: nn.Module, max_rank: int, title: str):
    # Create a 3x4 grid of subplots
    fig, axs = plt.subplots(nrows=3, ncols=4, figsize=(11, 9))

    # to store max similarity
    vmax = 0
    sumation=0
    all_layers = np.zeros((12, 4, 4))
    # Loop over each subplot and plot a heatmap
    for l, ax in enumerate(axs.flat):
        lora_matrix_1 = torch.matmul(
            modelA.base_model.bert.encoder.layer[l].attention.self.value.lora_B.default.weight,
            modelA.base_model.bert.encoder.layer[l].attention.self.value.lora_A.default.weight).detach()
        lora_matrix_2= torch.matmul(
            modelB.base_model.bert.encoder.layer[l].attention.self.value.lora_B.default.weight,
            modelB.base_model.bert.encoder.layer[l].attention.self.value.lora_A.default.weight).detach()

        # empty matrix to store grassman distance for different i and j
        dist_matrix = np.zeros((max_rank, max_rank))

        # SVD decomposition
        u_A, s_A, v_A = torch.linalg.svd(lora_matrix_1)
        u_B, s_B, v_B = torch.linalg.svd(lora_matrix_2)

        # calculate grassman dist for different i and j
        for i in range(1, max_rank+1):
            for j in range(1, max_rank+1):
                dist_matrix[i-1][j-1] = grassman_dist(u_A, u_B, i, j)

        vmax = max(np.max(dist_matrix), vmax)
        sumation+= np.max(dist_matrix)
        all_layers[l] = dist_matrix
        # Plot the heatmap on the current subplot
        im = ax.imshow(dist_matrix, cmap='hot',  vmin=0, vmax=0.5)
        # print(dist_matrix)
        # print(np.max(dist_matrix))
        # Add a title to the subplot
        ax.set_title(f'L{l}: {np.max(dist_matrix):.2f}', fontsize=20)
        ax.set_xticks(range(4))
        ax.set_yticks(range(4))
        ax.set_xticklabels(range(1, 5))
        ax.set_yticklabels(range(1, 5))
        ax.tick_params(axis='both', which='major', labelsize=12)

    # Add a colorbar to the figure
    # fig.colorbar(im, ax=axs.ravel().tolist())
    # Create a colorbar
    cbar = fig.colorbar(im, ax=axs.ravel().tolist())

    # Change the font size of the color bar
    cbar.ax.tick_params(labelsize=20)
    # fig.suptitle(title)
    avg = sumation/12
    print(avg)

    plt.show()
    # return all_layers, avg

In [None]:
plot_all_layers(sst2_42_model, imdb_model, 4, 'sst2 vs imdb')