In [4]:
import sys
import logging
import bitsandbytes as bnb
from bitsandbytes.nn import Linear4bit
import bitsandbytes as bnb
import tqdm
import datasets
from datasets import load_dataset
from peft import LoraConfig,PeftConfig, PeftModel, PeftModelForCausalLM
import torch
import transformers
from trl import SFTTrainer
from typing import List, Dict, Any, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
import safetensors
import torch.nn as nn
from functools import partial


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
base_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
safetensor_path = "/scratch/tathagato/redo_adapter_experiments/length/length/adapter_model.safetensors"
adapter_path = "/scratch/tathagato/redo_adapter_experiments/length/length/"

In [6]:
base_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.float16
    )
model_kwargs = dict(
        use_cache=False,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        device_map=None,
        cache_dir = "/scratch/tathagato",
        attn_implementation = "eager",
        quantization_config = nf4_config, 

    )
quantized_model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(base_model,cache_dir = "/scratch/tathagato")

`low_cpu_mem_usage` was None, now set to True since model is quantized.


In [7]:
non_quantized_model = AutoModelForCausalLM.from_pretrained(base_model, trust_remote_code=True, use_cache=False, cache_dir = "/scratch/tathagato")
print(non_quantized_model)
print(non_quantized_model.model.layers[0].self_attn.k_proj.weight.shape)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Line

In [8]:

class CascadedLoRALinear4bit(torch.nn.Module):
    def __init__(self, linear, in_dim, out_dim, rank_1 = 64, rank_2 = 32, alpha_1 = 16, alpha_2 = 16, adapter_name = "default" , dropout = None):
        super().__init__()
        self.base_layer = linear
        std_dev_1 = 1 / torch.sqrt(torch.tensor(rank_1).float())
        std_dev_2 = 1 / torch.sqrt(torch.tensor(rank_2).float())
        if dropout is not None:
            self.lora_dropout = nn.ModuleDict(
                {
                    adapter_name : torch.nn.Dropout(dropout)
                }
            )
        #first dimension
        self.lora_A = nn.ModuleDict(
            {
                adapter_name : torch.nn.Linear(in_dim, rank_1, bias = False)
            }
        )
        self.lora_B = nn.ModuleDict(
            {
                adapter_name : torch.nn.Linear(rank_1, out_dim, bias = False)
            }
        )
        self.lora_A[adapter_name].weight = torch.nn.Parameter(torch.randn(rank_1, in_dim) * std_dev_1)
        self.lora_B[adapter_name].weight = torch.nn.Parameter(torch.zeros(out_dim, rank_1))  

        self.lora_A1 = nn.ModuleDict(
            {
                adapter_name : torch.nn.Linear(in_dim, rank_2, bias = False)
            }
        )
        self.lora_A2 = nn.ModuleDict(
            {
                adapter_name : torch.nn.Linear(rank_2, rank_1, bias = False)
            }
        )
        self.lora_B1 = nn.ModuleDict(
            {
                adapter_name : torch.nn.Linear(rank_1, rank_2, bias = False)
            }
        )
        self.lora_B2 = nn.ModuleDict(
            {
                adapter_name : torch.nn.Linear(rank_2, out_dim, bias = False)
            }
        )
        self.lora_A1[adapter_name].weight = torch.nn.Parameter(torch.randn(rank_2, in_dim) * std_dev_2)
        self.lora_A2[adapter_name].weight = torch.nn.Parameter(torch.zeros(rank_1, rank_2))
        self.lora_B1[adapter_name].weight = torch.nn.Parameter(torch.zeros(rank_2, rank_1))
        self.lora_B2[adapter_name].weight = torch.nn.Parameter(torch.zeros(out_dim, rank_2) * std_dev_2)  
        self.alpha_1 = alpha_1
        self.alpha_2 = alpha_2
        self.rank_1 = rank_1
        self.rank_2 = rank_2
        self.is_second_layar_being_trained = False
        self.is_first_layer_being_trained = False
        self.is_first_layer_being_used_for_inference = True
        self.is_first_layer_being_used_for_inference = True
        self.scaling_1 = self.rank_1 / self.alpha_1
        self.scaling_2 = self.rank_2 / self.alpha_1
        self.adapter_name = adapter_name



    def set_gradients_for_all_layer(self):
        if self.is_second_layar_being_trained:
            self.lora_A1[self.adapter_name].requires_grad = True
            self.lora_A2[self.adapter_name].requires_grad = True
            self.lora_B1[self.adapter_name].requires_grad = True
            self.lora_B2[self.adapter_name].requires_grad = True


        else:
            self.lora_A1[self.adapter_name].requires_grad = False
            self.lora_A2[self.adapter_name].requires_grad = False
            self.lora_B1[self.adapter_name].requires_grad = False
            self.lora_B2[self.adapter_name].requires_grad = False
            
        if self.is_first_layer_being_trained:
            self.lora_A[self.adapter_name].requires_grad = True
            self.lora_B[self.adapter_name].requires_grad = True
        else:
            self.lora_A[self.adapter_name].requires_grad = False
            self.lora_B[self.adapter_name].requires_grad = False  
    
    def tune_the_first_adapter(self):
        self.is_first_layer_being_trained = True
    
    def freeze_the_first_adapter(self):
        self.is_first_layer_being_trained = False
    
    def tune_the_second_adapter(self):
        self.is_second_layar_being_trained = True
    
    def freeze_the_second_adapter(self):
        self.is_second_layar_being_trained = False

    def forward(self, x):

        self.set_gradients_for_all_layer()
        if self.is_first_layer_being_used_for_inference and self.is_second_layer_being_used_for_inference:
            #x = self.scaling_1 * (x @ self.W1_a @ self.W1_b) + self.scaling_2 * (x @ self.W2_a1 @ self.W2_a2 @ self.W2_b1 @ self.W2_b2)
            output  = self.linear(x) + self.scaling_1 * (self.W1['A'](self.W1['B'](x))) + self.scaling_2 * (self.W2['B2'](self.W2['A2'](self.W2['B1'](self.W2['A1'](x)))))
        if self.is_first_layer_being_used_for_inference and not self.is_second_layer_being_used_for_inference:
            #x = self.scaling_2 * (x @ self.W2_a1 @ self.W2_a2) 
            output  =  self.linear(x)  + self.scaling_1 * (self.W1['A'](self.W1['B'](x))) 
        return output
    







