In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
from peft import PeftModel, PeftConfig, get_peft_model
import torch
from tqdm import tqdm

In [2]:
print(torch.cuda.is_available())  # Should return True if CUDA is available
print(torch.cuda.device_count())  # Number of GPUs detected
print(torch.cuda.get_device_name(0))

True
1
NVIDIA GeForce GTX 1650


## Configurations

In [3]:
model_id = "mistralai/Mistral-7B-v0.1"
peft_model_gsm8k_id = "predibase/gsm8k"
peft_model_magicoder_id = "predibase/magicoder"
peft_model_gluecola_id = "predibase/glue_cola"
peft_model_hellaswag_id = "predibase/hellaswag"

adapter_name_gsm8k = "gsm8k"
adapter_name_magicoder = "magicoder"
adapter_name_gluecola = "glue_cola"
adapter_name_hellaswag = "hellaswag"

merged_adapter_name = "gsm8k_magicoder_hellaswag"

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    compute_dtype = torch.bfloat16 # Or torch.float16 depending on your GPU
else:
    compute_dtype = torch.float32

## Load model and Adaptors

In [5]:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

In [6]:
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=compute_dtype,
    device_map={"":"cuda"}, # Automatically distributes across GPUs if available/needed
    # offload_folder='offload/'
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
model = PeftModel.from_pretrained(
    base_model,
    peft_model_gsm8k_id,
    adapter_name=adapter_name_gsm8k, # You can name the first adapter here
    # device_map="auto", # Apply device mapping here if needed
    low_cpu_mem_usage=True,
    offload_folder='offload/'
)
model.load_adapter(peft_model_magicoder_id, adapter_name=adapter_name_magicoder)
# model.load_adapter(peft_model_gluecola_id, adapter_name=adapter_name_gluecola)
model.load_adapter(peft_model_hellaswag_id, adapter_name=adapter_name_hellaswag)
# model.to("cuda")

<All keys matched successfully>

In [8]:
state_dict = model.state_dict()
# merging 2
# lora_a = {k: v for k, v in state_dict.items() if adapter_name_gluecola in k and "lora_" in k}
# lora_b = {k.replace(adapter_name_hellaswag, adapter_name_gluecola): v for k, v in state_dict.items() if adapter_name_hellaswag in k and "lora_" in k}

# combined_lora = {}
# for k in tqdm(lora_a, desc="Combining LoRA weights"):
#     combined_lora[k] = lora_a[k] + lora_b[k]

# merging 3
lora_a = {k: v for k, v in state_dict.items() if adapter_name_gsm8k in k and "lora_" in k}
lora_b = {k.replace(adapter_name_magicoder, adapter_name_gsm8k): v for k, v in state_dict.items() if adapter_name_magicoder in k and "lora_" in k}
lora_c = {k.replace(adapter_name_hellaswag, adapter_name_gsm8k): v for k, v in state_dict.items() if adapter_name_hellaswag in k and "lora_" in k}

combined_lora = {}
for k in tqdm(lora_a, desc="Combining LoRA weights"):
    combined_lora[k] = lora_a[k] + lora_b[k] + lora_c[k]

# merge 4
# lora_a = {k: v for k, v in state_dict.items() if adapter_name_gsm8k in k and "lora_" in k}
# lora_b = {k.replace(adapter_name_gluecola, adapter_name_gsm8k): v for k, v in state_dict.items() if adapter_name_gluecola in k and "lora_" in k}
# lora_c = {k.replace(adapter_name_magicoder, adapter_name_gsm8k): v for k, v in state_dict.items() if adapter_name_magicoder in k and "lora_" in k}
# lora_d = {k.replace(adapter_name_hellaswag, adapter_name_gsm8k): v for k, v in state_dict.items() if adapter_name_hellaswag in k and "lora_" in k}

# combined_lora = {}
# for k in tqdm(lora_a, desc="Combining LoRA weights"):
#     combined_lora[k] = lora_a[k] + lora_b[k] + lora_c[k] + lora_d[k]

Combining LoRA weights: 100%|██████████| 128/128 [00:00<00:00, 4453.00it/s]


