In [23]:
import os
import torch
import torch.nn.functional as F
import polars as pl
from torch.utils.data import DataLoader, TensorDataset
from vector_quantize_pytorch import ResidualVQ
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np
from utils.file_config import FILE_CONFIG as fc
import utils.evaluation as eval_utils
from copy import deepcopy



# === Config ===
class Config:
    alpha = 0
    num_epochs = 100
    batch_size = 2048*32
    lr = 3e-4
    dim = 768
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    seed = 1234
    codebook_config =[
    {'num_codebooks': 5, 'codebook_size': 512, 'use_cosine': True, 'model_name': 'rvq_5x512_cos'},
    {'num_codebooks': 6, 'codebook_size': 512, 'use_cosine': True, 'model_name': 'rvq_6x512_cos'},
    {'num_codebooks': 7, 'codebook_size': 512, 'use_cosine': True, 'model_name': 'rvq_7x512_cos'},
    {'num_codebooks': 5, 'codebook_size': 1024, 'use_cosine': True, 'model_name': 'rvq_5x1024_cos'},
    {'num_codebooks': 6, 'codebook_size': 1024, 'use_cosine': True, 'model_name': 'rvq_6x1024_cos'},
    {'num_codebooks': 7, 'codebook_size': 1024, 'use_cosine': True, 'model_name': 'rvq_7x1024_cos'},

    {'num_codebooks': 5, 'codebook_size': 512, 'use_cosine': False, 'model_name': 'rvq_5x512_l2'},
    {'num_codebooks': 6, 'codebook_size': 512, 'use_cosine': False, 'model_name': 'rvq_6x512_l2'},
    {'num_codebooks': 7, 'codebook_size': 512, 'use_cosine': False, 'model_name': 'rvq_7x512_l2'},
    {'num_codebooks': 5, 'codebook_size': 1024, 'use_cosine': False, 'model_name': 'rvq_5x1024_l2'},
    {'num_codebooks': 6, 'codebook_size': 1024, 'use_cosine': False, 'model_name': 'rvq_6x1024_l2'},
    {'num_codebooks': 7, 'codebook_size': 1024, 'use_cosine': False, 'model_name': 'rvq_7x1024_l2'},
]

torch.manual_seed(Config.seed)


def load_embeddings(file_config, model, config):
    df_concept_all = pl.read_parquet(file_config["path_all_concept"])
    df_mapped = pl.read_csv(file_config["mapped_concept"])
    idx_mapped = df_mapped.join(df_concept_all, left_on="n.id", right_on="id")["idx"].unique().to_list()
    embedding_path = file_config["embedding_save_path"] + f"/{model}.pt"
    full_embeddings_l = torch.load(embedding_path)["labels_embeddings"].to(config.device)
    full_embeddings_exp = torch.load(embedding_path)["expressions_embeddings"].to(config.device)
    mapped_embeddings = full_embeddings_l[idx_mapped, :]

    return full_embeddings_l, full_embeddings_exp, mapped_embeddings


def train(model, train_loader, config):
    
    opt = torch.optim.AdamW(model.parameters(), lr=config.lr)

    model.train()

    for epoch in range(config.num_epochs):
        print(f"\n=== Epoch {epoch + 1}/{config.num_epochs} ===")
        epoch_losses = []
        for x_batch in tqdm(train_loader):

            x = x_batch[0].to(config.device)
            opt.zero_grad()

            quantized, _, _ = model(x)
            # out = out.clamp(-1., 1.)
            if config.use_cosine:
                loss = 1 - F.cosine_similarity(x, quantized, dim=-1).mean()
            else:
                loss = F.mse_loss(x, quantized)

            epoch_losses.append(loss.item())
            loss.backward()
            opt.step()
        
        avg_epoch_loss = sum(epoch_losses) / len(epoch_losses)
        print(f"Epoch {epoch+1} | Avg batch loss: {avg_epoch_loss:.4f}")
    return model


def evaluate(model, embeddings):
    model.eval()
    with torch.no_grad():
        embeddings = embeddings.to(Config.device)
        quantized, indices, _ = model(embeddings)
        cos_sim = F.cosine_similarity(embeddings, quantized, dim=-1).mean()
        print(f"Average cosine similarity: {cos_sim.item():.4f}")
        return quantized, indices
    
def get_dataloader(embeddings,config):
    dataset = TensorDataset(embeddings)
    return DataLoader(dataset, batch_size=config.batch_size, shuffle=True)