In [9]:

rank_1 = 64
rank_2 = 32
alpha_1 = 16
alpha_2 = 16
adapter_name = "test"
dropout = 0.05
target_modules = [
                    'q_proj',
                    'k_proj',
                    'v_proj',
                    'o_proj',
                    'gate_proj',
                    'up_proj',
                    'down_proj'
]


In [30]:
def replace_with_cascaded_lora(module, target_modules = target_modules, rank_1 = 64, rank_2 = 32, alpha_1 = 16 , alpha_2 = 16 , adapter_name = "default" , dropout = None):
    for name, child in module.named_children():
        if isinstance(child, bnb.nn.Linear4bit) and name in target_modules:
            #setattr(module, name, CascadedLoRALinear4bit(child, in_dim, out_dim, **kwargs))
            #print(name)
            #print(child.in_features, child.out_features)
            #get the device of the child
            setattr(module, name, CascadedLoRALinear4bit(child, child.in_features, child.out_features, rank_1, rank_2, alpha_1, alpha_2, adapter_name , dropout = dropout))
            #put everything in device 
        else:
            replace_with_cascaded_lora(child, target_modules, rank_1, rank_2, alpha_1, alpha_2, adapter_name , dropout = None)
def print_device_and_dtype(model, file = sys.stdout):
    if file == sys.stdout:
            for name, module in model.named_modules():
            # Get the device and dtype of the module's parameters
            #file = open(file, "a")
                try:
                    param = next(module.parameters())
                    device = param.device
                    dtype = param.dtype
                    type = param.type()
                except StopIteration:
                    device = 'No parameters'
                    dtype = 'No parameters'
                    type = 'No parameters'

                
                # Print the name, device, and dtype of the module
                print(f"Module: {name}", file = file)
                print(f"  Device: {device}", file = file)
                print(f"  Dtype: {dtype}", file = file)
                print(f"  Type: {type}", file = file)
                print(" ",file = file )
            return 

    with open(file, "w") as file:

        for name, module in model.named_modules():
            # Get the device and dtype of the module's parameters
            #file = open(file, "a")
            try:
                param = next(module.parameters())
                device = param.device
                dtype = param.dtype
                type = param.type()
            except StopIteration:
                device = 'No parameters'
                dtype = 'No parameters'
                type = 'No parameters'

            
            # Print the name, device, and dtype of the module
            print(f"Module: {name}", file = file)
            print(f"  Device: {device}", file = file)
            print(f"  Dtype: {dtype}", file = file)
            print(f"  Type: {type}", file = file)
            print(" ",file = file )
# Function to ensure all submodules are on GPU
def move_to_device(model, device):
    for name, module in model.named_modules():
        try:
            # Check if the module is already on the device
            param = next(module.parameters())
            if param.device != device:
                # Move the module to the specified device
                module.to(device)
                #print(f"Moved module: {name} to {device}")
        except StopIteration:
            # No parameters in the module
            pass



replace_with_cascaded_lora(quantized_model)
#print(quantized_model)
move_to_device(quantized_model, torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
print_device_and_dtype(quantized_model, file = "cascaded_lora_structure.txt")



In [25]:
ls

cascaded_lora.ipynb          extract_outputs.ipynb       test_model.ipynb
cascaded_lora_structure.txt  peft_adapter_structure.txt
cascaded_model.ipynb         test_dataset.ipynb


In [24]:
base_model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
base_model_quantized = AutoModelForCausalLM.from_pretrained(base_model_path, **model_kwargs)
adapter_model = PeftModelForCausalLM.from_pretrained(base_model_quantized, adapter_path, "test")
#print(adapter_model)
print_device_and_dtype(adapter_model, file = "./peft_adapter_structure.txt")

`low_cpu_mem_usage` was None, now set to True since model is quantized.
