In [4]:
import os, json, pickle
import numpy as np
import torch

# ===========================
# 1. 데이터 로드
# ===========================
config = {}
config["dataset"] = "beauty"
config["text_encoder"]  = "qwen3vl_emb2b/text"
config["image_encoder"] = "qwen3vl_emb2b/image"
config["mm_encoder"]    = "qwen3vl_emb2b/text_image"

dataset_path = f'./dataset/{config["dataset"]}'
id_map = json.load(open(f'{dataset_path}/id_map.json', "r"))["item2id"]

text_feat_raw  = pickle.load(open(f'{dataset_path}/{config["text_encoder"]}.pkl', "rb"))
image_feat_raw = pickle.load(open(f'{dataset_path}/{config["image_encoder"]}.pkl', "rb"))
mm_feat_raw    = pickle.load(open(f'{dataset_path}/{config["mm_encoder"]}.pkl', "rb"))

test_data = torch.load('./saved_data/test_data.pth', weights_only=False)

# ===========================
# 2. Target 및 Query 준비
# ===========================
target = test_data.dataset.inter_feat['item_id']
target_str = [str(t.item()) for t in target]
user_sequence = test_data.dataset.inter_feat['item_id_list']
seq_lens = test_data.dataset.inter_feat['item_length']
last_item_indices = seq_lens - 1
last_items = [user_sequence[i][last_item_indices[i]].item() for i in range(len(user_sequence))]
last_items_str = [str(user_sequence[i][last_item_indices[i]].item()) for i in range(len(user_sequence))]

last_items.append(7849)
target = torch.cat([target, torch.tensor([7850])], dim=0)

# ===========================
# 3. Feature Mapping
# ===========================
item_num = test_data.dataset.field2id_token['item_id'].__len__()
text_mapped_feat = np.zeros((item_num, text_feat_raw.shape[1]))
image_mapped_feat = np.zeros((item_num, image_feat_raw.shape[1]))
mm_mapped_feat = np.zeros((item_num, mm_feat_raw.shape[1]))

for i, token in enumerate(test_data.dataset.field2id_token['item_id']):
    if token == '[PAD]':
        continue
    token_idx = int(id_map[token])-1
    text_mapped_feat[i] = text_feat_raw[token_idx]
    image_mapped_feat[i] = image_feat_raw[token_idx]
    mm_mapped_feat[i] = mm_feat_raw[token_idx]

# ===========================
# 4. Recommendation Functions
# ===========================
def topk_recommend(query_item_ids, all_item_emb, k=5, exclude_history=False, history_sets=None, exclude_self=True):
    """
    query_item_ids: 각 유저의 마지막 아이템 ID 리스트 (len=N)
    all_item_emb: (item_num, emb_dim) 전체 아이템 임베딩
    k: 추천할 개수
    exclude_self: True면 query 자기 자신은 추천에서 제외
    """
    query_item_ids = np.array(query_item_ids)
    N = len(query_item_ids)
    
    # query 임베딩
    query_emb = all_item_emb[query_item_ids]  # (N, emb_dim)
    
    # 유사도 계산 (cosine similarity)
    query_norm = np.linalg.norm(query_emb, axis=1, keepdims=True) + 1e-12
    item_norm = np.linalg.norm(all_item_emb, axis=1, keepdims=True) + 1e-12
    
    query_emb_normed = query_emb / query_norm
    item_emb_normed = all_item_emb / item_norm
    
    scores = query_emb_normed @ item_emb_normed.T  # (N, item_num)
    
    # exclude_self - 벡터화로 최적화
    if exclude_self:
        scores[np.arange(N), query_item_ids] = -np.inf
    
    # exclude_history
    if exclude_history and history_sets is not None:
        for i in range(N):
            if len(history_sets[i]) > 0:
                scores[i, list(history_sets[i])] = -np.inf
    
    # top-k - np.argpartition 사용으로 최적화 (전체 정렬 대신 top-k만 찾기)
    # k+1개를 찾아서 exclude_self로 인한 자기 자신을 제외할 수 있도록 함
    topk_indices = np.argpartition(-scores, k, axis=1)[:, :k]
    # 각 행의 top-k를 정렬
    topk_scores = np.take_along_axis(scores, topk_indices, axis=1)
    sorted_idx = np.argsort(-topk_scores, axis=1)
    topk_indices = np.take_along_axis(topk_indices, sorted_idx, axis=1)
    topk_scores = np.take_along_axis(topk_scores, sorted_idx, axis=1)
    
    return topk_indices, topk_scores


def recall_at_k(topk_ids, target):
    """
    topk_ids: (N, k) 추천 결과
    target: (N,) 정답 아이템
    
    Returns:
    - recall: Recall@k 값
    - hit: (N,) boolean 배열, hit[i]=True면 i번째 샘플이 top-k에 정답 포함
    """
    target_np = target.cpu().numpy() if isinstance(target, torch.Tensor) else np.array(target)
    N = topk_ids.shape[0]
    
    # 벡터화된 방식으로 hit 계산 (훨씬 빠름)
    # 각 target이 해당 행의 topk_ids에 있는지 확인
    hit = (topk_ids == target_np[:, np.newaxis]).any(axis=1)
    
    recall = np.mean(hit)
    return recall, hit

# ===========================
# 5. Top-K 추천 수행
# ===========================
k = 5
query_ids = last_items

# Text 추천
text_topk_ids, text_topk_scores = topk_recommend(
    query_item_ids=query_ids,
    all_item_emb=text_mapped_feat,
    k=k,
    exclude_history=False,
    history_sets=None,
    exclude_self=True,
)

