In [None]:
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import datasets
import torch
import torch.nn as nn

In [None]:
baichuan_13B_path = 'YOUR PATH HERE'
baichuan_13B = AutoModelForCausalLM.from_pretrained(baichuan_13B_path, trust_remote_code=True)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(baichuan_13B_path, trust_remote_code=True)

In [None]:
INTERVAL = 1
MERGE_LAYERS = 3
HIGHEST_LAY = 39
LOWEST_LAY = 0
THRESHOLD = 0.6

In [None]:
from copy import deepcopy
def merge_layers_return_model(model, merge_base_lay, merge_layer_num):
    merge_layer_num = min(merge_layer_num, len(model.model.layers) - merge_base_lay - 1)
    
    model_copy = deepcopy(model)
    for diff_lay in range(merge_base_lay+1, merge_base_lay+1+merge_layer_num):      
        # gate_proj
        model_copy.model.layers[merge_base_lay].mlp.gate_proj.weight.data.add_(
            model.model.layers[diff_lay].mlp.gate_proj.weight.data - model_copy.model.layers[merge_base_lay].mlp.gate_proj.weight.data
        )
        # down_proj
        model_copy.model.layers[merge_base_lay].mlp.down_proj.weight.data.add_(
            model.model.layers[diff_lay].mlp.down_proj.weight.data - model_copy.model.layers[merge_base_lay].mlp.down_proj.weight.data
        )
        # up_proj
        model_copy.model.layers[merge_base_lay].mlp.up_proj.weight.data.add_(
            model.model.layers[diff_lay].mlp.up_proj.weight.data - model_copy.model.layers[merge_base_lay].mlp.up_proj.weight.data
        )
        # W_pack
        model_copy.model.layers[merge_base_lay].self_attn.W_pack.weight.data.add_(
            model.model.layers[diff_lay].self_attn.W_pack.weight.data - model_copy.model.layers[merge_base_lay].self_attn.W_pack.weight.data
        )
        # o_proj
        model_copy.model.layers[merge_base_lay].self_attn.o_proj.weight.data.add_(
            model.model.layers[diff_lay].self_attn.o_proj.weight.data - model_copy.model.layers[merge_base_lay].self_attn.o_proj.weight.data
        )
    for diff_lay in range(merge_base_lay+merge_layer_num, merge_base_lay, -1):
        del(model_copy.model.layers[diff_lay])
    return model_copy

In [None]:
import copy
baichuan_13B_copy_to_compress = copy.deepcopy(baichuan_13B)

In [None]:
import numpy as np
def cal_last_hidden_sim(model1, model2, tokenizer, sents):
    sim_ls = []
    for s in sents:
        encoded_inputs = tokenizer(s, return_tensors='pt')
        with torch.no_grad():
            outputs1 = model1(**encoded_inputs, output_hidden_states=True)
        hidden_states1 = outputs1.hidden_states[-1] # (1, seq_len, hidden)
        with torch.no_grad():
            outputs2 = model2(**encoded_inputs, output_hidden_states=True)
        hidden_states2 = outputs2.hidden_states[-1] # (1, seq_len, hidden)
        sim_ls.append(torch.cosine_similarity(hidden_states1.squeeze(0).flatten().unsqueeze(0), hidden_states2.squeeze(0).flatten().unsqueeze(0)))
    sim_ls = [i.item() for i in sim_ls]
    print(np.mean(sim_ls), sim_ls)
    
    return np.mean(sim_ls)

In [None]:
lay = HIGHEST_LAY - MERGE_LAYERS
last_merge_flag = False

sents= ['Torreorgaz is a municipality in the', 'The 81st Mechanised Brigade () is a']

while lay >= LOWEST_LAY:
    print(lay)
    print('current model layer', len(baichuan_13B_copy_to_compress.model.layers))
    tmp_merged_model = merge_layers_return_model(baichuan_13B_copy_to_compress, lay, MERGE_LAYERS-1)
    sim_value = cal_last_hidden_sim(baichuan_13B, tmp_merged_model, tokenizer, sents)
    if sim_value > THRESHOLD:
        baichuan_13B_copy_to_compress = tmp_merged_model
        lay -= INTERVAL
        if lay >= len(baichuan_13B_copy_to_compress.model.layers):
            lay = len(baichuan_13B_copy_to_compress.model.layers) - 1 - MERGE_LAYERS
    else:
        lay -= 1
    

In [None]:
baichuan_13B_copy_to_compress.config.num_hidden_layers = len(baichuan_13B_copy_to_compress.model.layers)