In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
from peft import PeftModel
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(torch.cuda.is_available())  # Should return True if CUDA is available
print(torch.cuda.device_count())  # Number of GPUs detected
print(torch.cuda.get_device_name(0))

True
1
NVIDIA GeForce GTX 1650


## Configurations

In [3]:
model_id = "mistralai/Mistral-7B-v0.1"
peft_model_gsm8k_id = "predibase/gsm8k"
peft_model_magicoder_id = "predibase/magicoder"
peft_model_gluecola_id = "predibase/glue_cola"
peft_model_legal_id = "predibase/legal"

adapter_name_gsm8k = "gsm8k"
adapter_name_magicoder = "magicoder"
adapter_name_gluecola = "glue_cola"
adapter_name_legal = "legal"

merged_adapter_name = "gsm8k_magicoder_gluecola_legal_avg"

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    compute_dtype = torch.bfloat16 # Or torch.float16 depending on your GPU
else:
    compute_dtype = torch.float32

## Load model and Adaptors

In [5]:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

In [6]:
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=compute_dtype,
    device_map={"":"cuda"}, # Automatically distributes across GPUs if available/needed
    # offload_folder='offload/'
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards: 100%|██████████| 2/2 [00:51<00:00, 25.67s/it]


In [11]:
model = PeftModel.from_pretrained(
    base_model,
    peft_model_gsm8k_id,
    adapter_name=adapter_name_gsm8k, # You can name the first adapter here
    # device_map="auto", # Apply device mapping here if needed
    low_cpu_mem_usage=True,
    offload_folder='offload/'
)
model.load_adapter(peft_model_magicoder_id, adapter_name=adapter_name_magicoder)
model.load_adapter(peft_model_gluecola_id, adapter_name=adapter_name_gluecola)
model.load_adapter(peft_model_legal_id, adapter_name=adapter_name_legal)
# model.to("cuda")



_IncompatibleKeys(missing_keys=['base_model.model.model.layers.0.self_attn.q_proj.lora_A.gsm8k_magicoder_gluecola_legal_avg.weight', 'base_model.model.model.layers.0.self_attn.q_proj.lora_B.gsm8k_magicoder_gluecola_legal_avg.weight', 'base_model.model.model.layers.0.self_attn.v_proj.lora_A.gsm8k_magicoder_gluecola_legal_avg.weight', 'base_model.model.model.layers.0.self_attn.v_proj.lora_B.gsm8k_magicoder_gluecola_legal_avg.weight', 'base_model.model.model.layers.1.self_attn.q_proj.lora_A.gsm8k_magicoder_gluecola_legal_avg.weight', 'base_model.model.model.layers.1.self_attn.q_proj.lora_B.gsm8k_magicoder_gluecola_legal_avg.weight', 'base_model.model.model.layers.1.self_attn.v_proj.lora_A.gsm8k_magicoder_gluecola_legal_avg.weight', 'base_model.model.model.layers.1.self_attn.v_proj.lora_B.gsm8k_magicoder_gluecola_legal_avg.weight', 'base_model.model.model.layers.2.self_attn.q_proj.lora_A.gsm8k_magicoder_gluecola_legal_avg.weight', 'base_model.model.model.layers.2.self_attn.q_proj.lora_B.gs

In [12]:
model.add_weighted_adapter(
    adapters=[adapter_name_gsm8k, adapter_name_magicoder, adapter_name_gluecola, adapter_name_legal],
    weights=[0.25, 0.25, 0.25, 0.25],
    adapter_name=merged_adapter_name,
    combination_type="svd" # 'linear' is the default for weighted sum
)

In [13]:
model.set_adapter(merged_adapter_name)

save_directory = f"weights/{merged_adapter_name}_svd"
print(f"\nSaving the merged adapter '{merged_adapter_name}' to {save_directory}...")
model.save_pretrained(save_directory, selected_adapters=[merged_adapter_name])
tokenizer.save_pretrained(save_directory) 


Saving the merged adapter 'gsm8k_magicoder_gluecola_legal_avg' to weights/gsm8k_magicoder_gluecola_legal_avg_svd...


('weights/gsm8k_magicoder_gluecola_legal_avg_svd\\tokenizer_config.json',
 'weights/gsm8k_magicoder_gluecola_legal_avg_svd\\special_tokens_map.json',
 'weights/gsm8k_magicoder_gluecola_legal_avg_svd\\tokenizer.json')

## Testing

In [14]:
print("\n--- Testing the merged model ---")
# Make sure the merged adapter is active (we did this in step 5)
print(f"Current active adapter: {model.active_adapter}") # Verify it's the merged one

prompt_gsm8k = "What is 5 * 8 + 3?" # Example GSM8K style
prompt_magicoder = "def fibonacci(n):" # Example Magicoder style
prompt_gluecola = 'Determine if the sentence below is syntactically and semantically correct. If it is syntactically and semantically correct, respond "1". Otherwise, respond "0".\n\nSentence: Every senator seems to become more corrupt, as he talks to more lobbyists.\n\nLabel: '
prompt_legal = 'You are a helpful, precise, detailed, and concise artificial intelligence assistant with a deep expertise in reading and interpreting legal documents. You are very intelligent and sharp, having a keen ability to discern the essential type of the legal clause from the text of the legal clause itself.\nIn this task, you are asked to determine clause type from clause text.\nYou will be evaluated based on the following criteria: - The generated answer is always one word (hyphens and underscores allowed). - The generated answer is best possible brief categorization of clause text.\nCategorize the clause text into a succinct clause type:\n### Clause Text: Any notice request consent claim demand approval waiver or other communication hereunder to Permitted Transferee shall be delivered or sent to Permitted Transferee at the address set forth on the signature page hereto in accordance with Section 7 1 of the Tax Receivable Agreement \n### Clause Type:'

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

for prompt in [prompt_gsm8k, prompt_gluecola, prompt_magicoder, prompt_legal]:
    print(f"\nPrompt: {prompt}")
    inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device)
    with torch.no_grad(): 
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            temperature=0.7,
            pad_token_id=tokenizer.pad_token_id # Important for generation
            )
    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Generated Output:\n{decoded_output}")


--- Testing the merged model ---
Current active adapter: gsm8k_magicoder_gluecola_legal_avg

Prompt: What is 5 * 8 + 3?
Generated Output:
What is 5 * 8 + 3?

Prompt: Determine if the sentence below is syntactically and semantically correct. If it is syntactically and semantically correct, respond "1". Otherwise, respond "0".

Sentence: Every senator seems to become more corrupt, as he talks to more lobbyists.

Label: 
Generated Output:
Determine if the sentence below is syntactically and semantically correct. If it is syntactically and semantically correct, respond "1". Otherwise, respond "0".

Sentence: Every senator seems to become more corrupt, as he talks to more lobbyists.

Label: 

Prompt: def fibonacci(n):
Generated Output:
def fibonacci(n):
    if n == 0:
        return 0
    if n == 1:
        return 1
    return fibonacci(n - 1) + fibonacci(n - 2)

Prompt: You are a helpful, precise, detailed, and concise artificial intelligence assistant with a deep expertise in reading and