# Blend Anything

Blend any two LoRA models. They should be from the same base model, but that is not strictly required.

This uses a very small amount of VRAM (usually < 4GB for Flux/Qwen models) by streaming the layers and blending each one individually.

This can resize the LoRAs while blending them. You can target any arbitrary rank, larger or smaller than the original.

In [9]:
from __future__ import annotations
from typing import Dict, Iterable, Optional, Tuple, Set, Literal, List, Callable
import torch
from tqdm.notebook import tqdm
import math

In [2]:
def dyn_weights_proportional(
    target_total_strength: float = 1.0,
    floor: float = 0.02,
) :
    """
    w_raw ∝ max(s, floor); then scale so (w1*s1 + w2*s2) == target_total_strength.
    If both s1/s2 missing, return (1,1).
    """
    eps = 1e-12
    def fn(base: str, s1: Optional[float], s2: Optional[float]) -> Tuple[float, float]:
        if (s1 is None and s2 is None):
            return (1.0, 1.0)
        a = max(floor, s1) if s1 is not None and math.isfinite(s1) else floor
        b = max(floor, s2) if s2 is not None and math.isfinite(s2) else floor
        denom = a + b
        p1 = a / (denom + eps)
        p2 = b / (denom + eps)
        # predicted combined strength (upper bound)
        pred = p1 * (s1 if s1 is not None else 0.0) + p2 * (s2 if s2 is not None else 0.0)
        scale = target_total_strength / max(pred, eps)
        return (scale * p1, scale * p2)
    return fn


In [3]:
def dyn_weights_softmax(
    target_total_strength: float = 1.0,
    temperature: float = 1.0,
    floor: float = 0.0,
):
    """
    w_raw = softmax([s1, s2]/T), with optional floor, then rescale so
    w1*s1 + w2*s2 == target_total_strength.
    """
    eps = 1e-12
    def fn(base: str, s1: Optional[float], s2: Optional[float]):
        # handle missing
        x1 = s1 if (s1 is not None and math.isfinite(s1)) else float("-inf")
        x2 = s2 if (s2 is not None and math.isfinite(s2)) else float("-inf")
        if x1 == float("-inf") and x2 == float("-inf"):
            return (1.0, 1.0)
        m = max(x1, x2)
        t = max(temperature, 1e-8)
        e1 = 0.0 if x1 == float("-inf") else math.exp((x1 - m)/t)
        e2 = 0.0 if x2 == float("-inf") else math.exp((x2 - m)/t)
        Z = e1 + e2 + eps
        p1 = e1 / Z
        p2 = e2 / Z
        if floor > 0:
            p1 = max(p1, floor); p2 = max(p2, floor)
            s = p1 + p2
            p1, p2 = p1/s, p2/s
        pred = (p1 * (s1 or 0.0)) + (p2 * (s2 or 0.0))
        scale = target_total_strength / max(pred, eps)
        return (scale * p1, scale * p2)
    return fn


In [4]:
def dyn_weights_powerlaw(
    target_total_strength: float = 1.0,
    p: float = 0.5,      # 0.0 ~ equal weights, 1.0 ~ proportional, >1.0 ~ extra sharp
    floor: float = 0.02,
):
    eps = 1e-12
    def fn(base: str, s1: Optional[float], s2: Optional[float]):
        if (s1 is None and s2 is None):
            return (1.0, 1.0)
        a = max(floor, s1 or 0.0) ** max(p, 0.0)
        b = max(floor, s2 or 0.0) ** max(p, 0.0)
        denom = a + b + eps
        p1 = a / denom
        p2 = b / denom
        pred = p1 * (s1 or 0.0) + p2 * (s2 or 0.0)
        scale = target_total_strength / max(pred, eps)
        return (scale * p1, scale * p2)
    return fn


