In [9]:
%load_ext autoreload
%autoreload 2

import torch
import pyterrier as pt
import pandas as pd

from src.neural_ranker.ranker import NeuralRanker
from src.neural_ranker.produce_rankings import IRDataset, Processor
from src.llm.llm import LLM_zeroshot, LLM_query_exp
from domain_adaptation import self_training_domain_adaptation
from main import rank_with_base_model, contrastive_train_neural_ranker, rank_with_contrastive_model, pseudo_labels_fine_tune, rank_with_pseudo_labels_model
from eval import Model, evaluate_rankings

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    gpu_mem_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
    gpu_mem_reserved = torch.cuda.memory_reserved(0) / 1024**3
    gpu_mem_allocated = torch.cuda.memory_allocated(0) / 1024**3
    gpu_mem_free = gpu_mem_total - gpu_mem_reserved
    print(f"GPU Memory: Total: {gpu_mem_total:.2f}GB, Reserved: {gpu_mem_reserved:.2f}GB, "
        f"Allocated: {gpu_mem_allocated:.2f}GB, Free: {gpu_mem_free:.2f}GB")

Using device: cuda
GPU Memory: Total: 15.99GB, Reserved: 0.00GB, Allocated: 0.00GB, Free: 15.99GB


In [3]:
bair = 'irds:beir/trec-covid'
cord19 = 'irds:cord19/trec-covid'

bair_dataset = IRDataset(bair, max_docs=None)
cord19_dataset = IRDataset(cord19, max_docs=None)

Loading up to None documents from irds:beir/trec-covid...


beir/trec-covid documents: 100%|██████████| 171332/171332 [00:00<00:00, 302268.09it/s]
Loading Documents: 171332it [00:00, 302267.07it/s]


Loading up to None documents from irds:cord19/trec-covid...


cord19/trec-covid documents: 100%|██████████| 192509/192509 [00:00<00:00, 310197.61it/s]
Loading Documents: 192509it [00:00, 310197.61it/s]


In [4]:
rank_with_base_model(cord19_dataset, device)

Checking for existing embeddings...
Loaded 192509 existing document embeddings.
Document encoding complete. Now ranking queries...
There are multiple query fields available: ('title', 'description', 'narrative'). To use with pyterrier, provide variant or modify dataframe to add query column.
Loaded 50 queries
Ranking 50 queries against 192509 documents...


Ranking Queries: 100%|██████████| 50/50 [00:27<00:00,  1.85it/s]


In [6]:
contrastive_train_neural_ranker(device=device)

Loading dataset...
Loading up to None documents from irds:beir/trec-covid...


beir/trec-covid documents: 100%|██████████| 171332/171332 [00:00<00:00, 318056.10it/s]
Loading Documents: 171332it [00:00, 318647.87it/s]


Processing documents...


Processing documents: 100%|██████████| 171332/171332 [00:00<00:00, 1146138.45it/s]


Loaded 171332 documents for contrastive training.

Initializing augmentor...
Creating contrastive dataset...

Creating dataloader...

Dataloader configuration: batch_size=8, num_workers=0

Initializing model...
Model has 109,482,240 trainable parameters

Initializing trainer...

Training for up to 3 epochs with early stopping (patience=2, min_delta=0.001)...


Epoch 1/3:   1%|▏         | 318/21417 [01:29<1:39:02,  3.55it/s, loss=0.0010]

Loss 0.000999 is below threshold. Stopping early.

Saving model to models/contrastive_model.pt...





Model saved successfully!


In [13]:
rank_with_contrastive_model(cord19_dataset, device)

  ranker.load_state_dict(torch.load("models/domain_adapted_model.pt", map_location=mydevice))


Checking for existing embeddings...
No existing embeddings found or file corrupted. Starting from scratch.

Processing chunk 1/39 (documents 0 to 4999)...


Encoding Chunk 1: 100%|██████████| 79/79 [00:24<00:00,  3.20it/s]



Processing chunk 2/39 (documents 5000 to 9999)...


Encoding Chunk 2: 100%|██████████| 79/79 [00:23<00:00,  3.39it/s]



Processing chunk 3/39 (documents 10000 to 14999)...


Encoding Chunk 3: 100%|██████████| 79/79 [00:23<00:00,  3.43it/s]



Processing chunk 4/39 (documents 15000 to 19999)...


Encoding Chunk 4: 100%|██████████| 79/79 [00:20<00:00,  3.79it/s]



Processing chunk 5/39 (documents 20000 to 24999)...


Encoding Chunk 5: 100%|██████████| 79/79 [00:22<00:00,  3.59it/s]



Processing chunk 6/39 (documents 25000 to 29999)...


