In [1]:
from sys import modules

IN_COLAB = 'google.colab' in modules
if IN_COLAB:
    !pip install -q ir_axioms[examples] python-terrier

In [2]:
# Start/initialize PyTerrier.
from pyterrier import started, init

if not started():
    init(tqdm="auto")

PyTerrier 0.8.0 has loaded Terrier 5.6 (built by craigmacdonald on 2021-09-17 13:27)

No etc/terrier.properties, using terrier.default.properties for bootstrap configuration.


In [3]:
from pyterrier.datasets import get_dataset, Dataset

# Load dataset.
dataset_name = "msmarco-passage"
dataset: Dataset = get_dataset(f"irds:{dataset_name}")
dataset_train: Dataset = get_dataset(f"irds:{dataset_name}/trec-dl-2019/judged")
dataset_test: Dataset = get_dataset(f"irds:{dataset_name}/trec-dl-2020/judged")

In [4]:
from pathlib import Path

cache_dir = Path("cache/")
index_dir = cache_dir / "indices" / dataset_name.split("/")[0]

In [5]:
from pyterrier.index import IterDictIndexer

if not index_dir.exists():
    indexer = IterDictIndexer(str(index_dir.absolute()))
    indexer.index(
        dataset.get_corpus_iter(),
        fields=["text"]
    )

In [6]:
from pyterrier.batchretrieve import BatchRetrieve

# BM25 baseline retrieval.
bm25 = BatchRetrieve(str(index_dir.absolute()), wmodel="BM25")

In [7]:
from ir_axioms.axiom import (
    ArgUC, QTArg, QTPArg, aSL, PROX1, PROX2, PROX3, PROX4, PROX5, TFC1, TFC3, RS_TF, RS_TF_IDF, RS_BM25, RS_PL2, RS_QL,
    AND, LEN_AND, M_AND, LEN_M_AND, DIV, LEN_DIV, M_TDC, LEN_M_TDC, STMC1, STMC1_f, STMC2, STMC2_f, LNC1, TF_LNC, LB1,
    REG, ANTI_REG, REG_f, ANTI_REG_f, ASPECT_REG, ASPECT_REG_f, ORIG
)

axioms = [
    ~ArgUC(), ~QTArg(), ~QTPArg(), ~aSL(),
    ~LNC1(), ~TF_LNC(), ~LB1(),
    ~PROX1(), ~PROX2(), ~PROX3(), ~PROX4(), ~PROX5(),
    ~REG(), ~REG_f(), ~ANTI_REG(), ~ANTI_REG_f(), ~ASPECT_REG(), ~ASPECT_REG_f(),
    ~AND(), ~LEN_AND(), ~M_AND(), ~LEN_M_AND(), ~DIV(), ~LEN_DIV(),
    ~RS_TF(), ~RS_TF_IDF(), ~RS_BM25(), ~RS_PL2(), ~RS_QL(),
    ~TFC1(), ~TFC3(), ~M_TDC(), ~LEN_M_TDC(),
    ~STMC1(), ~STMC1_f(), ~STMC2(), ~STMC2_f(),
    ORIG()
]

In [8]:
from statistics import mean, variance
from ir_axioms.backend.pyterrier.transformers import AggregatedAxiomaticPreference

aggregations = [
    mean,
    lambda ps: mean(float(p >= 0) for p in ps),
    lambda ps: mean(float(p <= 0) for p in ps),
]
features = bm25 % 20 >> AggregatedAxiomaticPreference(
    axioms=axioms,
    index=index_dir,
    aggregations=aggregations,
    dataset=dataset_name,
    verbose=True,
)

In [9]:
features.transform(dataset_train.get_topics()[:1])["features"]

Aggregating query axiom preferences:   0%|          | 0/1 [00:00<?, ?query/s]

