In [1]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=0,1

env: CUDA_VISIBLE_DEVICES=0,1


In [2]:
# import joblib
import json
import os
import _pickle as cPickle
from tqdm.auto import tqdm

In [3]:
retrieval_model = 'ance-bf'

## Filter passages

In [4]:
split = 'train'

In [5]:
q_ids = set()

In [6]:
train_gold_pids = set()

with open(f'data/msmarco-passage/qrels.train.txt') as f:
    for line in tqdm(f.readlines()):
        q_id, _, p_id, _ = line.rstrip().split('\t')
        q_ids.add(q_id)
        train_gold_pids.add(p_id)
print(f'{len(q_ids):,}')  # 502,939
print(f'{len(train_gold_pids):,}')  # 516,472

if not os.path.exists('data/msmarco-passage/pids.train.gold.txt'):
    print(f"writing to data/msmarco-passage/pids.train.gold.txt")
    with open('data/msmarco-passage/pids.train.gold.txt', 'w') as f:
        for p_id in sorted(train_gold_pids, key=lambda x: int(x)):
            f.write(f'{p_id}\n')

HBox(children=(FloatProgress(value=0.0, max=532761.0), HTML(value='')))


502,939
516,472


In [7]:
dev_gold_pids = set()

with open('data/msmarco-passage/qrels.dev.small.txt') as f:
    for line in tqdm(f.readlines()):
        q_id, _, p_id, _ = line.rstrip().split('\t')
        q_ids.add(q_id)
        dev_gold_pids.add(p_id)
print(f'{len(q_ids):,}')  # 509,919 = 502,939 + 6,980
print(f'{len(dev_gold_pids):,}')  # 7,433

if not os.path.exists('data/msmarco-passage/pids.dev.small.gold.txt'):
    with open('data/msmarco-passage/pids.dev.small.gold.txt', 'w') as f:
        for p_id in sorted(dev_gold_pids, key=lambda x: int(x)):
            f.write(f'{p_id}\n')

HBox(children=(FloatProgress(value=0.0, max=7437.0), HTML(value='')))


509,919
7,433


In [8]:
gold_pids = train_gold_pids | dev_gold_pids
print(f'{len(gold_pids):,}')  # 523,598

if not os.path.exists('data/msmarco-passage/pids.gold.txt'):
    with open('data/msmarco-passage/pids.gold.txt', 'w') as f:
        for p_id in sorted(gold_pids, key=lambda x: int(x)):
            f.write(f'{p_id}\n')

523,598


In [9]:
import glob

top_k_for_this = 20
top_k_for_other = 5
top_pids = set()

with open(f'data/msmarco-passage/run/{retrieval_model}.{split}.tsv') as f:
    for line in tqdm(f.readlines()):
        q_id, p_id, rank = line.rstrip().split('\t')
        if q_id not in q_ids:
            continue
        rank = int(rank)
        if rank <= top_k_for_this:
            top_pids.add(p_id)
print(f'{retrieval_model} {len(top_pids):,}')

for file_name in sorted(glob.glob(f'data/msmarco-passage/run/*.{split}.tsv')):
    rm = file_name.split('/')[-1].split('.')[0]
    if rm == retrieval_model:
        continue
    with open(file_name) as f:
        for line in tqdm(f.readlines()):
            q_id, p_id, rank = line.rstrip().split('\t')
            if q_id not in q_ids:
                continue
            rank = int(rank)
            if rank <= top_k_for_other:
                top_pids.add(p_id)
    print(f'{rm} {len(top_pids):,}')

top_pids.update(train_gold_pids if 'train' in split else dev_gold_pids)
print(f'{len(top_pids):,}')
# gold: 516,472
#    1: 1,330,614 6.2GB
#    2: 1,999,058 9.3GB
#    3: 2,574,081 11.6GB
#    5: 3,526,992 16.1GB
#   10: 3,863,152 23.2(27.2)GB
#   20: 6,882,374 32.0GB
#  all: 8,841,823 40.5GB

with open(f'data/msmarco-passage/pids.{split}.{retrieval_model}.top{top_k_for_this}.txt', 'w') as f:
    for p_id in sorted(top_pids, key=lambda x: int(x)):
        f.write(f'{p_id}\n')

HBox(children=(FloatProgress(value=0.0, max=40436550.0), HTML(value='')))


ance-bf 4,251,534


HBox(children=(FloatProgress(value=0.0, max=40433869.0), HTML(value='')))


bm25tuned 4,765,824


HBox(children=(FloatProgress(value=0.0, max=40434110.0), HTML(value='')))


expanded-bm25tuned 5,129,711


HBox(children=(FloatProgress(value=0.0, max=40436550.0), HTML(value='')))


unicoil-b8 5,248,980
5,284,422


## Load full index

## Get full P matrix

In [10]:
import torch
matrix_full = torch.load(f'data/msmarco-passage/matrix/{retrieval_model}.pt')

In [11]:
matrix_full.shape

torch.Size([8841823, 768])

## Get partial P matrix

In [12]:
split = 'train'
selection = 'top20'

In [13]:
from util import load_pids

pids = load_pids(f'data/msmarco-passage/pids.{split}.{retrieval_model}.{selection}.txt')[0]
all_pid2idx = load_pids(f'data/msmarco-passage/pids.all.txt')[1]
print(len(pids), len(all_pid2idx))

5284422 8841823


In [14]:
indices = torch.tensor([all_pid2idx[pid] for pid in pids])

In [15]:
matrix_selection = matrix_full.index_select(dim=0, index=indices)
print(matrix_selection.shape)  # torch.Size([3863152, 768])

torch.Size([5284422, 768])


In [16]:
torch.save(matrix_selection, f'data/msmarco-passage/matrix/{split}.{retrieval_model}.{selection}.pt')

## Retrieval

In [45]:
import time

In [46]:
P, Q, S = None, None, None

In [47]:
device = 'cuda:1'

In [48]:
if P is not None:
    del P
    P = None
P = matrix_selection.to(device)  # matrix_full matrix_selection

In [49]:
torch.cuda.empty_cache()

In [50]:
if Q is not None:
    del Q
    Q = None
Q = torch.rand((P.size(1), 96), dtype=torch.float, device=device, requires_grad=True)

In [51]:
%%time
# top10: 3863152  12.5-14.0GB  384 392 416 423 527 µs 354 ms
# all  : 8841823  27.1-30.4GB  380 450 451 456 µs
if S is not None:
    del S
    S = None
start = time.time()
S = torch.mm(P, Q)
print(time.time() - start)
print(S.shape)

0.8207974433898926
torch.Size([3859912, 96])
CPU times: user 252 ms, sys: 153 ms, total: 405 ms
Wall time: 821 ms