Encoding Chunk 6: 100%|██████████| 79/79 [00:23<00:00,  3.35it/s]



Processing chunk 7/39 (documents 30000 to 34999)...


Encoding Chunk 7: 100%|██████████| 79/79 [00:24<00:00,  3.24it/s]



Processing chunk 8/39 (documents 35000 to 39999)...


Encoding Chunk 8: 100%|██████████| 79/79 [00:24<00:00,  3.24it/s]



Processing chunk 9/39 (documents 40000 to 44999)...


Encoding Chunk 9: 100%|██████████| 79/79 [00:24<00:00,  3.24it/s]



Processing chunk 10/39 (documents 45000 to 49999)...


Encoding Chunk 10: 100%|██████████| 79/79 [00:24<00:00,  3.23it/s]



Processing chunk 11/39 (documents 50000 to 54999)...


Encoding Chunk 11: 100%|██████████| 79/79 [00:24<00:00,  3.24it/s]



Processing chunk 12/39 (documents 55000 to 59999)...


Encoding Chunk 12: 100%|██████████| 79/79 [00:24<00:00,  3.24it/s]



Processing chunk 13/39 (documents 60000 to 64999)...


Encoding Chunk 13: 100%|██████████| 79/79 [00:24<00:00,  3.24it/s]



Processing chunk 14/39 (documents 65000 to 69999)...


Encoding Chunk 14: 100%|██████████| 79/79 [00:24<00:00,  3.22it/s]



Processing chunk 15/39 (documents 70000 to 74999)...


Encoding Chunk 15: 100%|██████████| 79/79 [00:24<00:00,  3.25it/s]



Processing chunk 16/39 (documents 75000 to 79999)...


Encoding Chunk 16: 100%|██████████| 79/79 [00:24<00:00,  3.25it/s]



Processing chunk 17/39 (documents 80000 to 84999)...


Encoding Chunk 17: 100%|██████████| 79/79 [00:24<00:00,  3.26it/s]



Processing chunk 18/39 (documents 85000 to 89999)...


Encoding Chunk 18: 100%|██████████| 79/79 [00:24<00:00,  3.26it/s]



Processing chunk 19/39 (documents 90000 to 94999)...


Encoding Chunk 19: 100%|██████████| 79/79 [00:23<00:00,  3.33it/s]



Processing chunk 20/39 (documents 95000 to 99999)...


Encoding Chunk 20: 100%|██████████| 79/79 [00:23<00:00,  3.32it/s]



Processing chunk 21/39 (documents 100000 to 104999)...


Encoding Chunk 21: 100%|██████████| 79/79 [00:23<00:00,  3.29it/s]



Processing chunk 22/39 (documents 105000 to 109999)...


Encoding Chunk 22: 100%|██████████| 79/79 [00:23<00:00,  3.30it/s]



Processing chunk 23/39 (documents 110000 to 114999)...


Encoding Chunk 23: 100%|██████████| 79/79 [00:23<00:00,  3.32it/s]



Processing chunk 24/39 (documents 115000 to 119999)...


Encoding Chunk 24: 100%|██████████| 79/79 [00:24<00:00,  3.29it/s]



Processing chunk 25/39 (documents 120000 to 124999)...


Encoding Chunk 25: 100%|██████████| 79/79 [00:23<00:00,  3.33it/s]



Processing chunk 26/39 (documents 125000 to 129999)...


Encoding Chunk 26: 100%|██████████| 79/79 [00:23<00:00,  3.31it/s]



Processing chunk 27/39 (documents 130000 to 134999)...


Encoding Chunk 27: 100%|██████████| 79/79 [00:24<00:00,  3.29it/s]



Processing chunk 28/39 (documents 135000 to 139999)...


Encoding Chunk 28: 100%|██████████| 79/79 [00:24<00:00,  3.25it/s]



Processing chunk 29/39 (documents 140000 to 144999)...


Encoding Chunk 29: 100%|██████████| 79/79 [00:24<00:00,  3.26it/s]



Processing chunk 30/39 (documents 145000 to 149999)...


Encoding Chunk 30: 100%|██████████| 79/79 [00:24<00:00,  3.25it/s]



Processing chunk 31/39 (documents 150000 to 154999)...


Encoding Chunk 31: 100%|██████████| 79/79 [00:24<00:00,  3.25it/s]



Processing chunk 32/39 (documents 155000 to 159999)...


Encoding Chunk 32: 100%|██████████| 79/79 [00:24<00:00,  3.26it/s]



Processing chunk 33/39 (documents 160000 to 164999)...


Encoding Chunk 33: 100%|██████████| 79/79 [00:25<00:00,  3.15it/s]



Processing chunk 34/39 (documents 165000 to 169999)...


