## Testing the loading and saving of MultiLoRA models

In [1]:
from stein_lora import MultiLoraConfig, MultiLoraModel
import peft
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM
import torch as t

device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load a basic gpt2 model
base_model1 = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
base_model2 = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
# Apply regular LoRA
lora_config = LoraConfig(r=4)
lora_model = get_peft_model(base_model1, lora_config)

# Apply Multi-LoRA
multi_lora_config = MultiLoraConfig(r=4, K=5)
multi_lora_model = get_peft_model(base_model2, multi_lora_config)



### Regular LoRA

In [29]:
# save the lora model
lora_model.save_pretrained("temp/lora_model")

In [30]:
# load the lora model
lora_model2 = AutoModelForCausalLM.from_pretrained("temp/lora_model").to(device) 

In [31]:
# check if the parameters are the same
for p1, p2 in zip(lora_model.parameters(), lora_model2.parameters()):
    assert t.allclose(p1, p2)

In [32]:
for x,y in zip([x for x in lora_model.state_dict()], [x for x in lora_model2.state_dict()]):
    print(x)
    print(y)

base_model.model.transformer.wte.weight
transformer.wte.weight
base_model.model.transformer.wpe.weight
transformer.wpe.weight
base_model.model.transformer.h.0.ln_1.weight
transformer.h.0.ln_1.weight
base_model.model.transformer.h.0.ln_1.bias
transformer.h.0.ln_1.bias
base_model.model.transformer.h.0.attn.c_attn.base_layer.weight
transformer.h.0.attn.c_attn.base_layer.weight
base_model.model.transformer.h.0.attn.c_attn.base_layer.bias
transformer.h.0.attn.c_attn.base_layer.bias
base_model.model.transformer.h.0.attn.c_attn.lora_A.default.weight
transformer.h.0.attn.c_attn.lora_A.default.weight
base_model.model.transformer.h.0.attn.c_attn.lora_B.default.weight
transformer.h.0.attn.c_attn.lora_B.default.weight
base_model.model.transformer.h.0.attn.c_proj.weight
transformer.h.0.attn.c_proj.weight
base_model.model.transformer.h.0.attn.c_proj.bias
transformer.h.0.attn.c_proj.bias
base_model.model.transformer.h.0.ln_2.weight
transformer.h.0.ln_2.weight
base_model.model.transformer.h.0.ln_2.bia

In [33]:
assert lora_model.config == lora_model2.config
assert lora_model.peft_config == lora_model2.peft_config

AssertionError: 

In [34]:

assert all(t.allclose(x,y) for x,y in zip([lora_model.state_dict()[x] for x in lora_model.state_dict()],
                                 [lora_model2.state_dict()[x] for x in lora_model2.state_dict()]))

### MultiLoRA (hacking around for a solution)

In [None]:
# save the multi-lora model

# hack to let us save the model
multi_lora_model.peft_config['default'].peft_type = peft.PeftType.LORA

multi_lora_model.save_pretrained("temp/multi_lora_model")

In [None]:
# open the config and change the peft_type to multi_lora
with open("temp/multi_lora_model/adapter_config.json", "r") as f:
    config = f.read()

config = config.replace('"peft_type": "LORA"', '"peft_type": "MultiLORA"')
# config = config.replace('"peft_type": "MultiLORA"', '"peft_type": "LORA"') # revert the change


with open("temp/multi_lora_model/adapter_config.json", "w") as f:
    f.write(config)

In [None]:
# add MultiLoraModel to the peft tuner mapping
peft.mapping.PEFT_TYPE_TO_TUNER_MAPPING['MultiLORA'] = MultiLoraModel
peft.mapping.PEFT_TYPE_TO_CONFIG_MAPPING['MultiLORA'] = MultiLoraConfig

print(peft.mapping.PEFT_TYPE_TO_TUNER_MAPPING)
print(peft.mapping.PEFT_TYPE_TO_CONFIG_MAPPING)

In [None]:
# change peft.utils.save_and_load.set_peft_model_state_dict(...) to allow for MultiLoraModel

