# Blend Anything

Blend any two LoRA models. They should be from the same base model, but that is not strictly required.

This uses a very small amount of VRAM (usually < 4GB for Flux/Qwen models) by streaming the layers and blending each one individually.

This can resize the LoRAs while blending them. You can target any arbitrary rank, larger or smaller than the original.

In [1]:
from __future__ import annotations
from typing import Dict, Iterable, Optional, Tuple, Set, Literal
import torch
from tqdm.notebook import tqdm

In [7]:
_A = ".lora_A.weight"
_B = ".lora_B.weight"

# ---------- key pairing ----------
def _pair_bases(sd: Dict[str, torch.Tensor]) -> Set[str]:
    a = {k[:-len(_A)] for k in sd.keys() if k.endswith(_A)}
    b = {k[:-len(_B)] for k in sd.keys() if k.endswith(_B)}
    return a & b

def _iter_bases_union(sd1: Dict[str, torch.Tensor],
                      sd2: Optional[Dict[str, torch.Tensor]] = None,
                      include_bases: Optional[Iterable[str]] = None) -> Iterable[str]:
    bases = _pair_bases(sd1)
    if sd2:
        bases |= _pair_bases(sd2)
    if include_bases is not None:
        bases &= set(include_bases)
    # sort for deterministic output
    for base in tqdm(sorted(bases)):
        yield base

# ---------- math helpers ----------
@torch.no_grad()
def _delta_from_AB(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    # A: [r, in], B: [out, r] -> Δ: [out, in]
    return B @ A

def _best_rank_for(deltaW: torch.Tensor, target_rank: Optional[int]) -> int:
    m, n = deltaW.shape
    max_r = min(m, n)
    if target_rank is None:
        return max_r
    return max(1, min(int(target_rank), max_r))

@torch.no_grad()
def _truncated_factorization(
    deltaW: torch.Tensor,
    target_rank: int,
    method: Literal["svd", "pca_lowrank"] = "svd",
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Return (A', B') with shapes [r, in], [out, r] such that B'@A' ≈ Δ.
    - 'svd': exact/truncated SVD (robust, a bit heavier)
    - 'pca_lowrank': approximate (faster, lower memory on large mats)
    """
    r = _best_rank_for(deltaW, target_rank)

    if method == "pca_lowrank":
        # Δ ≈ U S Vh via PCA low-rank approx
        # q oversampling for better accuracy
        U, S, V = torch.pca_lowrank(deltaW, q=r + 8, center=False)
        # keep top-r
        U = U[:, :r]
        S = S[:r]
        V = V[:, :r]
        Vh = V.T
    else:
        # exact economy SVD then truncate
        U, S, Vh = torch.linalg.svd(deltaW, full_matrices=False)
        U = U[:, :r]
        S = S[:r]
        Vh = Vh[:r, :]

    sroot = torch.sqrt(torch.clamp(S, min=0))
    Bp = U * sroot.unsqueeze(0)      # [out, r]
    Ap = sroot.unsqueeze(1) * Vh     # [r, in]
    return Ap.contiguous(), Bp.contiguous()

# ---------- device / dtype helpers ----------
def _to_dev_dtype(x: torch.Tensor, device: Optional[torch.device], dtype: Optional[torch.dtype]) -> torch.Tensor:
    if device is not None:
        x = x.to(device, non_blocking=True)
    if dtype is not None and x.dtype != dtype:
        x = x.to(dtype)
    return x

def _maybe_empty_cuda():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ---------- main: streaming blend & rank-convert ----------
@torch.no_grad()
def blend_and_convert_loras_streaming(
    lora1: Dict[str, torch.Tensor],
    lora2: Optional[Dict[str, torch.Tensor]] = None,
    w1: float = 1.0,
    w2: float = 1.0,
    target_rank: Optional[int] = 32,
    compute_device: Optional[str] = None,      # e.g. "cuda:0" or "cpu"
    compute_dtype: Optional[torch.dtype] = torch.float32,
    include_bases: Optional[Iterable[str]] = None,
    drop_near_zero: bool = True,
    zero_tol: float = 1e-12,
    factor_method: Literal["svd", "pca_lowrank"] = "svd",
) -> Dict[str, torch.Tensor]:
    """
    Streaming, low-memory LoRA blend/convert.
    Processes one layer at a time; GPU RAM is bounded by one ΔW + SVD work.
    Returns a CPU state_dict with .lora_A/B tensors of target rank.
    """
    dev = torch.device(compute_device) if compute_device else None
    out_sd: Dict[str, torch.Tensor] = {}

    for base in _iter_bases_union(lora1, lora2, include_bases):
        # pull A/B from each lora if present
        A1 = lora1.get(base + _A, None)
        B1 = lora1.get(base + _B, None)
        A2 = lora2.get(base + _A, None) if lora2 else None
        B2 = lora2.get(base + _B, None) if lora2 else None

        if A1 is None or B1 is None:
            d1 = None
        else:
            A1d = _to_dev_dtype(A1, dev, compute_dtype)
            B1d = _to_dev_dtype(B1, dev, compute_dtype)
            d1 = _delta_from_AB(A1d, B1d)

        if lora2 is not None and A2 is not None and B2 is not None:
            A2d = _to_dev_dtype(A2, dev, compute_dtype)
            B2d = _to_dev_dtype(B2, dev, compute_dtype)
            d2 = _delta_from_AB(A2d, B2d)
        else:
            d2 = None

        # blended Δ
        if d1 is None and d2 is None:
            _maybe_empty_cuda()
            continue
        elif d1 is None:
            blended = w2 * d2
        elif d2 is None:
            blended = w1 * d1
        else:
            blended = w1 * d1 + w2 * d2

        # optionally skip numerically tiny layers
        if drop_near_zero and blended.abs().max().item() < zero_tol:
            del blended, d1, d2
            _maybe_empty_cuda()
            continue

        # factor to target rank (on compute device), then move to CPU
        r = _best_rank_for(blended, target_rank)
        A_new, B_new = _truncated_factorization(blended, r, method=factor_method)
        A_new_cpu = A_new.detach().to("cpu", copy=True)
        B_new_cpu = B_new.detach().to("cpu", copy=True)

        # store
        out_sd[base + _A] = A_new_cpu
        out_sd[base + _B] = B_new_cpu

        # hard cleanup for current layer before next
        del A_new, B_new, A_new_cpu, B_new_cpu, blended, d1, d2
        _maybe_empty_cuda()

    return out_sd


In [3]:
import safetensors.torch

lora1 = safetensors.torch.load_file("/mnt/models/tensors/loras/qwen_image/nsfw_qwen_bs8_r32_lowlr_000005500.safetensors")
lora2 = safetensors.torch.load_file("/mnt/models/tensors/loras/qwen_image/nsfw_qwen_resume_detail_qha.safetensors")

In [None]:
blend16 = blend_and_convert_loras_streaming(
    lora1, 
    lora2,
    w1=0.5, 
    w2=0.75, 
    target_rank=16, 
    compute_device="cuda:0", 
    compute_dtype=torch.float32,
)

safetensors.torch.save_file(blend16, "/mnt/models/tensors/loras/qwen_image/nsfw_qwen_blend_5500_qha.safetensors")

  0%|          | 0/840 [00:00<?, ?it/s]