In [2]:
%load_ext autoreload
%autoreload 2

In [4]:
import torch
import transformers
from finetune_peft import get_peft_config, PEFTArguments
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType

from peft.tuners.lora import Linear
import torch.nn.functional as F
from peft.utils.other import transpose

In [3]:
model_path = "/home/ubuntu/llama-weights/7B/llama-7b"
tokenizer_path = model_path

torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = transformers.LlamaForCausalLM.from_pretrained(model_path)

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)

model = get_peft_model(model, peft_config)
torch.set_default_tensor_type(torch.cuda.FloatTensor)

tokenizer = transformers.LlamaTokenizer.from_pretrained(tokenizer_path)
batch = tokenizer("The LLaMA language model is", return_tensors="pt")

with torch.no_grad():
    out1 = model.generate(
        input_ids=batch["input_ids"],
        attention_mask=torch.ones_like(batch["input_ids"]),
        max_length=200,
    )
print(tokenizer.decode(out1[0]))

In [8]:
model.base_model.model.model.layers[0].self_attn #.q_proj.lora_A.default.weight

LlamaAttention(
  (q_proj): Linear(
    in_features=4096, out_features=4096, bias=False
    (lora_dropout): ModuleDict(
      (default): Dropout(p=0.1, inplace=False)
    )
    (lora_A): ModuleDict(
      (default): Linear(in_features=4096, out_features=8, bias=False)
    )
    (lora_B): ModuleDict(
      (default): Linear(in_features=8, out_features=4096, bias=False)
    )
    (lora_embedding_A): ParameterDict()
    (lora_embedding_B): ParameterDict()
  )
  (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (v_proj): Linear(
    in_features=4096, out_features=4096, bias=False
    (lora_dropout): ModuleDict(
      (default): Dropout(p=0.1, inplace=False)
    )
    (lora_A): ModuleDict(
      (default): Linear(in_features=4096, out_features=8, bias=False)
    )
    (lora_B): ModuleDict(
      (default): Linear(in_features=8, out_features=4096, bias=False)
    )
    (lora_embedding_A): ParameterDict()
    (lora_embedding_B): ParameterDict()
  )
  (o_proj): Linear(in_feat

In [87]:
bs = 2
seq_len = 10
ctx_dim = 4096
rank = 8
scaling = 4.0
fan_in_fan_out = False
bias = None

In [88]:
# blora
x0 = torch.randn([bs, seq_len, ctx_dim])
weight = torch.randn(([ctx_dim, ctx_dim]))

lora_a0 = torch.nn.Linear(in_features=ctx_dim, out_features=rank, bias=False)
lora_b0 = torch.nn.Linear(in_features=rank, out_features=ctx_dim, bias=False)

lora_a1 = torch.nn.Linear(in_features=ctx_dim, out_features=rank, bias=False)
lora_b1 = torch.nn.Linear(in_features=rank, out_features=ctx_dim, bias=False)

lora1 = torch.cat([lora_a0.weight, lora_a1.weight], dim=1)
lora2 = torch.cat([lora_b0.weight, lora_b1.weight], dim=0)

# forward pass
result = F.linear(x0, transpose(weight, fan_in_fan_out), bias=bias)
x = x0.reshape(seq_len, -1)
x = x.to(lora1.dtype)

out = F.linear(x, transpose(lora1, fan_in_fan_out), bias=bias)
out = scaling * F.linear(out, transpose(lora2, fan_in_fan_out), bias=bias)
out = out.reshape(bs, seq_len, -1)
result += out

In [89]:
class BLora(torch.nn.Module):
    def __init__(
        self,
        lora1: list,
        lora2: list,
        weight: torch.Tensor,
    ):
        super().__init__()
        self.lora_a = torch.nn.Parameter(torch.cat([lora1[0].weight, lora2[0].weight], dim=1))
        self.lora_b = torch.nn.Parameter(torch.cat([lora1[1].weight, lora2[1].weight], dim=0))
        self.weight = torch.nn.Parameter(weight)
        print(f"lora_a: {self.lora_a.shape}, lora_b: {self.lora_b.shape}, weight: {self.weight.shape}")

    def forward(self, x: torch.Tensor):
        print(f"x: {x.shape}")
        result = F.linear(x, transpose(self.weight, fan_in_fan_out), bias=bias)
        print(f"result: {result.shape}")
        x = x.reshape(seq_len, -1)
        x = x.to(self.lora_a.dtype)
        print(f"x: {x.shape}")

        out = F.linear(x, transpose(self.lora_a, fan_in_fan_out), bias=bias)
        print(f"out: {out.shape}")
        out = scaling * F.linear(out, transpose(self.lora_b, fan_in_fan_out), bias=bias)
        print(f"out: {out.shape}")
        out = out.reshape(bs, seq_len, -1)
        print(f"out: {out.shape}")
        result += out
        return result
    
blora = BLora(lora1=[lora_a0, lora_b0], lora2=[lora_a1, lora_b1], weight=weight)
result2 = blora(x0)

lora_a: torch.Size([8, 8192]), lora_b: torch.Size([8192, 8]), weight: torch.Size([4096, 4096])
x: torch.Size([2, 10, 4096])
result: torch.Size([2, 10, 4096])
x: torch.Size([10, 8192])
out: torch.Size([10, 8])
out: torch.Size([10, 8192])
out: torch.Size([2, 10, 4096])


In [None]:
def forward(self, x: torch.Tensor):
    previous_dtype = x.dtype
    if self.active_adapter not in self.lora_A.keys():
        return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
    if self.disable_adapters:
        if self.r[self.active_adapter] > 0 and self.merged:
            self.unmerge()
        result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
    elif self.r[self.active_adapter] > 0 and not self.merged:
        result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

        x = x.to(self.lora_A[self.active_adapter].weight.dtype)

        result += (
            self.lora_B[self.active_adapter](
                self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
            )
            * self.scaling[self.active_adapter]
        )
    else:
        result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

    result = result.to(previous_dtype)

    return result

In [None]:
# depreciated blora
x0 = torch.randn([bs, seq_len, ctx_dim])
weight = torch.randn(([ctx_dim, ctx_dim]))

lora_a0 = torch.nn.Linear(in_features=ctx_dim, out_features=rank, bias=False)
lora_b0 = torch.nn.Linear(in_features=rank, out_features=ctx_dim, bias=False)

lora_a1 = torch.nn.Linear(in_features=ctx_dim, out_features=rank, bias=False)
lora_b1 = torch.nn.Linear(in_features=rank, out_features=ctx_dim, bias=False)

lora1 = torch.nn.Linear(in_features=bs * ctx_dim, out_features=rank, bias=False)
lora2 = torch.nn.Linear(in_features=rank, out_features=bs * ctx_dim, bias=False)

lora1.weight = torch.nn.Parameter(torch.cat([lora_a0.weight, lora_a1.weight], dim=1))
lora2.weight = torch.nn.Parameter(torch.cat([lora_b0.weight, lora_b1.weight], dim=0))

# forward pass
result1 = F.linear(x0, transpose(weight, fan_in_fan_out), bias=bias)
x1 = x0.reshape(seq_len, -1)
x1 = x1.to(lora1.weight.dtype)
out1 = lora2(lora1(x1)) * scaling
out1 = out1.reshape(bs, seq_len, -1)
result1 += out1