💫 Checkpoint Merger


| |Github|
|:--|:-:|
| 🏠 **Profile** | [![GitHub](https://img.shields.io/badge/GitHub-%23121011.svg?logo=github&logoColor=white)](https://github.com/xLegende)|
| 📘 **Repos** | [![GitHub](https://img.shields.io/badge/GitHub-%23121011.svg?logo=github&logoColor=white)](https://github.com/xLegende/merger_colab)|

In [None]:
#@title 💫 Merger
"""
Checkpoint Merger (Safetensors) - Google Colab Notebook

Merges two Stable Diffusion safetensors checkpoints using a weighted average.
Files are loaded from and saved to your Google Drive.

Instructions:
1. Run each cell in order.
2. Mount your Google Drive.
3. Enter the paths to your two checkpoint files and the desired output path in the UI.
4. Adjust the merge ratio using the slider.
5. Run the merging process.
6. Your merged checkpoint will be saved in your Google Drive.

Important Notes:
- This notebook assumes both checkpoints have compatible architectures (e.g., both are SDXL models).
- Merging can be memory-intensive. If you encounter issues, try reducing the merge ratio or using a smaller checkpoint.
- Always back up your original checkpoints before merging.
- Experiment with different merge ratios to achieve the desired blend.
"""

# Install necessary libraries
print("Installing libraries...")
!pip install safetensors torch -q
print("Libraries installed.")

import os
import torch
import safetensors.torch
from google.colab import drive
from IPython.display import display, Markdown

# --- Mount Google Drive ---
print("Mounting Google Drive...")
drive.mount('/content/drive')
print("Google Drive mounted.")

# --- UI Parameters ---
display(Markdown("## Checkpoint Merge Parameters"))

checkpoint_path_1 = "/content/drive/MyDrive/models/checkpoint_base.safetensors" #@param {type:"string"}
checkpoint_path_2 = "/content/drive/MyDrive/models/checkpoint_other.safetensors" #@param {type:"string"}
output_path = "/content/drive/MyDrive/models/merged_new.safetensors" #@param {type:"string"}
merge_ratio = 0.59 #@param {type:"slider", min:0.0, max:1.0, step:0.01}

display(Markdown("---"))

# --- Checkpoint Merging Function ---
def merge_checkpoints(checkpoint_path_1, checkpoint_path_2, output_path, merge_ratio):
    """
    Merges two safetensors checkpoints using a weighted average and saves the result.

    Args:
        checkpoint_path_1 (str): Path to the first checkpoint file.
        checkpoint_path_2 (str): Path to the second checkpoint file.
        output_path (str): Path to save the merged checkpoint file.
        merge_ratio (float): Ratio for merging (0.0 to 1.0). 0.0 means only checkpoint_2, 1.0 means only checkpoint_1.
                             Values in between are a weighted average.
    """
    print(f"Loading checkpoint 1: {checkpoint_path_1}")
    if not os.path.exists(checkpoint_path_1):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path_1}")
    checkpoint_1 = safetensors.torch.load_file(checkpoint_path_1, device='cpu')
    print(f"Loaded checkpoint 1.")

    print(f"Loading checkpoint 2: {checkpoint_path_2}")
    if not os.path.exists(checkpoint_path_2):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path_2}")
    checkpoint_2 = safetensors.torch.load_file(checkpoint_path_2, device='cpu')
    print(f"Loaded checkpoint 2.")

    print("Merging checkpoints...")
    merged_state_dict = {}
    for key in checkpoint_1.keys():
        if key in checkpoint_2:
            tensor_1 = checkpoint_1[key]
            tensor_2 = checkpoint_2[key]
            merged_state_dict[key] = merge_ratio * tensor_1 + (1 - merge_ratio) * tensor_2
        else:
            merged_state_dict[key] = checkpoint_1[key] # If key not in checkpoint_2, take from checkpoint_1 (you might need to adjust this logic)

    print(f"Saving merged checkpoint to: {output_path}")
    safetensors.torch.save_file(merged_state_dict, output_path)
    print(f"Merged checkpoint saved to: {output_path}")
    print("Checkpoint merging complete!")

# --- Run Merging Process ---
if __name__ == "__main__":
    try:
        merge_checkpoints(checkpoint_path_1, checkpoint_path_2, output_path, merge_ratio)
        display(Markdown(f"**✅ Merging completed successfully! Merged checkpoint saved to:** `{output_path}`"))
    except FileNotFoundError as e:
        display(Markdown(f"**❌ Error: File not found:** `{e}`. Please check your checkpoint paths."))
    except Exception as e:
        display(Markdown(f"**❌ An error occurred during merging:** `{e}`. Check the error message above for details."))

display(Markdown("---"))
display(Markdown("### Instructions after Merging:"))
display(Markdown("- You can now use the merged checkpoint in your Stable Diffusion setup."))
display(Markdown("- Test the merged checkpoint with various prompts to see the combined effect."))
display(Markdown("- Experiment with different `merge_ratio` values and different checkpoint combinations to achieve your desired results."))