In [None]:
import os
import torch
import torch.nn.functional as F
import polars as pl
from torch.utils.data import DataLoader, TensorDataset
from vector_quantize_pytorch import ResidualVQ
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np
from utils.file_config import FILE_CONFIG as fc
import utils.evaluation as eval_utils
from copy import deepcopy



# === Config ===
class Config:
    alpha = 0
    num_epochs = 100
    batch_size = 2048*32
    lr = 3e-4
    dim = 768
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    seed = 1234
    codebook_config =[
    {'num_codebooks': 5, 'codebook_size': 512, 'use_cosine': True},
    {'num_codebooks': 6, 'codebook_size': 512, 'use_cosine': True},
    {'num_codebooks': 7, 'codebook_size': 512, 'use_cosine': True},
    {'num_codebooks': 5, 'codebook_size': 1024, 'use_cosine': True},
    {'num_codebooks': 6, 'codebook_size': 1024, 'use_cosine': True},
    {'num_codebooks': 7, 'codebook_size': 1024, 'use_cosine': True},

    {'num_codebooks': 5, 'codebook_size': 512, 'use_cosine': False},
    {'num_codebooks': 6, 'codebook_size': 512, 'use_cosine': False},
    {'num_codebooks': 7, 'codebook_size': 512, 'use_cosine': False},
    {'num_codebooks': 5, 'codebook_size': 1024, 'use_cosine': False},
    {'num_codebooks': 6, 'codebook_size': 1024, 'use_cosine': False},
    {'num_codebooks': 7, 'codebook_size': 1024, 'use_cosine': False},
]

torch.manual_seed(Config.seed)


def load_embeddings(file_config, model, config):
    df_concept_all = pl.read_parquet(file_config["path_all_concept"])
    df_mapped = pl.read_csv(file_config["mapped_concept"])
    idx_mapped = df_mapped.join(df_concept_all, left_on="n.id", right_on="id")["idx"].unique().to_list()
    embedding_path = file_config["embedding_save_path"] + f"/{model}.pt"
    full_embeddings_l = torch.load(embedding_path)["labels_embeddings"].to(config.device)
    full_embeddings_exp = torch.load(embedding_path)["expressions_embeddings"].to(config.device)
    mapped_embeddings = full_embeddings_l[idx_mapped, :]

    return full_embeddings_l, full_embeddings_exp, mapped_embeddings


def train(model, train_loader, config):
    
    opt = torch.optim.AdamW(model.parameters(), lr=config.lr)

    model.train()

    for epoch in range(config.num_epochs):
        print(f"\n=== Epoch {epoch + 1}/{config.num_epochs} ===")
        epoch_losses = []
        for x_batch in train_loader:

            x = x_batch[0].to(config.device)
            opt.zero_grad()

            quantized, _, _ = model(x)
            # out = out.clamp(-1., 1.)
            if config.use_cosine:
                loss = 1 - F.cosine_similarity(x, quantized, dim=-1).mean()
            else:
                loss = F.mse_loss(x, quantized)

            epoch_losses.append(loss.item())
            loss.backward()
            opt.step()
        
        avg_epoch_loss = sum(epoch_losses) / len(epoch_losses)
        print(f"Epoch {epoch+1} | Avg batch loss: {avg_epoch_loss:.4f}")

    return model

def evaluate(model, embeddings):
    model.eval()
    with torch.no_grad():
        embeddings = embeddings.to(Config.device)
        quantized, indices, _ = model(embeddings)
        cos_sim = F.cosine_similarity(embeddings, quantized, dim=-1).mean()
        print(f"Average cosine similarity: {cos_sim.item():.4f}")
        return quantized, indices
    
def get_dataloader(embeddings,config):
    dataset = TensorDataset(embeddings)
    return DataLoader(dataset, batch_size=config.batch_size, shuffle=True)

def run_all_configs(config_list, train_loader, base_config):
    results = []

    for conf in config_list:
        cfg = deepcopy(base_config)
        cfg.num_codebooks = conf["num_codebooks"]
        cfg.codebook_size = conf["codebook_size"]
        cfg.use_cosine = conf["use_cosine"]

        config_id = f"{cfg.num_codebooks}x{cfg.codebook_size}_{'cos' if cfg.use_cosine else 'l2'}"
        print(f"\n=== Training {config_id} ===")

        # Instantiate model
        model = ResidualVQ(
            dim=cfg.dim,
            num_quantizers=cfg.num_codebooks,
            codebook_size=cfg.codebook_size,
            learnable_codebook=True,
            ema_update=False,
            # kmeans_init=True,
            # kmeans_iters=10,
            use_cosine_sim=cfg.use_cosine
        ).to(cfg.device)

        # Train
        model = train(model, train_loader, cfg)

        # Save result
        torch.save(model.state_dict(), f"rvq_{config_id}.pt")
        results.append({
            "config": conf,
            "config_id": config_id,
            "model": model
        })

    return results


In [None]:

