In [69]:
from sys import modules

IN_COLAB = 'google.colab' in modules
if IN_COLAB:
    !pip install -q ir_axioms[examples] python-terrier

In [70]:
# Start/initialize PyTerrier.
from pyterrier import started, init

if not started():
    init(tqdm="auto")

In [71]:
from pyterrier.datasets import get_dataset, Dataset

# Load dataset.
dataset_name = "msmarco-passage"
dataset: Dataset = get_dataset(f"irds:{dataset_name}")
dataset_train: Dataset = get_dataset(f"irds:{dataset_name}/trec-dl-2019/judged")
dataset_test: Dataset = get_dataset(f"irds:{dataset_name}/trec-dl-2020/judged")

In [72]:
from pathlib import Path

cache_dir = Path("cache/")
index_dir = cache_dir / "indices" / dataset_name.split("/")[0]

In [73]:
from pyterrier.index import IterDictIndexer

if not index_dir.exists():
    indexer = IterDictIndexer(str(index_dir.absolute()))
    indexer.index(
        dataset.get_corpus_iter(),
        fields=["text"]
    )

In [74]:
from pyterrier.batchretrieve import BatchRetrieve

# BM25 baseline retrieval.
bm25 = BatchRetrieve(str(index_dir.absolute()), wmodel="BM25")

09:41:25.180 [main] WARN org.terrier.structures.BaseCompressingMetaIndex - OutOfMemoryError: Structure meta reading lookup file directly from disk
09:41:25.262 [main] WARN org.terrier.structures.BaseCompressingMetaIndex - OutOfMemoryError: Structure meta reading data file directly from disk


In [75]:
from ir_axioms.axiom import (
    ArgUC, QTArg, QTPArg, aSL, PROX1, PROX2, PROX3, PROX4, PROX5, TFC1, TFC3, RS_TF, RS_TF_IDF, RS_BM25, RS_PL2, RS_QL,
    AND, LEN_AND, M_AND, LEN_M_AND, DIV, LEN_DIV, M_TDC, LEN_M_TDC, STMC1, STMC1_f, STMC2, STMC2_f, LNC1, TF_LNC, LB1,
    REG, ANTI_REG, REG_f, ANTI_REG_f, ASPECT_REG, ASPECT_REG_f, ORIG
)

axioms = [
    ~ArgUC(), ~QTArg(), ~QTPArg(), ~aSL(),
    ~LNC1(), ~TF_LNC(), ~LB1(),
    ~PROX1(), ~PROX2(), ~PROX3(), ~PROX4(), ~PROX5(),
    ~REG(), ~REG_f(), ~ANTI_REG(), ~ANTI_REG_f(), ~ASPECT_REG(), ~ASPECT_REG_f(),
    ~AND(), ~LEN_AND(), ~M_AND(), ~LEN_M_AND(), ~DIV(), ~LEN_DIV(),
    ~RS_TF(), ~RS_TF_IDF(), ~RS_BM25(), ~RS_PL2(), ~RS_QL(),
    ~TFC1(), ~TFC3(), ~M_TDC(), ~LEN_M_TDC(),
    ~STMC1(), ~STMC1_f(), ~STMC2(), ~STMC2_f(),
    ORIG()
]

In [76]:
from statistics import mean
from ir_axioms.backend.pyterrier.transformers import AggregatedAxiomaticPreference

aggregations = [
    lambda ps: mean(float(p >= 0) for p in ps),
    lambda ps: mean(float(p <= 0) for p in ps),
]
features = bm25 % 20 >> AggregatedAxiomaticPreference(
    axioms=axioms,
    index=index_dir,
    aggregations=aggregations,
    dataset=dataset_name,
    verbose=True,
)

In [77]:
features.transform(dataset_train.get_topics()[:1])["features"]

Aggregating query axiom preferences:   0%|          | 0/1 [00:00<?, ?query/s]