Encoding Chunk 34: 100%|██████████| 79/79 [00:24<00:00,  3.22it/s]



Processing chunk 35/39 (documents 170000 to 174999)...


Encoding Chunk 35: 100%|██████████| 79/79 [00:24<00:00,  3.18it/s]



Processing chunk 36/39 (documents 175000 to 179999)...


Encoding Chunk 36: 100%|██████████| 79/79 [00:25<00:00,  3.10it/s]



Processing chunk 37/39 (documents 180000 to 184999)...


Encoding Chunk 37: 100%|██████████| 79/79 [00:25<00:00,  3.09it/s]



Processing chunk 38/39 (documents 185000 to 189999)...


Encoding Chunk 38: 100%|██████████| 79/79 [00:24<00:00,  3.21it/s]



Processing chunk 39/39 (documents 190000 to 192508)...


Encoding Chunk 39: 100%|██████████| 40/40 [00:12<00:00,  3.25it/s]


Completed processing 192509 documents.
Document encoding complete. Now ranking queries...
There are multiple query fields available: ('title', 'description', 'narrative'). To use with pyterrier, provide variant or modify dataframe to add query column.
Loaded 50 queries
Ranking 50 queries against 192509 documents...


Ranking Queries: 100%|██████████| 50/50 [00:32<00:00,  1.53it/s]


In [11]:
evaluate_rankings(model=Model.BASE, dataset_name=cord19)

Evaluating rankings with metrics@10...
Loaded 9625450 ranking entries

=== Evaluation Results ===
map: 0.1859
ndcg: 0.6866
ndcg_cut_10: 0.6570
P_10: 0.6880
recall_10: 0.0165
recip_rank: 0.8403

Evaluation results saved to evaluation_results/base\evaluation_results.csv

Per-query evaluation results:
Query 1:
    map: 0.1281
    recip_rank: 1.0000
    P_10: 0.5000
    recall_10: 0.0072
    ndcg: 0.6776
    ndcg_cut_10: 0.4959
-------------------------

Query 2:
    map: 0.0585
    recip_rank: 0.5000
    P_10: 0.6000
    recall_10: 0.0179
    ndcg: 0.5997
    ndcg_cut_10: 0.5527
-------------------------

Query 3:
    map: 0.1977
    recip_rank: 1.0000
    P_10: 0.9000
    recall_10: 0.0138
    ndcg: 0.7465
    ndcg_cut_10: 0.9364
-------------------------

Query 4:
    map: 0.0535
    recip_rank: 0.1000
    P_10: 0.1000
    recall_10: 0.0018
    ndcg: 0.5830
    ndcg_cut_10: 0.0636
-------------------------

Query 5:
    map: 0.1349
    recip_rank: 0.3333
    P_10: 0.4000
    recall_10: 

{'map': 0.18585182567772524,
 'ndcg': 0.6865586940416332,
 'ndcg_cut_10': 0.6569951008404138,
 'P_10': 0.6879999999999998,
 'recall_10': 0.01649486349187759,
 'recip_rank': 0.840343137254902}

In [14]:
evaluate_rankings(model=Model.CONTRASTIVE, dataset_name=cord19)

Evaluating rankings with metrics@10...
Loaded 9625450 ranking entries

=== Evaluation Results ===
map: 0.0033
ndcg: 0.4235
ndcg_cut_10: 0.0000
P_10: 0.0000
recall_10: 0.0000
recip_rank: 0.0026

Evaluation results saved to evaluation_results/contrastive\evaluation_results.csv

Per-query evaluation results:
Query 1:
    map: 0.0046
    recip_rank: 0.0021
    P_10: 0.0000
    recall_10: 0.0000
    ndcg: 0.4614
    ndcg_cut_10: 0.0000
-------------------------

Query 2:
    map: 0.0060
    recip_rank: 0.0078
    P_10: 0.0000
    recall_10: 0.0000
    ndcg: 0.4451
    ndcg_cut_10: 0.0000
-------------------------

Query 3:
    map: 0.0031
    recip_rank: 0.0004
    P_10: 0.0000
    recall_10: 0.0000
    ndcg: 0.4277
    ndcg_cut_10: 0.0000
-------------------------

Query 4:
    map: 0.0031
    recip_rank: 0.0022
    P_10: 0.0000
    recall_10: 0.0000
    ndcg: 0.4253
    ndcg_cut_10: 0.0000
-------------------------

Query 5:
    map: 0.0036
    recip_rank: 0.0002
    P_10: 0.0000
    reca

{'map': 0.003338411312322189,
 'ndcg': 0.42352346833935955,
 'ndcg_cut_10': 0.0,
 'P_10': 0.0,
 'recall_10': 0.0,
 'recip_rank': 0.002590893580737511}