In [5]:
def dyn_weights_margin_sigmoid(
    target_total_strength: float = 1.0,
    temperature: float = 1.0,
    bias: float = 0.0,     # + favors lora1, - favors lora2
    floor: float = 0.02
):
    """
    p1 = sigmoid((s1 - s2 + bias)/T), p2 = 1 - p1, then rescale to target.
    """
    import math
    eps = 1e-12
    def sigm(x): return 1.0 / (1.0 + math.exp(-x))
    def fn(base: str, s1: Optional[float], s2: Optional[float]):
        if s1 is None and s2 is None:
            return (1.0, 1.0)
        a = s1 or 0.0
        b = s2 or 0.0
        p1 = sigm(((a - b) + bias) / max(temperature, 1e-8))
        p1 = max(floor, min(1.0 - floor, p1))
        p2 = 1.0 - p1
        pred = p1*a + p2*b
        scale = target_total_strength / max(pred, eps)
        return (scale * p1, scale * p2)
    return fn


In [10]:
_A = ".lora_A.weight"
_B = ".lora_B.weight"

# ---------- key pairing ----------
def _pair_bases(sd: Dict[str, torch.Tensor]) -> Set[str]:
    a = {k[:-len(_A)] for k in sd.keys() if k.endswith(_A)}
    b = {k[:-len(_B)] for k in sd.keys() if k.endswith(_B)}
    return a & b

def _iter_bases_union(sd1: Dict[str, torch.Tensor],
                      sd2: Optional[Dict[str, torch.Tensor]] = None,
                      include_bases: Optional[Iterable[str]] = None) -> Iterable[str]:
    bases = _pair_bases(sd1)
    if sd2:
        bases |= _pair_bases(sd2)
    if include_bases is not None:
        bases &= set(include_bases)
    for base in tqdm(sorted(bases)):
        yield base

# ---------- norms ----------
@torch.no_grad()
def _spectral_norm_power(W: torch.Tensor, iters: int = 50) -> float:
    # Power iteration (stable, low memory)
    m, n = W.shape
    device = W.device
    if m >= n:
        v = torch.randn(n, device=device)
        v = v / (v.norm() + 1e-12)
        for _ in range(iters):
            u = (W @ v);  u = u / (u.norm() + 1e-12)
            v = (W.T @ u); v = v / (v.norm() + 1e-12)
        val = (u @ (W @ v)).item()
    else:
        u = torch.randn(m, device=device)
        u = u / (u.norm() + 1e-12)
        for _ in range(iters):
            v = (W.T @ u); v = v / (v.norm() + 1e-12)
            u = (W @ v);   u = u / (u.norm() + 1e-12)
        val = (u @ (W @ v)).item()
    return float(abs(val))

