<a href="https://colab.research.google.com/github/p1atdev/SimilarityCalculator/blob/main/Similarity_Calculator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 準備
必要なもののインストールなど

In [None]:
!pip install torch pytorch-lightning safetensors

## インポート

In [None]:
from safetensors.torch import load_file
import sys
import torch
from pathlib import Path
import torch.nn as nn
import torch.nn.functional as F
import requests
from pathlib import Path
import os

## 関数定義

In [None]:
def cal_cross_attn(to_q, to_k, to_v, rand_input):
    hidden_dim, embed_dim = to_q.shape
    attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False)
    attn_to_k = nn.Linear(hidden_dim, embed_dim, bias=False)
    attn_to_v = nn.Linear(hidden_dim, embed_dim, bias=False)
    attn_to_q.load_state_dict({"weight": to_q})
    attn_to_k.load_state_dict({"weight": to_k})
    attn_to_v.load_state_dict({"weight": to_v})
    
    return torch.einsum(
        "ik, jk -> ik", 
        F.softmax(torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)), dim=-1),
        attn_to_v(rand_input)
    )

def model_hash(filename):
    try:
        with open(filename, "rb") as file:
            import hashlib
            m = hashlib.sha256()

            file.seek(0x100000)
            m.update(file.read(0x10000))
            return m.hexdigest()[0:8]
    except FileNotFoundError:
        return 'NOFILE'
        
def load_model(path):
    if path.suffix == ".safetensors":
        return load_file(path, device="cpu")
    else:
        ckpt = torch.load(path, map_location="cpu")
        return ckpt["state_dict"] if "state_dict" in ckpt else ckpt
        
def eval(model, n, input):
    qk = f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_q.weight"
    uk = f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_k.weight"
    vk = f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_v.weight"
    atoq, atok, atov = model[qk], model[uk], model[vk]

    attn = cal_cross_attn(atoq, atok, atov, input)
    return attn

# 実行

In [None]:
#@title ckptファイルのURLの指定
base_model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt"#@param {type:"string"}
target_model_url = "https://huggingface.co/prompthero/openjourney/resolve/main/mdjrny-v4.ckpt"#@param {type:"string"}

In [None]:
#@title ckptファイルのダウンロード

models_dir = Path("./models")

if not os.path.isdir(models_dir):
  os.mkdir(models_dir)

base_model_name = base_model_url.split("/")[-1]
base_model_path = models_dir / base_model_name

target_model_name = target_model_url.split("/")[-1]
target_model_path = models_dir / target_model_name

urls = [base_model_url, target_model_url]

for url in urls:
  model_name = url.split("/")[-1]
  model_path = models_dir / model_name
  if not os.path.isfile(model_path):
    print(f"Downloading {url}")
    base_model = requests.get(url).content

    with open(model_path,mode='wb') as f:
      f.write(base_model)

print("Finished")

In [None]:
#@title 比較の開始

seed = 123456789
torch.manual_seed(seed)
print(f"seed: {seed}") 

model_a = load_model(base_model_path)

print()
print(f"base: {base_model_name} [{model_hash(base_model_path)}]")
print()

map_attn_a = {}
map_rand_input = {}
for n in range(3, 11):
    hidden_dim, embed_dim = model_a[f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_q.weight"].shape
    rand_input = torch.randn([embed_dim, hidden_dim])

    map_attn_a[n] = eval(model_a, n, rand_input)
    map_rand_input[n] = rand_input
    
del model_a
  
model_b = load_model(target_model_path)

sims = []
for n in range(3, 11):
    attn_a = map_attn_a[n]
    attn_b = eval(model_b, n, map_rand_input[n])
    
    sim = torch.mean(torch.cosine_similarity(attn_a, attn_b))
    sims.append(sim)

del model_b
    
print(f"{target_model_name} [{model_hash(target_model_path)}] - {torch.mean(torch.stack(sims)) * 1e2:.2f}%")

seed: 123456789

base: v1-5-pruned-emaonly.ckpt [81761151]

mdjrny-v4.ckpt [ddc6edf2] - 99.96%
