<a href="https://colab.research.google.com/github/yf591/sd-model-merge-tool/blob/main/05_Merge_Lora_Model_Ver1_0_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Loraモデル同士のマージ（MyDriveからのロードのみを想定）

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import torch
from safetensors.torch import load_file, save_file
from typing import List, Dict
import ipywidgets as widgets
from IPython.display import display, clear_output
from getpass import getpass

In [None]:
#@title ### 関数の定義（設定）

def load_lora_weights(lora_path):
    """LoRAファイルの重みを読み込む"""
    try:
        return load_file(lora_path, device="cuda" if torch.cuda.is_available() else "cpu")
    except Exception as e:
        print(f"エラー: LoRAファイルの読み込みに失敗しました ({lora_path}): {e}")
        return None

def merge_lora_weights(lora_weights_list, alpha_list):
    """複数のLoRA重みをマージする"""
    merged_weights = {}

    # 初めのLoRAのキー構造を取得
    base_keys = set(lora_weights_list[0].keys())

    for key in base_keys:
        weights_to_merge = []
        valid_alphas = []
        for weights, alpha in zip(lora_weights_list, alpha_list):
            if key in weights:
                 weights_to_merge.append(weights[key])
                 valid_alphas.append(alpha)
            else:
                print(f"スキップ: {key}")
        if weights_to_merge:
           try:
               merged_weights[key] = sum(
                   alpha * weight for alpha, weight in zip(valid_alphas, weights_to_merge)
               )
           except Exception as e:
                print(f"スキップ: {key}, Error: {e}")
        else:
            print(f"情報: レイヤー {key} に対応する重みが見つからないため、最初のLoRAの重みを使用します。")
            merged_weights[key] = lora_weights_list[0][key]
    return merged_weights

def save_merged_lora(merged_weights, output_path):
    """マージしたLoRAを保存する"""
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    save_file(merged_weights, output_path)
    print(f"マージ済みLoRAを {output_path} に保存しました。")

In [None]:
#@title ### UI設定

#@markdown ### マージするLoraモデルを設定
num_loras = 2 #@param {type:"integer"}

# パスとアルファ値を格納するリストを初期化
lora_files = []
sliders = []
alpha_n = None

# パス入力UIの生成
for i in range(num_loras):
    lora_files.append(widgets.Text(value="", description=f"LoRA{i+1}", layout=widgets.Layout(width='75%')))

# アルファ値設定UIの生成
if num_loras > 1:
   sliders = [widgets.FloatSlider(value=1/num_loras, min=0, max=1, step=0.01, description=f"Alpha{i+1}", layout=widgets.Layout(width='50%')) for i in range(num_loras-1)]
   alpha_n = widgets.FloatText(value=1-sum([slider.value for slider in sliders]), description=f"Alpha{num_loras}", layout=widgets.Layout(width='50%'), disabled = True)
else:
    alpha_n = widgets.FloatText(value=1, description="Alpha1", layout=widgets.Layout(width='50%'), disabled=True)


def enforce_alpha_constraints(*args):
    total_alpha = sum(slider.value for slider in sliders)
    if total_alpha > 1.0:
        for slider in sliders:
            slider.value = slider.value / total_alpha
    alpha_n.value = 1 - total_alpha

for slider in sliders:
    slider.observe(enforce_alpha_constraints, 'value')

# アウトプット先の設定
output_file_widget = widgets.Text(value="/content/drive/MyDrive/sd-webui-google-colab-setup/stable-diffusion-webui/models/loras/merged_lora.safetensors", description="Output", layout=widgets.Layout(width='75%'))

In [None]:
#@title ### マージ実行関数

def execute_merge():
    lora_path_and_alphas = []
    all_alphas = sliders + [alpha_n]
    for path, alpha in zip(lora_files, all_alphas):
        if path.value and alpha.value > 0:
           lora_path_and_alphas.append({
               "path": path.value,
               "alpha": alpha.value
           })


    output_path = output_file_widget.value
    try:
        lora_weights_list = []
        alpha_list = []
        for item in lora_path_and_alphas:
            print(f"LoRAファイルを読み込み中: {item['path']}")
            lora_weights = load_lora_weights(item['path'])
            if lora_weights is not None:
                 lora_weights_list.append(lora_weights)
                 alpha_list.append(item['alpha'])
            else:
                print(f"エラー: {item['path']} の読み込みに失敗しました。スキップします。")

        for idx, weights in enumerate(lora_weights_list):
            print(f"LoRA{idx+1} keys: {list(weights.keys())[:5]} ...")

        print("LoRAをマージ中...")
        merged_weights = merge_lora_weights(lora_weights_list, alpha_list)

        print(f"マージ済みLoRAを保存中: {output_path}")
        save_merged_lora(merged_weights, output_path)
        print("LoRAマージ完了！")


    except Exception as e:
        print(f"エラーが発生しました: {e}")

In [None]:
#@title ### UIの表示と実行ボタン

if num_loras > 1:
    ui = widgets.VBox(lora_files + sliders + [alpha_n])
else:
    ui = widgets.VBox(lora_files + [alpha_n])
merge_button = widgets.Button(description="マージ実行")
merge_button.on_click(lambda x: execute_merge())
display(ui, output_file_widget, merge_button)