In [9]:
config = PeftConfig.from_pretrained(peft_model_gsm8k_id)

# Update the adapter name to be the new one
config.peft_type = "LORA"
config.task_type = "CAUSAL_LM"  # or whatever matches your use case
config.inference_mode = False
config.adapter_name = merged_adapter_name

# Add the new adapter slot into the model
model.add_adapter(merged_adapter_name, config)

In [10]:
with torch.no_grad():
    for name, param in model.named_parameters():
        if merged_adapter_name in name and "lora_" in name:
            # Map to the combined key
            base_name = name.replace(merged_adapter_name, adapter_name_gsm8k)
            if base_name in combined_lora:
                param.copy_(combined_lora[base_name])

In [11]:
model.set_adapter(merged_adapter_name)

save_directory = f"weights/element_add/{merged_adapter_name}"
print(f"\nSaving the merged adapter '{merged_adapter_name}' to {save_directory}...")
model.save_pretrained(save_directory, selected_adapters=[merged_adapter_name])
tokenizer.save_pretrained(save_directory) 


Saving the merged adapter 'gsm8k_magicoder_hellaswag' to weights/element_add/gsm8k_magicoder_hellaswag...


('weights/element_add/gsm8k_magicoder_hellaswag\\tokenizer_config.json',
 'weights/element_add/gsm8k_magicoder_hellaswag\\special_tokens_map.json',
 'weights/element_add/gsm8k_magicoder_hellaswag\\tokenizer.json')

## Testing

In [12]:
print("\n--- Testing the merged model ---")
# Make sure the merged adapter is active (we did this in step 5)
print(f"Current active adapter: {model.active_adapter}") # Verify it's the merged one

prompt_gsm8k = "What is 5 * 8 + 3?" # Example GSM8K style
prompt_magicoder = "def fibonacci(n):" # Example Magicoder style
prompt_gluecola = 'Determine if the sentence below is syntactically and semantically correct. If it is syntactically and semantically correct, respond "1". Otherwise, respond "0".\n\nSentence: Every senator seems to become more corrupt, as he talks to more lobbyists.\n\nLabel: '
prompt_hellaswag = 'You are provided with an incomplete passage below as well as 4 endings in quotes and separated by commas, with only one of them being the correct ending. Treat the endings as being labelled 0, 1, 2, 3 in order. Please respond with the number corresponding to the correct ending for the passage.\n\n### Passage: The mother instructs them on how to brush their teeth while laughing. The boy helps his younger sister brush his teeth. she\n\n### Endings: [\'shows how to hit the mom and then kiss his dad as well.\' \'brushes past the camera, looking better soon after.\' \'glows from the center of the camera as a reaction.\' \'gets them some water to gargle in their mouths.\']\n\n### Correct Ending Number: '
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

for prompt in [prompt_magicoder, prompt_hellaswag, prompt_gsm8k, prompt_gluecola]:
    print(f"\nPrompt: {prompt}")
    inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device)
    with torch.no_grad(): 
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            temperature=0.7,
            pad_token_id=tokenizer.pad_token_id # Important for generation
            )
    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Generated Output:\n{decoded_output}")


--- Testing the merged model ---
Current active adapter: gsm8k_magicoder_hellaswag

Prompt: def fibonacci(n):




Generated Output:
def fibonacci(n):

Prompt: You are provided with an incomplete passage below as well as 4 endings in quotes and separated by commas, with only one of them being the correct ending. Treat the endings as being labelled 0, 1, 2, 3 in order. Please respond with the number corresponding to the correct ending for the passage.

### Passage: The mother instructs them on how to brush their teeth while laughing. The boy helps his younger sister brush his teeth. she

### Endings: ['shows how to hit the mom and then kiss his dad as well.' 'brushes past the camera, looking better soon after.' 'glows from the center of the camera as a reaction.' 'gets them some water to gargle in their mouths.']

### Correct Ending Number: 
Generated Output:
You are provided with an incomplete passage below as well as 4 endings in quotes and separated by commas, with only one of them being the correct ending. Treat the endings as being labelled 0, 1, 2, 3 in order. Please respond with the number co