In [None]:
%set_env CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7
from transformers import LlamaModel, LlamaForCausalLM, LlamaTokenizer, GenerationConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
import torch
import transformers

In [None]:
class LlamaRotaryEmbeddingExt(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=16384, base=10000, alpha=8, device=None):
        super().__init__()
        alpha = alpha
        base = base * alpha ** (dim / (dim-2))
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )


# Load Model

In [None]:
model = LlamaForCausalLM.from_pretrained("chinese-alpaca-plus-7b-merged", device_map="auto")
tokenizer = LlamaTokenizer.from_pretrained("chinese-alpaca-plus-7b-merged", use_fast=False)

# Inject into Model

In [None]:
def inject(alpha=1):
    for layer in model.base_model.layers:
        origin = layer.self_attn.rotary_emb
        head_dim = model.config.hidden_size // model.config.num_attention_heads
        injector = LlamaRotaryEmbeddingExt(head_dim, alpha=alpha, device=origin.inv_freq.device)
        layer.self_attn.rotary_emb = injector

# Eval

In [None]:
import json
from rouge_chinese import Rouge
import jieba

In [None]:
rouge = Rouge()
def eval(output, repeat_times, generate_len=128):
    prompt = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n\n{output*repeat_times}\n\n 请在提取上面文本的摘要\n\n### Response:\n\n"
    tokenized_sources = tokenizer(prompt, return_tensors="pt")
    result = model.generate(**tokenized_sources, max_new_tokens=64, generation_config=GenerationConfig(output_scores=False))
    output_result = tokenizer.decode(result[0]).split("### Response:\n\n")[-1]
    return rouge.get_scores(' '.join(jieba.cut(output)), ' '.join(jieba.cut(output_result)))[0]['rouge-l']['f'], tokenized_sources['input_ids'].size(1)

In [None]:
data = {}
test_alpha_list = (1, 2, 4)

In [None]:
test_content = "主要依赖于相关司法解释文件的出台，从而呈现了紧急状态下的“应急释法刑事治理模式”。"
for alpha in test_alpha_list:
    data[alpha] = []
    print(f"### {alpha}")
    for i in range(1, 1000, 5):
        inject(alpha)
        f1, token = eval(test_content, i)
        print(f"{token}: {f1}")
        data[alpha].append((token, f1))
        if f1 == 0:
            break

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
plt.figure(1)
plt.plot(np.array(data[1])[:, 0], np.array(data[1])[:, 1])
plt.plot(np.array(data[2])[:, 0], np.array(data[2])[:, 1])
plt.plot(np.array(data[4])[:, 0], np.array(data[4])[:, 1])
plt.xlabel("tokens")
plt.ylabel("rouge-f f1 score")
plt.legend(["alpha=1", "alpha=2", "alpha=4"])
plt.savefig("result.png", dpi=200)