rvq_vanilla = ResidualVQ(
    dim=768,
    num_quantizers=4,
    codebook_size=512,
    learnable_codebook=True,
    ema_update=False,
    kmeans_init=True,
    kmeans_iters=10,
).to(Config.device)

rvq_cosine = ResidualVQ(
    dim=768,
    num_quantizers=4,
    codebook_size=512,
    learnable_codebook=True,
    ema_update=False,
    use_cosine_sim=True,
).to(Config.device)


In [None]:
cfg = Config()

# === Dataset & DataLoader ===
full_embeddings_l,full_embeddings_exp, mapped_embeddings = load_embeddings(file_config=fc, model= "sapbert_lora_triplet16", config=cfg)
train_loader = get_dataloader(torch.concat((full_embeddings_l, full_embeddings_exp)), cfg)

# === train ===

cfg.use_cosine = False  # Set to True if you want to use cosine similarity
rvq_vanilla = train(rvq_vanilla, train_loader, cfg)

cfg.use_cosine = True  # Set to True if you want to use cosine similarity
rvq_cosine = train(rvq_cosine, train_loader, cfg)

# eval

In [21]:
df_concept_all = pl.read_parquet(fc["path_all_concept"])
df_concept_all_idx = set(df_concept_all["idx"].unique().to_list())

df_concept_train = pl.read_parquet(fc["training_triplet_idx"])
df_concept_train_idx = set(df_concept_train["idx"].unique().to_list())

df_concept_test_idx = df_concept_all_idx - df_concept_train_idx
df_concept_test = list(df_concept_test_idx)

id2idx = dict(zip(df_concept_all["id"], df_concept_all["idx"]))

df_concept_test_fd = df_concept_all.filter(pl.col("idx").is_in(df_concept_test)).filter(pl.col("status") == "defined")["idx"].unique().to_list()

# eval task 1

In [None]:
mrrs_1 = []
models_1 = []
ranks_1 = {}

embeddings = torch.load(fc["embedding_save_path"] + f"/sapbert_lora_triplet16.pt")
embedding_exp = embeddings["expressions_embeddings"]
embedding_label = embeddings["labels_embeddings"]
embedding_exp_q,_ = evaluate(rvq_vanilla, embedding_exp)
embedding_label_q,_ = evaluate(rvq_vanilla, embedding_label)
rank = eval_utils.top_k_array_by_batch(df_concept_test_fd, embedding_exp_q, embedding_label_q,cfg.device, 100)
mrr_rank = eval_utils.compute_mmr(rank)
mrrs_1.append(mrr_rank)
models_1.append("vanilla_4_512")
ranks_1["vanilla_4_512"] = rank
print(f"MRR: {mrr_rank}")



In [None]:

cfg.use_cosine = True  # Set to True if you want to use cosine similarity
rvq_cosine = train(rvq_cosine, train_loader, cfg)


In [None]:

embeddings = torch.load(fc["embedding_save_path"] + f"/sapbert_lora_triplet16.pt")
embedding_exp = embeddings["expressions_embeddings"]
embedding_label = embeddings["labels_embeddings"]
embedding_exp_q,_ = evaluate(rvq_cosine, embedding_exp)
embedding_label_q,_ = evaluate(rvq_cosine, embedding_label)
rank = eval_utils.top_k_array_by_batch(df_concept_test_fd, embedding_exp_q, embedding_label_q,cfg.device, 100)
mrr_rank = eval_utils.compute_mmr(rank)
mrrs_1.append(mrr_rank)
models_1.append("cosin_4_512")
ranks_1["cosin_4_512"] = rank
print(f"MRR: {mrr_rank}")


Average cosine similarity: 0.7597
Average cosine similarity: 0.7502
Processing batch 1/435 (0-100)


  query_matrix = torch.tensor(query_matrix, dtype=torch.float32).to(device)
  candidate_matrix = torch.tensor(candidate_matrix, dtype=torch.float32).to(device)


Processing batch 2/435 (100-200)
Processing batch 3/435 (200-300)
Processing batch 4/435 (300-400)
Processing batch 5/435 (400-500)
Processing batch 6/435 (500-600)
Processing batch 7/435 (600-700)
Processing batch 8/435 (700-800)
Processing batch 9/435 (800-900)
Processing batch 10/435 (900-1000)
Processing batch 11/435 (1000-1100)
Processing batch 12/435 (1100-1200)
Processing batch 13/435 (1200-1300)
Processing batch 14/435 (1300-1400)
Processing batch 15/435 (1400-1500)
Processing batch 16/435 (1500-1600)
Processing batch 17/435 (1600-1700)
Processing batch 18/435 (1700-1800)
Processing batch 19/435 (1800-1900)
Processing batch 20/435 (1900-2000)
Processing batch 21/435 (2000-2100)
Processing batch 22/435 (2100-2200)
Processing batch 23/435 (2200-2300)
Processing batch 24/435 (2300-2400)
Processing batch 25/435 (2400-2500)
Processing batch 26/435 (2500-2600)
Processing batch 27/435 (2600-2700)
Processing batch 28/435 (2700-2800)
Processing batch 29/435 (2800-2900)
Processing batch 