In [4]:
from PIL import Image
from model import AggMInterface
from data import DInterface
import yaml
import warnings
warnings.filterwarnings("ignore")

# 指定模型的超参数配置文件路径和checkpoint文件路径
config_path = "logs/dinov2_backbone_dinov2_large/lightning_logs/version_14/hparams.yaml"
checkpoint_path = "logs/dinov2_backbone_dinov2_large/lightning_logs/version_14/checkpoints/dinov2_backbone_epoch(39)_step(39080)_R1[0.9135]_R5[0.9595]_R10[0.9649].ckpt"

# 加载yaml文件，获取模型超参数配置
with open(config_path) as f:
    config = yaml.safe_load(f)

# 根据配置初始化数据模块
data_module = DInterface(**config)  # 数据模块初始化，传入配置参数
transform = data_module.valid_transform  # 获取验证集的数据变换方法

# 根据checkpoint文件路径加载模型，并设置为评估模式
model = AggMInterface.load_from_checkpoint(checkpoint_path)
model.eval()
model = model.model.model
model.requires_grad_(False)

正在尝试从以下路径导入模块: .dinov2_backbone
正在查找类: Dinov2Backbone


DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-23): 24 x NestedTensorBlock(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )


In [10]:
from utils.hook_func import compute_visual_similarity, create_attention_similarity_plot
from scipy.stats import mode
import numpy as np
import os
import matplotlib.pyplot as plt

# plt.rcParams.update({
#     'font.family': 'Times New Roman',
#     'font.size': 28,
#     'axes.labelsize': 14,
#     'axes.titlesize': 16,
#     'legend.fontsize': 12,
#     'xtick.labelsize': 12,
#     'ytick.labelsize': 12,
#     'axes.linewidth': 1.2,
#     'lines.linewidth': 2.5,
#     'lines.markersize': 10,
#     'grid.linewidth': 0.8,
# })

source_image_path = "sample_imgs/nordland/4/query/@0@26504.8@@@@@11132@@@@@@@@.jpg" 
target_image_path = "sample_imgs/nordland/4/ref/@0@26528.6@@@@@11142@@@@@@@@.jpg"

image_size = 560
query_point = (252, 181)
layer_idxs = list(range(10, 24))
device = "cuda"

for layer_idx in layer_idxs:   
    print(f"Processing layer {layer_idx}")
    # 计算视觉相似度
    similarity_maps, source_img, target_img = compute_visual_similarity(
        source_image_path,
        target_image_path,
        model,
        query_point,
        layer_idx,
        image_size,
        device
    )

    # 计算注意力图最大值位置
    key_max = mode(np.argwhere(similarity_maps["key"].max() == similarity_maps["key"]), axis=0).mode[:2]
    query_max = mode(np.argwhere(similarity_maps["query"].max()==similarity_maps["query"]), axis=0).mode[:2]  
    value_max = mode(np.argwhere(similarity_maps["value"].max()==similarity_maps["value"]), axis=0).mode[:2]
    token_max = mode(np.argwhere(similarity_maps["token"].max()==similarity_maps["token"]), axis=0).mode[:2]

    # 将最大值位置存储在字典中
    max_positions = {
        "key": key_max,
        "query": query_max,
        "value": value_max,
        "token": token_max
    }

    # 可视化注意力相似度
    fig = create_attention_similarity_plot(
        source_image=source_img,
        target_image=target_img,
        attention_maps=similarity_maps,
        source_point=query_point,
        max_positions=max_positions,
        fig_size=(36, 8),
        dpi=600
    )

    # 可以选择保存图形
    # 设置保存路径
    save=True
    if save:
        save_folder = "some_result_images/query_key_value_token/07"
        if not os.path.isdir(save_folder):
            os.makedirs(save_folder)
            print(f"Directory created: {save_folder}")
        else:
            print(f"Destination directory '{save_folder}' already exists!")
        save_fname = f"{layer_idx}_I{source_image_path.split('/')[-1]}-{target_image_path.split('/')[-1]}_Px{query_point[0]}_Py{query_point[1]}.png"
        if save_folder:
            fig.savefig(os.path.join(save_folder, save_fname), bbox_inches='tight')

    # 可以选择显示图形
    # plt.show(fig)

    # 清理图形以释放内存
    plt.close(fig)


Processing layer 10
Directory created: some_result_images/query_key_value_token/07
Processing layer 11
Destination directory 'some_result_images/query_key_value_token/07' already exists!
Processing layer 12
Destination directory 'some_result_images/query_key_value_token/07' already exists!
Processing layer 13
Destination directory 'some_result_images/query_key_value_token/07' already exists!
Processing layer 14
Destination directory 'some_result_images/query_key_value_token/07' already exists!
Processing layer 15
Destination directory 'some_result_images/query_key_value_token/07' already exists!
Processing layer 16
Destination directory 'some_result_images/query_key_value_token/07' already exists!
Processing layer 17
Destination directory 'some_result_images/query_key_value_token/07' already exists!
Processing layer 18
Destination directory 'some_result_images/query_key_value_token/07' already exists!
Processing layer 19
Destination directory 'some_result_images/query_key_value_token/0