In [1]:
import torch
from models.base_model import ViSynoSenseEmbedding
from transformers import PhobertTokenizerFast
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = PhobertTokenizerFast.from_pretrained("BERT_BASE-20250820T150905Z-1-001/BERT_BASE")
from utils.span_extractor import SpanExtractor
from utils.process_data import text_normalize

span_ex =SpanExtractor(tokenizer)
model = ViSynoSenseEmbedding.from_pretrained("BERT_BASE-20250820T150905Z-1-001/BERT_BASE",tokenizer).to(device)

  from .autonotebook import tqdm as notebook_tqdm


MemoryError: 

In [None]:
# Load FAISS index
import faiss
import json
index = faiss.read_index("/content/drive/MyDrive/UIT/ViSynoSense_models/index_mean.faiss")

# Load metadata
with open("/content/drive/MyDrive/UIT/ViSynoSense_models/metadata_mean.json", "r", encoding="utf-8") as f:
    metadata = json.load(f)


'2.8.0+cpu'

In [2]:
import torch

torch.cuda.is_available()

True

In [None]:
def pipeline(query, target):
  query_norm=text_normalize(query)
  tokenized_query = tokenizer(query_norm,return_tensors="pt").to(device)
  span_idx = span_ex.get_span_indices(query_norm, target)
  span =torch.Tensor(span_idx).unsqueeze(0).to(device)
  model.eval()
  query_vec = model(tokenized_query, span)
  return query_vec

In [None]:
import torch.nn.functional as F


query_1 = "Tôi đang khoan."
target_1 = "Khoan"
query_vec_1 = pipeline(query_1, target_1)

query_2 = "khoan này bị mất mũi khoan."
target_2 = "mũi khoan"
query_vec_2 = pipeline(query_2, target_2)

query_3 = "Khoan là việc rất tiện lợi."
target_3 = "Khoan"
query_vec_3 = pipeline(query_3, target_3)


def cosine_similarity(vec1, vec2):
    return F.cosine_similarity(vec1, vec2, dim=1).item()


sim_1 = cosine_similarity(query_vec_1, query_vec_3)
sim_2 = cosine_similarity(query_vec_2, query_vec_3)

print(f"Similarity between 1: {target_1}  and  3: {target_3}: {sim_1:.4f}")
print(f"Similarity between 2: {target_2} and 3:{target_3}: {sim_2:.4f}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import umap.umap_ as umap       # pip install umap-learn
from adjustText import adjust_text  # pip install adjustText
from matplotlib.lines import Line2D

# --- 1. Chuẩn bị 2 query và targets ---
queries = [
    ("Đây là người viết thơ rất nổi tiếng.", "người viết thơ"),
    ("Anh ấy là người sáng tác thơ.",   "người sáng tác thơ"),
]
colors_query = ["red", "blue"]
colors_nn    = ["orange", "cyan"]

# --- 2. Lấy vector query và top‑K neighbours ---
all_vecs   = []
all_labels = []
all_colors = []

seen_idxs = set()
for qi, (q_text, q_target) in enumerate(queries):
    # lấy vec query
    q_vec = pipeline(q_text, q_target).detach().cpu().numpy().reshape(-1)
    D, I   = index.search(q_vec[np.newaxis, :], k=10)

    # thêm query vào lists
    all_vecs.append(q_vec)
    all_labels.append(f"Q{qi+1}: {q_target}")
    all_colors.append(colors_query[qi])

    # top‑K nhưng bỏ trùng với seen_idxs
    for idx in I[0]:
        if idx in seen_idxs:
            continue
        seen_idxs.add(idx)
        nn_vec = np.array(index.reconstruct(int(idx))).reshape(-1)
        all_vecs.append(nn_vec)
        word = metadata[idx]['word']
        all_labels.append(word)
        all_colors.append(colors_nn[qi])

# --- 3. Giảm chiều bằng UMAP ---
all_vecs = np.vstack(all_vecs)
reducer = umap.UMAP(n_components=2, metric='cosine', random_state=42)
emb_2d   = reducer.fit_transform(all_vecs)

# --- 4. Vẽ scatter + dàn nhãn ---
plt.figure(figsize=(12, 8))
text_objs = []
for (x, y), label, c in zip(emb_2d, all_labels, all_colors):
    plt.scatter(x, y, color=c, s=80)
    text_objs.append(plt.text(x, y, label, fontsize=9))

adjust_text(text_objs,
            arrowprops=dict(arrowstyle="->", color='gray', lw=0.5),
            expand_text=(1.2, 1.2),
            force_text=0.5)

# --- 5. Tạo legend thủ công ---
legend_elements = [
    Line2D([0], [0], marker='o', color='w', label=f"Query {i+1}: {t}",
           markerfacecolor=colors_query[i], markersize=10)
    for i, (_, t) in enumerate(queries)
] + [
    Line2D([0], [0], marker='o', color='w', label=f"Top‑K of Query {i+1}",
           markerfacecolor=colors_nn[i], markersize=8)
    for i in range(len(queries))
]

plt.legend(handles=legend_elements, loc='upper right', framealpha=1)
plt.title("UMAP Visualization of Two Queries and Their Neighbors")
plt.axis("equal")
plt.axis("off")
plt.tight_layout()
plt.show()