0     [0.95, 0.85, 0.95, 1.0, 0.85, 1.0, 0.95, 1.0, ...
1     [0.95, 1.0, 0.95, 0.95, 0.95, 0.95, 1.0, 1.0, ...
2     [1.0, 0.85, 0.95, 0.95, 0.8, 1.0, 0.95, 1.0, 1...
3     [1.0, 0.85, 0.95, 0.95, 1.0, 0.75, 0.95, 1.0, ...
4     [0.8, 0.95, 0.95, 0.95, 0.85, 0.8, 0.95, 1.0, ...
5     [0.9, 1.0, 0.95, 1.0, 0.9, 0.9, 0.95, 1.0, 1.0...
6     [0.9, 1.0, 1.0, 1.0, 0.9, 1.0, 1.0, 0.95, 0.95...
7     [1.0, 1.0, 0.95, 1.0, 0.95, 1.0, 1.0, 1.0, 1.0...
8     [1.0, 0.9, 1.0, 0.9, 0.95, 0.95, 1.0, 1.0, 1.0...
9     [1.0, 0.85, 0.85, 0.95, 0.85, 0.95, 1.0, 0.85,...
10    [0.85, 0.95, 1.0, 0.75, 0.75, 1.0, 0.85, 1.0, ...
11    [0.75, 1.0, 0.75, 1.0, 0.95, 0.85, 0.95, 1.0, ...
12    [0.75, 1.0, 0.8, 1.0, 1.0, 0.7, 0.85, 1.0, 0.9...
13    [1.0, 0.85, 0.95, 1.0, 0.95, 0.9, 1.0, 0.95, 1...
14    [1.0, 0.85, 0.95, 0.85, 0.85, 0.95, 1.0, 0.85,...
15    [0.9, 0.95, 0.95, 1.0, 1.0, 0.85, 0.9, 1.0, 1....
16    [0.95, 0.85, 1.0, 0.7, 0.75, 1.0, 1.0, 0.75, 1...
17    [0.85, 0.95, 0.95, 0.95, 0.95, 0.85, 0.95,

In [78]:
from lightgbm import LGBMRanker
from pyterrier.ltr import apply_learned_model

lambda_mart = LGBMRanker(
    num_iterations=1000,
    metric="ndcg",
    eval_at=[5, 10],
    importance_type="gain",
)
ltr = features >> apply_learned_model(lambda_mart, form="ltr")

In [79]:
ltr.fit(
    dataset_train.get_topics()[:-5],
    dataset_train.get_qrels(),
    dataset_train.get_topics()[-5:],
    dataset_train.get_qrels()
)

Aggregating query axiom preferences:   0%|          | 0/38 [00:00<?, ?query/s]

Aggregating query axiom preferences:   0%|          | 0/5 [00:00<?, ?query/s]



[1]	valid_0's ndcg@5: 0.327646	valid_0's ndcg@10: 0.417289
[2]	valid_0's ndcg@5: 0.291835	valid_0's ndcg@10: 0.413147
[3]	valid_0's ndcg@5: 0.209976	valid_0's ndcg@10: 0.400773
[4]	valid_0's ndcg@5: 0.320268	valid_0's ndcg@10: 0.454609
[5]	valid_0's ndcg@5: 0.320753	valid_0's ndcg@10: 0.471763
[6]	valid_0's ndcg@5: 0.375986	valid_0's ndcg@10: 0.552848
[7]	valid_0's ndcg@5: 0.458908	valid_0's ndcg@10: 0.59261
[8]	valid_0's ndcg@5: 0.437189	valid_0's ndcg@10: 0.58149
[9]	valid_0's ndcg@5: 0.391449	valid_0's ndcg@10: 0.529723
[10]	valid_0's ndcg@5: 0.429818	valid_0's ndcg@10: 0.596405
[11]	valid_0's ndcg@5: 0.41736	valid_0's ndcg@10: 0.601596
[12]	valid_0's ndcg@5: 0.391417	valid_0's ndcg@10: 0.577515
[13]	valid_0's ndcg@5: 0.372261	valid_0's ndcg@10: 0.561022
[14]	valid_0's ndcg@5: 0.371857	valid_0's ndcg@10: 0.549882
[15]	valid_0's ndcg@5: 0.361316	valid_0's ndcg@10: 0.546368
[16]	valid_0's ndcg@5: 0.373222	valid_0's ndcg@10: 0.555008
[17]	valid_0's ndcg@5: 0.371759	valid_0's ndcg@10: 0

In [80]:
from pyterrier.pipelines import Experiment
from ir_measures import nDCG, MAP, RR

experiment = Experiment(
    [bm25, ltr ^ bm25],
    dataset_test.get_topics(),
    dataset_test.get_qrels(),
    [nDCG @ 10, RR, MAP],
    ["BM25", "Axiomatic LTR"],
    verbose=True,
)
experiment.sort_values(by="nDCG@10", ascending=False, inplace=True)

pt.Experiment:   0%|          | 0/2 [00:00<?, ?system/s]

Aggregating query axiom preferences:   0%|          | 0/54 [00:00<?, ?query/s]

In [81]:
experiment

Unnamed: 0,name,nDCG@10,RR,AP
1,Axiomatic LTR,0.493972,0.797126,0.364303
0,BM25,0.493627,0.802359,0.358724


In [82]:
from numpy import ndarray

feature_importance: ndarray = lambda_mart.feature_importances_.reshape(-1, len(aggregations))
feature_importance

array([[1.38255146e+01, 9.83386161e+00],
       [8.19368405e+00, 4.91472697e+00],
       [3.33220371e+01, 1.39172834e+01],
       [3.66694007e+00, 8.02065335e+00],
       [1.67422781e+00, 1.72466187e+01],
       [2.76123816e+00, 1.68144900e+00],
       [1.34220797e+01, 5.80551589e+01],
       [1.06236556e+01, 1.91728264e+01],
       [1.40630147e+01, 7.52702802e+01],
       [0.00000000e+00, 1.74549064e-01],
       [8.73922969e+00, 1.21440273e+01],
       [1.02079806e+01, 6.50640350e+00],
       [1.64036227e+01, 1.96317276e+01],
       [2.96654888e+01, 2.66297964e+01],
       [1.68919878e+01, 8.42878147e+00],
       [2.45054185e+01, 3.05978061e+01],
       [0.00000000e+00, 0.00000000e+00],
       [0.00000000e+00, 0.00000000e+00],
       [1.07856668e+00, 7.81222256e+00],
       [5.06687596e-02, 8.38046363e-02],
       [1.67256700e+01, 1.23256429e+01],
       [1.32466777e-01, 1.40863353e+00],
       [1.89814018e+01, 5.25765021e+01],
       [2.92528407e+00, 5.20215224e+00],
       [1.738498

In [83]:
feature_importance.sum(0)

array([398.40405254, 536.25463633])

In [84]:
feature_importance.sum(1)

array([2.36593762e+01, 1.31084110e+01, 4.72393206e+01, 1.16875934e+01,
       1.89208465e+01, 4.44268716e+00, 7.14772386e+01, 2.97964820e+01,
       8.93332950e+01, 1.74549064e-01, 2.08832570e+01, 1.67143841e+01,
       3.60353503e+01, 5.62952852e+01, 2.53207693e+01, 5.51032246e+01,
       0.00000000e+00, 0.00000000e+00, 8.89078924e+00, 1.34473396e-01,
       2.90513129e+01, 1.54110030e+00, 7.15579039e+01, 8.12743631e+00,
       2.76903653e+01, 4.39325911e+01, 6.28104337e+00, 1.86844021e+01,
       5.67958105e+01, 1.66356523e+01, 0.00000000e+00, 1.12583999e-04,
       0.00000000e+00, 5.18634762e+01, 5.45844995e+01, 2.45963734e+00,
       1.24278428e+01, 3.80816968e+00])