In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import torch

  from .autonotebook import tqdm as notebook_tqdm


## Download Weights

In [None]:
model_id = "mistralai/Mistral-7B-v0.1"
peft_model_gsm8k_id = "predibase/gsm8k"
gsm8k_model = AutoModelForCausalLM.from_pretrained(model_id)
gsm8k_model.load_adapter(peft_model_gsm8k_id)

In [None]:
model_id = "mistralai/Mistral-7B-v0.1"
peft_model_magicoder_id = "predibase/magicoder"
magicoder_model = AutoModelForCausalLM.from_pretrained(model_id)
magicoder_model.load_adapter(peft_model_magicoder_id)

# Save weights
torch.save(gsm8k_model.state_dict(), "gsm8k_model_weights.pth")
torch.save(magicoder_model.state_dict(), "magicoder_model_weights.pth")

## Load Weights

In [2]:
gsm8k_state_dict = torch.load("weights/gsm8k_model_weights.pth")
magicoder_state_dict = torch.load("weights/magicoder_model_weights.pth")

In [3]:
# Verify matching shapes
for (name1, param1), (name2, param2) in zip(gsm8k_state_dict.items(), magicoder_state_dict.items()):
    assert param1.shape == param2.shape, f"Shape mismatch in {name1}"

# Perform element-wise addition and normalization
new_state_dict = {name: (param1 + param2) / 2 for (name, param1), (_, param2) in zip(gsm8k_state_dict.items(), magicoder_state_dict.items())}

torch.save(new_state_dict, "mistral_combined_weights.pth")

: 