In [None]:
from transformers.models.llama.modeling_llama import (
    LlamaAttention, 
    LlamaDecoderLayer, 
    LlamaModel, 
    LlamaForCausalLM
)
from transformers import LlamaConfig
from transformers import AutoTokenizer

import torch
import torch.nn as nn
import math

from transformers.modeling_utils import load_sharded_checkpoint


class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, r=8, lora_alpha=32, dropout=0.2, base_weight=None):
        super().__init__()
        self.r = r
        self.lora_alpha = lora_alpha
        self.scale = lora_alpha / r

        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        if base_weight is not None:
            self.weight.data.copy_(base_weight)
        else:
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        self.weight.requires_grad = False

        self.lora_A = nn.Parameter(torch.zeros(r, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, r))
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

    def forward(self, x):
        result = torch.nn.functional.linear(x, self.weight)
        lora = self.dropout(x) @ self.lora_A.T
        lora = lora @ self.lora_B.T * self.scale
        return result + lora


class CustomLlamaAttention(LlamaAttention):
    def __init__(self, config, layer_idx):
        super().__init__(config, layer_idx)
        base_q = self.q_proj.weight.data.clone()
        base_v = self.v_proj.weight.data.clone()

        self.q_proj = LoRALinear(
            config.hidden_size, config.hidden_size,
            r=config.lora_r, lora_alpha=config.lora_alpha,
            base_weight=base_q
        )
        self.v_proj = LoRALinear(
            config.hidden_size, config.hidden_size,
            r=config.lora_r, lora_alpha=config.lora_alpha,
            base_weight=base_v
        )


class CustomLlamaDecoderLayer(LlamaDecoderLayer):
    def __init__(self, config, layer_idx):
        super().__init__(config, layer_idx)
        self.self_attn = CustomLlamaAttention(config, layer_idx)


class CustomLlamaModel(LlamaModel):
    def __init__(self, config):
        super().__init__(config)
        self.layers = nn.ModuleList([CustomLlamaDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])


class CustomLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.model = CustomLlamaModel(config)


if __name__ == '__main__':
    model_name = "/home/xwj/Model/llama2-7b-hf"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    config = LlamaConfig.from_pretrained(model_name)
    config.lora_r = 8
    config.lora_alpha = 32

    model = CustomLlamaForCausalLM(config)
    load_sharded_checkpoint(model, model_name, strict=False)
    model.to(device)
    model.eval()
    prompt = "The capital of France is"
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    outputs = model.generate(**inputs, max_new_tokens=50)
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))