# LambdaMART
This notebook shows examples of how we optimize [LambdaMART](https://www.microsoft.com/en-us/research/uploads/prod/2016/02/MSR-TR-2010-82.pdf) for online video cover song identification. 

## Requirements
- set the variables `DATASET_PATH` and `METADATA_PARQUET_PATH` at the beginning of the first cell
- prepare data as described in `README`
- the pairwise predictions per model `MODEL` and dataset `DATASET` under `preds/MODEL/DATASET/preds.pt"`



In [6]:
import torch
import os
import scipy
from torchmetrics.retrieval import RetrievalNormalizedDCG, RetrievalMAP
from src.dataset import TestDataset, OnlineCoverSongDataset
from src.evaluation import RetrievalEvaluation
from src.baselines.blocking import Blocker
from rapidfuzz import fuzz
import xgboost as xgb
import numpy as np

DATASET_PATH = "/data/csi_datasets/"
METADATA_PARQUET_PATH = "/data/yt_metadata.parquet"

mean_average_precision = RetrievalMAP(empty_target_action="skip")

def mean_rank_1(preds, target):
        """
        Compute the mean rank for relevant items in the predictions.
        Args:
            preds (torch.Tensor): A tensor of predicted scores (higher scores indicate more relevant items).
            target (torch.Tensor): A tensor of true relationships (0 for irrelevant, 1 for relevant).
        Returns:
            float: The mean rank of relevant items for each query.
        """
        has_positives = torch.sum(target, 1) > 0
        
        _, spred = torch.topk(preds, preds.size(1), dim=1)
        found = torch.gather(target, 1, spred)
        temp = torch.arange(preds.size(1)).cpu().float() * 1e-6
        _, sel = torch.topk(found - temp, 1, dim=1)
        
        sel = sel.float()
        sel[~has_positives] = torch.nan
        
        mr1 = torch.nanmean((sel+1).float())

        del sel, found, temp, spred, has_positives
        torch.cuda.empty_cache()
        return mr1


In [7]:
def get_audio_preds(model, dataset):

    # get audio preds
    data = get_dataset(model, dataset)
    preds = data.get_csi_pred_matrix(model).cpu()
    preds = torch.where(preds == float('-inf'), 0, preds)
    return preds

def get_fuzzy_preds(dataset):

    # get text preds
    blocker = Blocker(blocking_func=fuzz.token_ratio, threshold=0.5)
    left_df, right_df = dataset.get_dfs_by_task("svShort")
    preds = blocker.predict(left_df, right_df).cpu()
    preds = preds.fill_diagonal_(-float('inf')) / 100
    preds = torch.where(preds == float('-inf'), 0, preds)
    return preds


def get_text_preds(model, dataset):
    if model == "fuzzy":
        return get_fuzzy_preds(get_dataset(model, dataset))
    else:
        return torch.load(f"preds/{model}/{dataset}/preds.pt")


def get_model_mode(model):
    if model == "fuzzy" or model == "sentence-transformers":
        return "tvShort"
    elif model == "ditto" or model == "rsupcon":
        return "rLong"
    elif model == "hiergat_split":
        return "rShort"


def get_dataset(model, dataset):
    csi_path = DATASET_PATH
    metadata_path = METADATA_PARQUET_PATH
    if model == "sentence-transformers":
        return OnlineCoverSongDataset(
                dataset,
                csi_path,
                metadata_path,
                get_model_mode(model)
        )  
    else:
        return TestDataset(
        dataset,
        csi_path,
        metadata_path,
        tokenizer="roberta-base"
        )


def get_ensemble_data(text_model, audio_model, dataset):
    
    data = get_dataset(text_model, dataset)
    
    # get preds
    text_preds = get_text_preds(text_model, dataset).cpu().numpy()
    audio_preds = get_audio_preds(audio_model, dataset).cpu().numpy()

    # get ground truth
    Y = data.get_target_matrix().to(float).cpu()
    
    # get indexes
    m, n = Y.shape
    indexes = torch.arange(m).view(-1, 1).expand(-1, n).cpu()

    # last transform
    y_train = Y.cpu().numpy().flatten()
    X_train = np.concatenate([text_preds.reshape(-1, 1), audio_preds.reshape(-1, 1)], axis=1)

    # get query info array
    qids = indexes.cpu().numpy().flatten()
    return X_train, y_train, qids

def compute_metrics(X_test, y_test, qids, ltr_model, out_path):

    preds = ltr_model.predict(X_test)
    # unflatten
    def unflatten(t):
        return torch.tensor(t.reshape((int(np.sqrt(len(t))), int(np.sqrt(len(t))))))
    
    preds = unflatten(preds)
    # normalize
    preds = (preds - torch.min(preds)) / (torch.max(preds) - torch.min(preds))
    
    target = unflatten(y_test)
    indexes = unflatten(qids)

    torch.save(preds, os.path.join(out_path, "ypreds.pt"))
    torch.save(target, os.path.join(out_path, "ytrue.pt"))

    map_result = mean_average_precision(preds.cpu(), target.cpu(), indexes.cpu())
    mr1_result = mean_rank_1(preds, target)
    return map_result, mr1_result


In [8]:
params = {
    "objective": "rank:map", 
    "lambdarank_pair_method": "topk", 
    "lambdarank_num_pair_per_sample": 50
    }


# Fuzzy Matching

In [9]:

X_train, y_train, qids_train = get_ensemble_data("fuzzy", "coverhunter", "shs100k_1000")
X_val, y_val, qids_val = get_ensemble_data("fuzzy", "coverhunter", "shs100k2_val")

model_fuzzy_ch = xgb.XGBRanker(**params)
model_fuzzy_ch.fit(X_train, y_train, qid=qids_train, eval_set=[(X_val, y_val)], eval_qid=[qids_val])



In [10]:
text_model = "fuzzy"
audio_model = "coverhunter"
dataset = "shs100k2_test"
out_path = os.path.join("preds", f"{text_model}_{audio_model}", dataset)
os.makedirs(out_path, exist_ok=True)

X_test, y_test, qids_test = get_ensemble_data(text_model, audio_model , dataset)
mapr, mr1r = compute_metrics(X_test, y_test, qids_test, model_fuzzy_ch, out_path)
mapr, mr1r


(tensor(0.8973), tensor(4.3458))

# S-BERT

In [15]:
X_train, y_train, qids_train = get_ensemble_data("sentence-transformers", "coverhunter", "shs100k_1000")
X_val, y_val, qids_val = get_ensemble_data("sentence-transformers", "coverhunter", "shs100k2_val")

model_sbert_ch = xgb.XGBRanker(**params)
model_sbert_ch.fit(X_train, y_train, qid=qids_train, eval_set=[(X_val, y_val)], eval_qid=[qids_val])



In [16]:
text_model = "sentence-transformers"
audio_model = "coverhunter"
dataset = "shs100k2_test"
out_path = os.path.join("preds", f"{text_model}_{audio_model}", dataset)
os.makedirs(out_path, exist_ok=True)

X_test, y_test, qids_test = get_ensemble_data(text_model, audio_model , dataset)
mapr, mr1r = compute_metrics(X_test, y_test, qids_test, model_sbert_ch, out_path)
mapr, mr1r


(tensor(0.9303), tensor(3.8029))