In [1]:
import torch
import scipy
from torchmetrics.retrieval import RetrievalNormalizedDCG, RetrievalMAP
from src.dataset import TestDataset, OnlineCoverSongDataset
from src.evaluation import RetrievalEvaluation
from src.baselines.blocking import Blocker
from rapidfuzz import fuzz
import xgboost as xgb
import numpy as np


mean_average_precision = RetrievalMAP(empty_target_action="skip")

def mean_rank_1(preds, target):
        """
        Compute the mean rank for relevant items in the predictions.
        Args:
            preds (torch.Tensor): A tensor of predicted scores (higher scores indicate more relevant items).
            target (torch.Tensor): A tensor of true relationships (0 for irrelevant, 1 for relevant).
        Returns:
            float: The mean rank of relevant items for each query.
        """
        has_positives = torch.sum(target, 1) > 0
        
        _, spred = torch.topk(preds, preds.size(1), dim=1)
        found = torch.gather(target, 1, spred)
        temp = torch.arange(preds.size(1)).cpu().float() * 1e-6
        _, sel = torch.topk(found - temp, 1, dim=1)
        
        sel = sel.float()
        sel[~has_positives] = torch.nan
        
        mr1 = torch.nanmean((sel+1).float())

        del sel, found, temp, spred, has_positives
        torch.cuda.empty_cache()
        return mr1


  from .autonotebook import tqdm as notebook_tqdm


In [24]:
def get_audio_preds(model, dataset):

    # get audio preds
    data = get_dataset(model, dataset)
    preds = data.get_csi_pred_matrix(model).cpu()
    preds = torch.where(preds == float('-inf'), 0, preds)
    return preds

def get_fuzzy_preds(dataset):

    # get text preds
    blocker = Blocker(blocking_func=fuzz.token_ratio, threshold=0.5)
    left_df, right_df = dataset.get_dfs_by_task("svShort")
    preds = blocker.predict(left_df, right_df).cpu()
    preds = preds.fill_diagonal_(-float('inf')) / 100
    preds = torch.where(preds == float('-inf'), 0, preds)
    return preds


def get_text_preds(model, dataset):
    if model == "fuzzy":
        return get_fuzzy_preds(get_dataset(model, dataset))
    else:
        return torch.load(f"preds/{model}/{dataset}/preds.pt")


def get_model_mode(model):
    if model == "fuzzy" or model == "sentence-transformers":
        return "tvShort"
    elif model == "ditto" or model == "rsupcon":
        return "rLong"
    elif model == "hiergat_split":
        return "rShort"


def get_dataset(model, dataset):
    csi_path = "/data/csi_datasets/"
    metadata_path = "/data/yt_metadata.parquet"
    if model == "sentence-transformers":
        return OnlineCoverSongDataset(
                dataset,
                csi_path,
                metadata_path,
                get_model_mode(model)
        )  
    else:
        return TestDataset(
        dataset,
        csi_path,
        metadata_path,
        tokenizer="roberta-base"
        )


def get_ensemble_data(text_model, audio_model, dataset):
    
    data = get_dataset(text_model, dataset)
    
    # get preds
    text_preds = get_text_preds(text_model, dataset).cpu().numpy()
    audio_preds = get_audio_preds(audio_model, dataset).cpu().numpy()

    # get ground truth
    Y = data.get_target_matrix().to(float).cpu()
    
    # get indexes
    m, n = Y.shape
    indexes = torch.arange(m).view(-1, 1).expand(-1, n).cpu()

    # last transform
    y_train = Y.cpu().numpy().flatten()
    X_train = np.concatenate([text_preds.reshape(-1, 1), audio_preds.reshape(-1, 1)], axis=1)

    # get query info array
    qids = indexes.cpu().numpy().flatten()
    return X_train, y_train, qids



# Fuzzy Matching

In [25]:
X_train, y_train, qids_train = get_ensemble_data("fuzzy", "cqtnet", "shs100k_1000")
X_val, y_val, qids_val = get_ensemble_data("fuzzy", "cqtnet", "shs100k2_val")

model_fuzzy = xgb.XGBRanker(objective="rank:map")
model_fuzzy.fit(X_train, y_train, qid=qids_train, eval_set=[(X_val, y_val)], eval_qid=[qids_val])


In [None]:
def compute_metrics(X_test, y_test, qids, ltr_model):

    preds = ltr_model.predict(X_test)
    # unflatten
    def unflatten(t):
        return torch.tensor(t.reshape((int(np.sqrt(len(t))), int(np.sqrt(len(t))))))
    
    preds = unflatten(preds)
    # normalize
    preds = (preds - torch.min(preds)) / (torch.max(preds) - torch.min(preds))
    
    target = unflatten(y_test)
    indexes = unflatten(qids)

    map_result = mean_average_precision(preds.cpu(), target.cpu(), indexes.cpu())
    mr1_result = mean_rank_1(preds, target)
    return map_result, mr1_result

X_test, y_test, qids_test = get_ensemble_data("fuzzy", "cqtnet", "shs100k2_test")
mapr, mr1r = compute_metrics(X_test, y_test, qids_test, model_fuzzy)
mapr, mr1r


(tensor(0.7514), tensor(15.5964))

In [None]:
X_test, y_test, qids_test = get_ensemble_data("fuzzy", "cqtnet", "da-tacos")
mapr, mr1r = compute_metrics(X_test, y_test, qids_test, model_fuzzy)
mapr, mr1r


(tensor(0.8044), tensor(3.8126))

# S-BERT

In [21]:
X_train, y_train, qids_train = get_ensemble_data("sentence-transformers", "cqtnet", "shs100k_1000")
X_val, y_val, qids_val = get_ensemble_data("sentence-transformers", "cqtnet", "shs100k2_val")

model_sbert = xgb.XGBRanker(objective="rank:map")
model_sbert.fit(X_train, y_train, qid=qids_train, eval_set=[(X_val, y_val)], eval_qid=[qids_val])



In [22]:
X_test, y_test, qids_test = get_ensemble_data("sentence-transformers", "cqtnet", "shs100k2_test")
mapr, mr1r = compute_metrics(X_test, y_test, qids_test, model_sbert)
mapr, mr1r


(tensor(0.8449), tensor(12.3406))

In [23]:
X_test, y_test, qids_test = get_ensemble_data("sentence-transformers", "cqtnet", "da-tacos")
mapr, mr1r = compute_metrics(X_test, y_test, qids_test, model_sbert)
mapr, mr1r


KeyboardInterrupt: 

# Datenset für LambdaMART Optimization
- Option a) "Overfit" am Validation Datenset
- Option b) Cross validation am Validation Datenset
- Option c) Fitten am (subset) vom Training set, validation am validation set
    - wie subset vom training set nehmen?