# Hyper-Merge: Advanced Weight Merging for Stable-Diffusion

Import required modules.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
from tqdm.auto import tqdm

import torch

from hyper_merge.checkpoint import load_ckpt, save_ckpt
from hyper_merge.models import make_model_average, make_diff_model, make_hyper_model
from hyper_merge.lora import make_lora

## ⚙️ Setup Environment Configuration ⚙️

Define the `device` and `dtype` for PyTorch computations.

In [None]:
device = torch.device('cuda')
dtype = torch.float16

## 📦 Load Pre-trained Models 📦

Specify the paths to your models here.

In [None]:
models_paths = [
    # TODO add your models paths here!
]

models = [load_ckpt(path, dtype=dtype) for path in tqdm(models_paths, desc='Loading models')]

## 🌐 Create an Average Model 🌐


In [None]:
avg_model_path = Path('models/avg-model.safetensors')

# Check if average model already exists. If not, create one
if not avg_model_path.exists():
    avg_model = make_model_average(models, dtype=dtype, device=device)
    save_ckpt(avg_model, avg_model_path)

# Load the average model into GPU
avg_model = load_ckpt(avg_model_path, dtype=dtype) 

# 🧬 Create the Hyper-Merged Model 🧬

In [None]:
diff_model_path = Path('models/diff-model.safetensors')
lambda_path = Path('models/lambda.pt')

# Check if the differential model exists. If not, create one
if not diff_model_path.exists():
    diff_model, λ = make_diff_model(models, device, dtype, iterations=14)
    save_ckpt(diff_model, diff_model_path)

    torch.save(λ, lambda_path)

diff_model = load_ckpt(diff_model_path, dtype=dtype, device=device) # Load into GPU
λ = torch.load(lambda_path, map_location=device)

## 🔍 Visualize the Multipliers (λ) 🔍


In [None]:
print(λ)

## Create a LoRA from the differential weights

In [None]:
lora_path = Path('models/hyper-lora.safetensors')
lora = make_lora(diff_model, rank=64)

save_ckpt(lora, lora_path)

## 🌟 Generate Super Models 🌟

Create hyper-models using the lambda multipliers

In [None]:
# Create hyper-models from the multipliers to give closest match to the real model
for multiplier, model_path in tqdm(list(zip(λ, models_paths))):
    hyper_model = make_hyper_model(avg_model, diff_model, multiplier, device, dtype)

    name = Path(model_path).stem
    hyper_model_path = Path(f'models/hyper-model_({name})_[{multiplier}].safetensors')
    save_ckpt(hyper_model, hyper_model_path)

## 🎯 Use a Specific Multiplier 🎯


In [None]:
multiplier = 1 # Define your specific multiplier here

hyper_model = make_hyper_model(avg_model, diff_model, multiplier, device, dtype)

hyper_model_path = Path(f'models/hyper-model_{multiplier}.safetensors')
save_ckpt(hyper_model, hyper_model_path)

multiplier += 1/4