0     [-0.25, 0.75, 1.0, -0.25, 0.75, 1.0, 0.05, 0.9...
1     [-0.15, 0.8, 0.95, -0.15, 0.8, 0.95, 0.05, 0.8...
2     [-0.2, 0.8, 1.0, -0.05, 0.9, 0.95, -0.1, 0.8, ...
3     [-0.1, 0.85, 0.95, -0.05, 0.9, 0.95, 0.15, 0.9...
4     [0.3, 1.0, 0.7, 0.25, 1.0, 0.75, -0.25, 0.75, ...
5     [0.2, 1.0, 0.8, 0.2, 1.0, 0.8, -0.2, 0.8, 1.0,...
6     [0.1, 0.95, 0.85, 0.25, 1.0, 0.75, -0.25, 0.75...
7     [0.2, 1.0, 0.8, 0.2, 1.0, 0.8, -0.05, 0.85, 0....
8     [0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ...
9     [0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ...
10    [-0.05, 0.95, 1.0, -0.05, 0.95, 1.0, 0.05, 1.0...
11    [0.05, 1.0, 0.95, 0.05, 1.0, 0.95, -0.05, 0.95...
12    [0.1, 0.95, 0.85, -0.3, 0.7, 1.0, 0.0, 0.85, 0...
13    [0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ...
14    [-0.15, 0.85, 1.0, -0.05, 0.9, 0.95, 0.25, 1.0...
15    [-0.05, 0.95, 1.0, -0.05, 0.95, 1.0, -0.05, 0....
16    [0.1, 1.0, 0.9, 0.1, 1.0, 0.9, 0.0, 0.95, 0.95...
17    [0.05, 0.9, 0.85, -0.05, 0.9, 0.95, 0.25, 

In [10]:
from lightgbm import LGBMRanker
from pyterrier.ltr import apply_learned_model

lambda_mart = LGBMRanker(
    task="train",
    num_leaves=32,
    objective="lambdarank",
    metric="ndcg",
    ndcg_eval_at=[10],
    learning_rate=.1,
    num_iterations=2000,
    importance_type="gain",
)
ltr = features >> apply_learned_model(lambda_mart, form="ltr")

In [11]:
ltr.fit(
    dataset_train.get_topics(),
    dataset_train.get_qrels(),
    dataset_test.get_topics()[:10],
    dataset_test.get_qrels()
)

Aggregating query axiom preferences:   0%|          | 0/43 [00:00<?, ?query/s]

Aggregating query axiom preferences:   0%|          | 0/10 [00:00<?, ?query/s]



[1]	valid_0's ndcg@10: 0.57498
[2]	valid_0's ndcg@10: 0.612858
[3]	valid_0's ndcg@10: 0.661236
[4]	valid_0's ndcg@10: 0.632767
[5]	valid_0's ndcg@10: 0.637499
[6]	valid_0's ndcg@10: 0.646993
[7]	valid_0's ndcg@10: 0.662497
[8]	valid_0's ndcg@10: 0.646235
[9]	valid_0's ndcg@10: 0.690048
[10]	valid_0's ndcg@10: 0.672579
[11]	valid_0's ndcg@10: 0.67564
[12]	valid_0's ndcg@10: 0.66026
[13]	valid_0's ndcg@10: 0.649855
[14]	valid_0's ndcg@10: 0.649434
[15]	valid_0's ndcg@10: 0.64581
[16]	valid_0's ndcg@10: 0.688721
[17]	valid_0's ndcg@10: 0.685211
[18]	valid_0's ndcg@10: 0.663255
[19]	valid_0's ndcg@10: 0.660161
[20]	valid_0's ndcg@10: 0.656774
[21]	valid_0's ndcg@10: 0.651487
[22]	valid_0's ndcg@10: 0.623734
[23]	valid_0's ndcg@10: 0.633398
[24]	valid_0's ndcg@10: 0.636775
[25]	valid_0's ndcg@10: 0.661007
[26]	valid_0's ndcg@10: 0.638607
[27]	valid_0's ndcg@10: 0.640335
[28]	valid_0's ndcg@10: 0.639702
[29]	valid_0's ndcg@10: 0.62548
[30]	valid_0's ndcg@10: 0.625979
[31]	valid_0's ndcg@10: 

In [12]:
from pyterrier.pipelines import Experiment
from ir_measures import nDCG, MAP, RR

experiment = Experiment(
    [bm25, ltr ^ bm25],
    dataset_test.get_topics()[10:],
    dataset_test.get_qrels(),
    [nDCG @ 10, RR, MAP],
    ["BM25", "Axiomatic LTR"],
    verbose=True,
)
experiment.sort_values(by="nDCG@10", ascending=False, inplace=True)

pt.Experiment:   0%|          | 0/2 [00:00<?, ?system/s]

Aggregating query axiom preferences:   0%|          | 0/44 [00:00<?, ?query/s]

In [13]:
experiment

Unnamed: 0,name,nDCG@10,RR,AP
1,Axiomatic LTR,0.480283,0.791757,0.352412
0,BM25,0.470551,0.799107,0.345938


In [14]:
from numpy import ndarray

feature_importance: ndarray = lambda_mart.feature_importances_.reshape(-1, len(aggregations))
feature_importance

array([[2.11072235e+00, 1.45804139e+01, 6.76848184e-01],
       [3.62154975e+00, 4.77092407e+00, 3.47168799e+00],
       [1.42671734e+01, 3.16071471e+01, 1.02099511e+01],
       [8.84088085e+00, 4.66749941e+00, 2.62178812e+00],
       [9.47046489e-01, 5.24948858e-01, 6.16518576e+00],
       [1.82372136e+00, 9.82869358e+00, 4.62238961e+00],
       [6.62425878e+00, 1.39257909e+01, 9.64649354e+00],
       [9.56909546e+00, 1.22768332e+01, 2.11051225e+01],
       [1.17356163e+01, 1.54202767e+01, 6.24066069e+01],
       [0.00000000e+00, 0.00000000e+00, 2.44809505e-01],
       [1.05234939e+00, 6.75013890e+00, 2.93704625e+00],
       [5.96841491e+00, 1.14042244e+01, 6.21733276e+00],
       [2.81036635e+00, 1.91805216e+01, 1.37271448e+01],
       [1.61359699e+01, 2.04250198e+01, 1.27456253e+01],
       [5.87392950e-01, 1.57819854e+01, 1.96993307e+01],
       [6.83497229e+00, 1.94608750e+01, 2.28027103e+01],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [0.00000000e+00, 0.00000

In [15]:
feature_importance.sum(0)

array([320.28658604, 428.31172486, 369.14182992])

In [16]:
feature_importance.sum(1)

array([1.73679844e+01, 1.18641618e+01, 5.60842717e+01, 1.61301684e+01,
       7.63718111e+00, 1.62748046e+01, 3.01965432e+01, 4.29510512e+01,
       8.95625000e+01, 2.44809505e-01, 1.07395345e+01, 2.35899721e+01,
       3.57180328e+01, 4.93066151e+01, 3.60687090e+01, 4.90985575e+01,
       0.00000000e+00, 0.00000000e+00, 1.22962191e+01, 4.69134086e+00,
       8.18784949e+01, 3.36715723e+00, 6.47696989e+01, 1.43321731e+01,
       3.43342023e+01, 6.66354204e+01, 1.68008471e+01, 3.40550952e+01,
       8.82531386e+01, 2.46431726e+01, 0.00000000e+00, 5.74661699e-02,
       0.00000000e+00, 6.93797970e+01, 5.74371600e+01, 1.34854773e+01,
       3.48309140e+01, 3.65746911e+00])