# Image 추천
img_topk_ids, img_topk_scores = topk_recommend(
    query_item_ids=query_ids,
    all_item_emb=image_mapped_feat,
    k=k,
    exclude_history=False,
    history_sets=None,
    exclude_self=True,
)

# MM 추천
mm_topk_ids, mm_topk_scores = topk_recommend(
    query_item_ids=query_ids,
    all_item_emb=mm_mapped_feat,
    k=k,
    exclude_history=False,
    history_sets=None,
    exclude_self=True,
)

# Recall 계산
text_recall, text_hit = recall_at_k(text_topk_ids, target)
img_recall, img_hit = recall_at_k(img_topk_ids, target)
mm_recall, mm_hit = recall_at_k(mm_topk_ids, target)

print(f"Recall@{k} (Text):  {text_recall:.4f}")
print(f"Recall@{k} (Image): {img_recall:.4f}")
print(f"Recall@{k} (MM):    {mm_recall:.4f}")

# ===========================
# 6. 샘플 수 계산
# ===========================
# 각 임베딩이 독점적으로 맞추는 샘플
text_only = np.sum(text_hit & ~img_hit & ~mm_hit)
img_only = np.sum(~text_hit & img_hit & ~mm_hit)
mm_only = np.sum(~text_hit & ~img_hit & mm_hit)

# 두 개만 맞추는 경우
text_img_only = np.sum(text_hit & img_hit & ~mm_hit)
text_mm_only = np.sum(text_hit & ~img_hit & mm_hit)
img_mm_only = np.sum(~text_hit & img_hit & mm_hit)

# 세 개 모두 맞추는 경우
all_three = np.sum(text_hit & img_hit & mm_hit)

# 모두 못 맞추는 경우
all_miss = np.sum(~text_hit & ~img_hit & ~mm_hit)

# 결과 출력
print("\n" + "=" * 80)
print("Recall@5 Analysis - Embedding Overlap")
print("=" * 80)

print("\n[1] Exclusive Hits - 오직 해당 임베딩만 맞춘 샘플")
print(f"  Text only:  {text_only:5d} samples ({text_only/len(text_hit)*100:.2f}%)")
print(f"  Image only: {img_only:5d} samples ({img_only/len(text_hit)*100:.2f}%)")
print(f"  MM only:    {mm_only:5d} samples ({mm_only/len(text_hit)*100:.2f}%)")

print("\n[2] Pairwise Hits - 정확히 두 개만 맞춘 샘플")
print(f"  Text + Image (MM X): {text_img_only:5d} samples ({text_img_only/len(text_hit)*100:.2f}%)")
print(f"  Text + MM (Image X): {text_mm_only:5d} samples ({text_mm_only/len(text_hit)*100:.2f}%)")
print(f"  Image + MM (Text X): {img_mm_only:5d} samples ({img_mm_only/len(text_hit)*100:.2f}%)")

print("\n[3] Common & Missing")
print(f"  All three hit: {all_three:5d} samples ({all_three/len(text_hit)*100:.2f}%)")
print(f"  All miss:      {all_miss:5d} samples ({all_miss/len(text_hit)*100:.2f}%)")

print("\n[4] Verification")
total_check = text_only + img_only + mm_only + text_img_only + text_mm_only + img_mm_only + all_three + all_miss
print(f"  Sum: {text_only} + {img_only} + {mm_only} + {text_img_only} + {text_mm_only} + {img_mm_only} + {all_three} + {all_miss} = {total_check}")
print(f"  Total samples: {len(text_hit)}")
print(f"  Match: {total_check == len(text_hit)}")

print("\n[5] Text vs Image (기존 통계)")
print(f"  Text fail & Image hit: {np.sum(~text_hit & img_hit):5d}")
print(f"  Text hit & Image fail: {np.sum(text_hit & ~img_hit):5d}")
print(f"  Both hit:              {np.sum(text_hit & img_hit):5d}")
print(f"  Both fail:             {np.sum(~text_hit & ~img_hit):5d}")

print("=" * 80)

Recall@5 (Text):  0.0464
Recall@5 (Image): 0.0394
Recall@5 (MM):    0.0500

Recall@5 Analysis - Embedding Overlap

[1] Exclusive Hits - 오직 해당 임베딩만 맞춘 샘플
  Text only:    160 samples (0.72%)
  Image only:   182 samples (0.81%)
  MM only:      106 samples (0.47%)

[2] Pairwise Hits - 정확히 두 개만 맞춘 샘플
  Text + Image (MM X):    13 samples (0.06%)
  Text + MM (Image X):   327 samples (1.46%)
  Image + MM (Text X):   149 samples (0.67%)

[3] Common & Missing
  All three hit:   537 samples (2.40%)
  All miss:      20890 samples (93.41%)

[4] Verification
  Sum: 160 + 182 + 106 + 13 + 327 + 149 + 537 + 20890 = 22364
  Total samples: 22364
  Match: True

[5] Text vs Image (기존 통계)
  Text fail & Image hit:   331
  Text hit & Image fail:   487
  Both hit:                550
  Both fail:             20996


In [6]:
total_correct = 106+182+160+149+13+327+537
image_only = 182

ratio = image_only / total_correct
print(f"image_only / total_correct: {ratio}")

image_only / total_correct: 0.12347354138398914


In [None]:
# total_correct_qwen2 = 106+182+160+149+13+327+537
# image_only_qwen2 = 106

# ratio = image_only / total_correct
# print(f"image_only / total_correct: {ratio}")