def train_all_rvq_configs(cfg_base, train_loader,embeddings_mapped):
    results = []
    for config in cfg_base.codebook_config:
        cfg = deepcopy(cfg_base)
        cfg.num_codebooks = config['num_codebooks']
        cfg.codebook_size = config['codebook_size']
        cfg.use_cosine = config['use_cosine']
        config_id = f"{cfg.num_codebooks}x{cfg.codebook_size}_{'cos' if cfg.use_cosine else 'l2'}"

        print(f"\n=== Training {config_id} ===")

        model = ResidualVQ(
            dim=cfg.dim,
            num_quantizers=cfg.num_codebooks,
            codebook_size=cfg.codebook_size,
            
            learnable_codebook=True,
            ema_update=False,
            kmeans_init=True,
            kmeans_iters=10,
            use_cosine_sim=cfg.use_cosine
        ).to(cfg.device)

        model = train(model, train_loader, cfg)

        torch.save(model.state_dict(), os.path.join(fc["model_q_save_path"], f"rvq_{config_id}.pt"))
        quantized, _ = evaluate(model, embeddings_mapped)

        print(f"Model {config_id} trained and saved.")

        results.append({
        'config_id': config_id,
        'num_codebooks': cfg.num_codebooks,
        'codebook_size': cfg.codebook_size,
        'use_cosine': cfg.use_cosine,
        'cos_sim_mapped': F.cosine_similarity(embeddings_mapped, quantized, dim=-1).mean().item()
    })

def load_model(cfg_base, config_dict):
    # Merge config base with specific codebook setting
    cfg = deepcopy(cfg_base)
    cfg.num_codebooks = config_dict['num_codebooks']
    cfg.codebook_size = config_dict['codebook_size']
    cfg.use_cosine = config_dict['use_cosine']
    
    config_id = config_dict['model_name']  # already follows "rvq_XxY_cos/l2"

    model = ResidualVQ(
        dim=cfg.dim,
        num_quantizers=cfg.num_codebooks,
        codebook_size=cfg.codebook_size,
        learnable_codebook=True,
        ema_update=False,
        kmeans_init=True,
        kmeans_iters=10,
        use_cosine_sim=cfg.use_cosine
    ).to(cfg.device)

    model_path = os.path.join(fc["model_q_save_path"], f"{config_id}.pt")
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Checkpoint not found at {model_path}")
    
    model.load_state_dict(torch.load(model_path))
    print(f"\n=== Model {config_id} loaded ===")

    return model, config_id


# train all

In [None]:
cfg = Config()

# === Dataset & DataLoader ===
full_embeddings_l,full_embeddings_exp, mapped_embeddings = load_embeddings(file_config=fc, model= "sapbert_lora_triplet16", config=cfg)
train_loader = get_dataloader(torch.concat((full_embeddings_l, full_embeddings_exp)), cfg)

# === Train Residual VQ for all configurations ===
result = train_all_rvq_configs(cfg, train_loader, mapped_embeddings)


# eval

In [4]:
df_concept_all = pl.read_parquet(fc["path_all_concept"])
df_concept_all_idx = set(df_concept_all["idx"].unique().to_list())

df_concept_train = pl.read_parquet(fc["training_triplet_idx"])
df_concept_train_idx = set(df_concept_train["idx"].unique().to_list())

df_concept_test_idx = df_concept_all_idx - df_concept_train_idx
df_concept_test = list(df_concept_test_idx)

id2idx = dict(zip(df_concept_all["id"], df_concept_all["idx"]))

df_concept_test_fd = df_concept_all.filter(pl.col("idx").is_in(df_concept_test)).filter(pl.col("status") == "defined")["idx"].unique().to_list()

In [22]:
cfg = Config()
for config in cfg.codebook_config:
    
    load_model(cfg, config)


TypeError: load_model() takes 1 positional argument but 2 were given

# eval task 1

In [None]:
mrrs_1 = []
models_1 = []
ranks_1 = {}

embeddings = torch.load(fc["embedding_save_path"] + f"/sapbert_lora_triplet16.pt")
embedding_exp = embeddings["expressions_embeddings"]
embedding_label = embeddings["labels_embeddings"]
embedding_exp_q,_ = evaluate(rvq_vanilla, embedding_exp)
embedding_label_q,_ = evaluate(rvq_vanilla, embedding_label)
rank = eval_utils.top_k_array_by_batch(df_concept_test_fd, embedding_exp_q, embedding_label_q,cfg.device, 100)
mrr_rank = eval_utils.compute_mmr(rank)
mrrs_1.append(mrr_rank)
models_1.append("vanilla_4_512")
ranks_1["vanilla_4_512"] = rank
print(f"MRR: {mrr_rank}")



In [None]:

cfg.use_cosine = True  # Set to True if you want to use cosine similarity
rvq_cosine = train(rvq_cosine, train_loader, cfg)


In [None]:

embeddings = torch.load(fc["embedding_save_path"] + f"/sapbert_lora_triplet16.pt")
embedding_exp = embeddings["expressions_embeddings"]
embedding_label = embeddings["labels_embeddings"]
embedding_exp_q,_ = evaluate(rvq_cosine, embedding_exp)
embedding_label_q,_ = evaluate(rvq_cosine, embedding_label)
rank = eval_utils.top_k_array_by_batch(df_concept_test_fd, embedding_exp_q, embedding_label_q,cfg.device, 100)
mrr_rank = eval_utils.compute_mmr(rank)
mrrs_1.append(mrr_rank)
models_1.append("cosin_4_512")
ranks_1["cosin_4_512"] = rank
print(f"MRR: {mrr_rank}")