def set_peft_model_state_dict_with_multilora(
    model, peft_model_state_dict, adapter_name="default", ignore_mismatched_sizes: bool = False
):
    """
    Set the state dict of the Peft model.

    Args:
        model ([`PeftModel`]):
            The Peft model.
        peft_model_state_dict (`dict`):
            The state dict of the Peft model.
        adapter_name (`str`, *optional*, defaults to `"default"`):
            The name of the adapter whose state dict should be set.
        ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
            Whether to ignore mismatched in the state dict.
    """
    config = model.peft_config[adapter_name]
    state_dict = {}
    if getattr(model, "modules_to_save", None) is not None:
        for key, value in peft_model_state_dict.items():
            if any(module_name in key for module_name in model.modules_to_save):
                for module_name in model.modules_to_save:
                    if module_name in key:
                        key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}")
                        break
            state_dict[key] = value
    else:
        state_dict = peft_model_state_dict

    if config.peft_type in (
        PeftType.LORA,
        PeftType.LOHA,
        PeftType.LOKR,
        PeftType.ADALORA,
        PeftType.IA3,
        PeftType.OFT,
        PeftType.POLY,
        PeftType.LN_TUNING,
        PeftType.BOFT,
        PeftType.VERA,
        PeftType.FOURIERFT,
        PeftType.HRA,
        'MultiLORA'      ##################### NEW LINE
    ):
        peft_model_state_dict = {}
        parameter_prefix = {
            PeftType.IA3: "ia3_",
            PeftType.LORA: "lora_",
            PeftType.ADALORA: "lora_",
            PeftType.LOHA: "hada_",
            PeftType.LOKR: "lokr_",
            PeftType.OFT: "oft_",
            PeftType.POLY: "poly_",
            PeftType.BOFT: "boft_",
            PeftType.LN_TUNING: "ln_tuning_",
            PeftType.VERA: "vera_lambda_",
            PeftType.FOURIERFT: "fourierft_",
            PeftType.HRA: "hra_",
            'MultiLORA': "lora_"     ##################### NEW LINE
        }[config.peft_type]
        for k, v in state_dict.items():
            if parameter_prefix in k:
                suffix = k.split(parameter_prefix)[1]
                if "." in suffix:
                    suffix_to_replace = ".".join(suffix.split(".")[1:])
                    k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}")
                else:
                    k = f"{k}.{adapter_name}"
                peft_model_state_dict[k] = v
            else:
                peft_model_state_dict[k] = v

        if config.peft_type == PeftType.ADALORA:
            rank_pattern = config.rank_pattern
            if rank_pattern is not None:
                model.resize_modules_by_rank_pattern(rank_pattern, adapter_name)
        elif config.peft_type == PeftType.VERA:
            if config.save_projection and "base_model.vera_A" not in peft_model_state_dict:
                raise ValueError(
                    "Specified to load vera_A and vera_B from state dictionary however they were not present!"
                )
            elif not config.save_projection and "base_model.vera_A" in peft_model_state_dict:
                warnings.warn(
                    "Specified to not load vera_A and vera_B from state dictionary however they are present in state"
                    " dictionary! Consider using them to ensure checkpoint loading is correct on all platforms using"
                    " `peft_config.save_projection = True`"
                )
            elif not config.save_projection:  # and no vera_A in state dictionary
                warnings.warn(
                    "Specified to not load vera_A and vera_B from state dictionary. This means we will be relying on"
                    " PRNG initialisation to restore these projections using `config.projection_prng_key`, which may"
                    " not be accurate on all system configurations."
                )
        elif config.peft_type == PeftType.LORA:
            # Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
            # ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer.
            old_dora_suffix = f"lora_magnitude_vector.{adapter_name}"

            def renamed_dora_weights(k):
                if k.endswith(old_dora_suffix):
                    k = k + ".weight"
                return k

            peft_model_state_dict = {renamed_dora_weights(k): v for k, v in peft_model_state_dict.items()}

    elif config.is_prompt_learning or config.peft_type == PeftType.ADAPTION_PROMPT:
        peft_model_state_dict = state_dict
    elif config.peft_type == PeftType.XLORA:
        peft_model_state_dict = state_dict
    else:
        print(config.peft_type)
        raise NotImplementedError

    peft_model_state_dict, mismatched_keys = _find_mismatched_keys(
        model, peft_model_state_dict, ignore_mismatched_sizes=ignore_mismatched_sizes
    )
    load_result = model.load_state_dict(peft_model_state_dict, strict=False)
    if config.is_prompt_learning:
        model.prompt_encoder[adapter_name].embedding.load_state_dict(
            {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True
        )

    if config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:
        model.prompt_encoder[adapter_name].load_state_dict(peft_model_state_dict, strict=False)

    if mismatched_keys:
        # see https://github.com/huggingface/transformers/blob/09f9f566de83eef1f13ee83b5a1bbeebde5c80c1/src/transformers/modeling_utils.py#L4039
        mismatched_warning = "\n".join(
            [
                f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
                for key, shape1, shape2 in mismatched_keys
            ]
        )
        msg = (
            f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint "
            f"and are being ignored because you passed `ignore_mismatched_sizes=True`: {mismatched_warning}."
        )
        warnings.warn(msg)
    return load_result

peft.utils.save_and_load.set_peft_model_state_dict = set_peft_model_state_dict_with_multilora

In [None]:
# load the multi-lora model
multi_lora_model2 = AutoModelForCausalLM.from_pretrained("temp/multi_lora_model").to(device)

In [None]:
# load multilora model manually
base_model3 = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
saved_multilora_config = MultiLoraConfig.from_pretrained("temp/multi_lora_model")

multi_lora_model3 = MultiLoraModel(base_model3, saved_multilora_config, adapter_name='default').to(device)

In [None]:
saved_multilora_config

In [None]:
# subsitute in the saved lora weights from adapter_model.safetensors
state_dict = peft.utils.save_and_load.load_peft_weights("temp/multi_lora_model", adapter_name='default')
state_dict = {k: v for k, v in state_dict.items() if "lora" in k}
# multi_lora_model3.load_state_dict(state_dict, strict=False)

# print(state_dict)

for k, v in state_dict.items():
    lora_key = k.replace("base_model.", "").replace(".weight", ".default.weight")
    if lora_key in multi_lora_model3.state_dict():
        print(k)
        multi_lora_model3.state_dict()[lora_key].copy_(v)

In [None]:
# ml3 = [t for t,v  in multi_lora_model3.state_dict.items()]
# ml = [t for t in multi_lora_model.parameters()]
# state_dict.keys()

# ml3[0]

ml3 = [p for p in multi_lora_model3.state_dict()]
print(multi_lora_model3.state_dict().keys())

state_dict['base_model.model.transformer.h.0.attn.c_attn.lora_A.weight']
state_dict

In [None]:
# check if the parameters are the same
for p1, p2 in zip(multi_lora_model.parameters(), multi_lora_model3.parameters()):
    assert t.allclose(p1, p2)

In [None]:
isinstance(multi_lora_model, MultiLoraModel)
multi_lora_model

## MultiLORA Correct Implementation

In [35]:
from stein_lora import save_multilora_weights, apply_saved_multilora_weights
# from stein_lora import *
# import stein_lora

save_multilora_weights(multi_lora_model, "temp/multi_lora_model")
# multi_lora_model.base_model.save_pretrained("temp/multi_lora_model")

base_model4 = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
multi_lora_model4 = apply_saved_multilora_weights(base_model4, "temp/multi_lora_model")

In [36]:
# check if the parameters are the same
for p1, p2 in zip(multi_lora_model.parameters(), multi_lora_model4.parameters()):
    assert t.allclose(p1, p2)

In [37]:
for x,y in zip([x for x in multi_lora_model.state_dict()], [x for x in multi_lora_model4.state_dict()]):
    print(x)
    print(y)

base_model.model.transformer.wte.weight
model.transformer.wte.weight
base_model.model.transformer.wpe.weight
model.transformer.wpe.weight
base_model.model.transformer.h.0.ln_1.weight
model.transformer.h.0.ln_1.weight
base_model.model.transformer.h.0.ln_1.bias
model.transformer.h.0.ln_1.bias
base_model.model.transformer.h.0.attn.c_attn.base_layer.weight
model.transformer.h.0.attn.c_attn.base_layer.weight
base_model.model.transformer.h.0.attn.c_attn.base_layer.bias
model.transformer.h.0.attn.c_attn.base_layer.bias
base_model.model.transformer.h.0.attn.c_attn.lora_A.default.weight
model.transformer.h.0.attn.c_attn.lora_A.default.weight
base_model.model.transformer.h.0.attn.c_attn.lora_B.default.weight
model.transformer.h.0.attn.c_attn.lora_B.default.weight
base_model.model.transformer.h.0.attn.c_proj.weight
model.transformer.h.0.attn.c_proj.weight
base_model.model.transformer.h.0.attn.c_proj.bias
model.transformer.h.0.attn.c_proj.bias
base_model.model.transformer.h.0.ln_2.weight
model.tra

In [38]:
assert multi_lora_model.config == multi_lora_model4.config
assert multi_lora_model.peft_config == multi_lora_model4.peft_config

AssertionError: 

In [39]:
print(multi_lora_model.peft_config)
print(multi_lora_model4.peft_config)

{'default': MultiLoraConfig(peft_type='MultiLORA', auto_mapping=None, base_model_name_or_path='gpt2', revision=None, task_type=None, inference_mode=False, r=4, target_modules={'c_attn'}, lora_alpha=8, lora_dropout=0.0, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, layer_replication=None, runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=False), K=5)}
{'default': MultiLoraConfig(peft_type='MultiLORA', auto_mapping={'base_model_class': 'GPT2LMHeadModel', 'parent_library': 'transformers.models.gpt2.modeling_gpt2'}, base_model_name_or_path='gpt2', revision=None, task_type=None, inference_mode=True, r=4, target_modules={'c_attn'}, lora_alpha=8, lora_dropout=0.0, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to

In [40]:
assert all(t.allclose(x,y) for x,y in zip([multi_lora_model.state_dict()[x]  for x in multi_lora_model.state_dict()],
                                          [multi_lora_model4.state_dict()[x] for x in multi_lora_model4.state_dict()]))

# What happens if we re-save a model and then load it again?

Do the state_dict keys change?