# ---------- math helpers ----------
@torch.no_grad()
def _delta_from_AB(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    # A: [r, in], B: [out, r] -> Δ: [out, in]
    return B @ A

def _best_rank_for(deltaW: torch.Tensor, target_rank: Optional[int]) -> int:
    m, n = deltaW.shape
    max_r = min(m, n)
    if target_rank is None:
        return max_r
    return max(1, min(int(target_rank), max_r))

@torch.no_grad()
def _truncated_factorization(
    deltaW: torch.Tensor,
    target_rank: int,
    method: Literal["svd", "pca_lowrank"] = "svd",
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Factor ΔW (shape [out, in]) into A', B' such that B'@A' ≈ ΔW with original axes.
    """
    r = _best_rank_for(deltaW, target_rank)
    if method == "pca_lowrank":
        q = min(r + 8, min(*deltaW.shape))
        U, S, V = torch.pca_lowrank(deltaW, q=q, center=False)
        U, S, Vh = U[:, :r], S[:r], V[:, :r].T
    else:
        U, S, Vh = torch.linalg.svd(deltaW, full_matrices=False)
        U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
    sroot = torch.sqrt(torch.clamp(S, min=0))
    Bp = U * sroot.unsqueeze(0)      # [out, r]
    Ap = sroot.unsqueeze(1) * Vh     # [r, in]
    return Ap.contiguous(), Bp.contiguous()

def _to_dev_dtype(x: torch.Tensor, device: Optional[torch.device], dtype: Optional[torch.dtype]) -> torch.Tensor:
    if device is not None:
        x = x.to(device, non_blocking=True)
    if dtype is not None and x.dtype != dtype:
        x = x.to(dtype)
    return x

def _maybe_empty_cuda():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ---------- per-block weighting utilities ----------
# Static rules: list of (prefix, mul_lora1, mul_lora2)
BlockRules = List[Tuple[str, float, float]]

def _apply_block_rules(base: str, rules: Optional[BlockRules]) -> Tuple[float, float]:
    """
    Return (m1_rule, m2_rule) from the longest matching prefix rule, else (1.0, 1.0).
    """
    if not rules:
        return 1.0, 1.0
    best = None
    best_len = -1
    for prefix, m1, m2 in rules:
        if base.startswith(prefix) and len(prefix) > best_len:
            best = (m1, m2)
            best_len = len(prefix)
    return best if best is not None else (1.0, 1.0)

# Dynamic callback: (base, s1, s2) -> (m1_dyn, m2_dyn)
DynWeightFn = Callable[[str, Optional[float], Optional[float]], Tuple[float, float]]

def softmax_spectral_weight_fn(temperature: float = 1.0, floor: float = 0.0) -> DynWeightFn:
    """
    Prefer the LoRA with larger spectral norm via softmax:
      w_i ∝ exp( (s_i - max(s)) / T )
    'floor' lets a missing/zero norm still get a tiny weight if desired.
    """
    import math
    def fn(base: str, s1: Optional[float], s2: Optional[float]) -> Tuple[float, float]:
        x1 = s1 if (s1 is not None and s1 == s1) else float("-inf")
        x2 = s2 if (s2 is not None and s2 == s2) else float("-inf")
        m = max(x1, x2)
        if m == float("-inf"):  # both missing
            return 1.0, 1.0
        e1 = math.exp((x1 - m) / max(temperature, 1e-8)) if x1 != float("-inf") else 0.0
        e2 = math.exp((x2 - m) / max(temperature, 1e-8)) if x2 != float("-inf") else 0.0
        denom = e1 + e2
        if denom <= 0:
            return 1.0, 1.0
        w1 = e1 / denom
        w2 = e2 / denom
        if floor > 0.0:
            # blend toward a minimal non-zero weight
            w1 = max(w1, floor); w2 = max(w2, floor)
            s = w1 + w2
            w1, w2 = w1 / s, w2 / s
        return w1, w2
    return fn

# ---------- main: streaming blend & rank-convert with per-block weights ----------
@torch.no_grad()
def blend_and_convert_loras_streaming(
    lora1: Dict[str, torch.Tensor],
    lora2: Optional[Dict[str, torch.Tensor]] = None,
    w1: float = 1.0,
    w2: float = 1.0,
    target_rank: Optional[int] = 32,
    compute_device: Optional[str] = None,      # e.g., "cuda:0" or "cpu"
    compute_dtype: Optional[torch.dtype] = torch.float32,
    include_bases: Optional[Iterable[str]] = None,
    drop_near_zero: bool = True,
    zero_tol: float = 1e-12,
    factor_method: Literal["svd", "pca_lowrank"] = "svd",
    # NEW: per-block weighting
    block_rules: Optional[BlockRules] = None,                  # static prefix multipliers
    dynamic_weight_fn: Optional[DynWeightFn] = None,           # e.g., softmax of spectral norms
    dyn_norm_iters: int = 40,                                  # iterations for power-iteration
) -> Dict[str, torch.Tensor]:
    """
    Streaming LoRA blend/convert with per-block weights.
    Order of multipliers per layer:
        effective_w1 = w1 * rule_m1 * dyn_m1
        effective_w2 = w2 * rule_m2 * dyn_m2
    dynamic_weight_fn receives per-layer spectral norms (computed on-the-fly).
    """
    dev = torch.device(compute_device) if compute_device else None
    out_sd: Dict[str, torch.Tensor] = {}

    for base in _iter_bases_union(lora1, lora2, include_bases):
        # --- load A/B (CPU -> compute device)
        A1 = lora1.get(base + _A, None)
        B1 = lora1.get(base + _B, None)
        A2 = lora2.get(base + _A, None) if lora2 else None
        B2 = lora2.get(base + _B, None) if lora2 else None

        if A1 is None or B1 is None:
            d1 = None
        else:
            A1d = _to_dev_dtype(A1, dev, compute_dtype)
            B1d = _to_dev_dtype(B1, dev, compute_dtype)
            d1 = _delta_from_AB(A1d, B1d)

        if lora2 is not None and A2 is not None and B2 is not None:
            A2d = _to_dev_dtype(A2, dev, compute_dtype)
            B2d = _to_dev_dtype(B2, dev, compute_dtype)
            d2 = _delta_from_AB(A2d, B2d)
        else:
            d2 = None

        # skip if neither present
        if d1 is None and d2 is None:
            _maybe_empty_cuda()
            continue

        # --- per-block multipliers
        rule_m1, rule_m2 = _apply_block_rules(base, block_rules)

        # dynamic weights by spectral norm (computed lazily)
        dyn_m1 = dyn_m2 = 1.0
        if dynamic_weight_fn is not None:
            s1 = _spectral_norm_power(d1, iters=dyn_norm_iters) if d1 is not None else None
            s2 = _spectral_norm_power(d2, iters=dyn_norm_iters) if d2 is not None else None
            dyn_m1, dyn_m2 = dynamic_weight_fn(base, s1, s2)

        eff_w1 = w1 * rule_m1 * dyn_m1
        eff_w2 = w2 * rule_m2 * dyn_m2

        # --- blend
        if d1 is None:
            blended = eff_w2 * d2
        elif d2 is None:
            blended = eff_w1 * d1
        else:
            blended = eff_w1 * d1 + eff_w2 * d2

        # optionally skip tiny layers
        if drop_near_zero and blended.abs().max().item() < zero_tol:
            del blended, d1, d2
            _maybe_empty_cuda()
            continue

        # --- factorize to target rank and save on CPU
        r = _best_rank_for(blended, target_rank)
        A_new, B_new = _truncated_factorization(blended, r, method=factor_method)
        out_sd[base + _A] = A_new.detach().to("cpu", copy=True)
        out_sd[base + _B] = B_new.detach().to("cpu", copy=True)

        # cleanup
        del A_new, B_new, blended, d1, d2
        _maybe_empty_cuda()

    return out_sd

In [11]:
import safetensors.torch

lora1 = safetensors.torch.load_file("/mnt/models/tensors/loras/qwen_image/nsfw_qwen_bs8_r32_lowlr_000005000.safetensors")
lora2 = safetensors.torch.load_file("/mnt/models/tensors/loras/qwen_image/nsfw_qwen_resume_detail_qha.safetensors")

In [14]:
dyn_fn = softmax_spectral_weight_fn(temperature=0.5, floor=0.05)

blend8 = blend_and_convert_loras_streaming(
    lora1, 
    lora2,
    w1=1.0, # 0.5, 
    w2=1.0, # 0.65, 
    target_rank=32, 
    compute_device="cuda:0", 
    compute_dtype=torch.float32,
    # factor_method='pca_lowrank',
    dynamic_weight_fn=dyn_fn, 
    dyn_norm_iters=40,
)

safetensors.torch.save_file(blend8, "/mnt/models/tensors/loras/qwen_image/nsfw_qwen_blend_5000_qha_softmax_r32.safetensors")

  0%|          | 0/840 [00:00<?, ?it/s]