In [1]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=3,1

env: CUDA_VISIBLE_DEVICES=3,1


In [2]:
import json
import os
import _pickle as cPickle
from tqdm.auto import tqdm

## Filter passages

In [3]:
q_ids = set()

In [4]:
train_gold_pids = set()

with open('data/msmarco-passage/qrels.train.txt') as f:
    for line in tqdm(f.readlines()):
        q_id, _, p_id, _ = line.rstrip().split('\t')
        q_ids.add(q_id)
        train_gold_pids.add(p_id)
print(f'{len(q_ids):,}')  # 502,939
print(f'{len(train_gold_pids):,}')  # 516,472

if not os.path.exists('data/msmarco-passage/pids.train.gold.txt'):
    with open('data/msmarco-passage/pids.train.gold.txt', 'w') as f:
        for p_id in sorted(train_gold_pids, key=lambda x: int(x)):
            f.write(f'{p_id}\n')

HBox(children=(FloatProgress(value=0.0, max=532761.0), HTML(value='')))


502,939
516,472


In [5]:
dev_gold_pids = set()

with open('data/msmarco-passage/qrels.dev.small.txt') as f:
    for line in tqdm(f.readlines()):
        q_id, _, p_id, _ = line.rstrip().split('\t')
        q_ids.add(q_id)
        dev_gold_pids.add(p_id)
print(f'{len(q_ids):,}')  # 509,919 = 502,939 + 6,980
print(f'{len(dev_gold_pids):,}')  # 7,433

if not os.path.exists('data/msmarco-passage/pids.dev.small.gold.txt'):
    with open('data/msmarco-passage/pids.dev.small.gold.txt', 'w') as f:
        for p_id in sorted(dev_gold_pids, key=lambda x: int(x)):
            f.write(f'{p_id}\n')

HBox(children=(FloatProgress(value=0.0, max=7437.0), HTML(value='')))


509,919
7,433


In [6]:
gold_pids = train_gold_pids | dev_gold_pids
print(f'{len(gold_pids):,}')  # 523,598

if not os.path.exists('data/msmarco-passage/pids.gold.txt'):
    with open('data/msmarco-passage/pids.gold.txt', 'w') as f:
        for p_id in sorted(gold_pids, key=lambda x: int(x)):
            f.write(f'{p_id}\n')

523,598


In [9]:
import glob

top_k = 5
top_pids = set()

for file_name in sorted(glob.glob('data/msmarco-passage/run/*.train.tsv')):
    retrieval_method = file_name.split('/')[-1].split('.')[0]
    with open(file_name) as f:
        for line in tqdm(f.readlines()):
            q_id, p_id, rank = line.rstrip().split('\t')
            if q_id not in q_ids:
                continue
            rank = int(rank)
            if rank <= top_k:
                top_pids.add(p_id)
    print(f'{retrieval_method} {len(top_pids):,}')

top_pids.update(gold_pids)
print(f'{len(top_pids):,}')
# gold: 516,472
#    1: 1,330,614 6.2GB
#    2: 1,999,058 9.3GB
#    3: 2,574,081 11.6GB
#    5: 3,526,992 16.1GB    3,531,017
#   10: 5,172,700 23.2(27.2)GB
#   20: 6,882,374 32.0GB
#  all: 8,841,823 40.5GB

with open(f'data/msmarco-passage/pids.train-top{top_k}.txt', 'w') as f:
    for p_id in sorted(top_pids, key=lambda x: int(x)):
        f.write(f'{p_id}\n')

HBox(children=(FloatProgress(value=0.0, max=40433869.0), HTML(value='')))


bm25tuned 1,629,659


HBox(children=(FloatProgress(value=0.0, max=40434110.0), HTML(value='')))


expanded-bm25tuned 2,806,293


HBox(children=(FloatProgress(value=0.0, max=40436550.0), HTML(value='')))


unicoil-b8 3,409,557
3,531,017


## Build index

In [10]:
%%time
# CPU times: user 13min 23s, sys: 56.1 s, total: 14min 19s
# Wall time: 14min 17s
# 40.5GB
import glob
import gzip

inverted_index = {}

max_weight, min_weight = float('-inf'), float('inf')

for file_name in tqdm(sorted(glob.glob('data/msmarco-passage/vec/unicoil-b8/*.jsonl.gz'))):
    with gzip.open(file_name, 'r') as f:
        for line in f:
            p = json.loads(line)
            p_id = p['id']
            if p_id not in top_pids:
                continue
            for term, weight in p['vector'].items():
                if weight < min_weight:
                    min_weight = weight
                elif weight > max_weight:
                    max_weight = weight
                if weight <= 0:
                    assert weight == 0, f"'{term}' = {weight}"
                    continue
                if term == '[SEP]':
                    continue
                if term not in inverted_index:
                    inverted_index[term] = []
                inverted_index[term].append((p_id, weight))
            # del p
print(len(inverted_index))  # 27677
print(max_weight, min_weight)  # 270 0

HBox(children=(FloatProgress(value=0.0), HTML(value='')))


27668
269 0
CPU times: user 9min 12s, sys: 34.6 s, total: 9min 47s
Wall time: 9min 47s


In [11]:
%%time
# CPU times: user 6min 1s, sys: 1min 41s, total: 7min 42s
# Wall time: 8min 5s
with open(f'data/msmarco-passage/index/unicoil-b8.top{top_k}.pkl', 'wb') as f:
    cPickle.dump(inverted_index, f)

CPU times: user 2min 22s, sys: 54.7 s, total: 3min 17s
Wall time: 3min 16s


## Load index

In [12]:
section = 'top5'

In [13]:
from util import load_pids

pids, pid2idx = load_pids(f'data/msmarco-passage/pids.train-{section}.txt')
print(len(pids), len(pid2idx))

3531017 3531017


In [5]:
%%time
# CPU times: 2min 10s, sys: 27.3 s, total: 2min 37s
# Wall time: 2min 37s
# all: 40.5GB (48GB)
with open(f'data/msmarco-passage/index/unicoil-b8.{section}.pkl', 'rb') as f:
    inverted_index = cPickle.load(f)
print(len(inverted_index))

27668
CPU times: user 53.1 s, sys: 10.6 s, total: 1min 3s
Wall time: 1min 3s


## Retrieval

In [14]:
import time
import torch
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [15]:
n_p = len(pids)
n_t = len(tokenizer.vocab)

In [16]:
# top1: 1.9GB, top5: 4.3GB, top10: 7.6GB
indices = ([], [])
values = []
for term, postings in inverted_index.items():
    t_idx = tokenizer.vocab[term]
    for p_id, weight in postings:
        p_idx = pid2idx[p_id]
        indices[0].append(p_idx)
        indices[1].append(t_idx)
        values.append(weight)
print(f'{len(values):,}')
print(f'{len(values) / (n_p * n_t) * 100: .3f}%')  # 0.213%

226,058,431
 0.210%


In [10]:
csr_Q = None
dense_Q = None
dense_Q_ = None
dense_P = None
dense_P_ = None
S = None

In [19]:
# top1: 1.6GB, top5: 4.3GB, top10: 6.4GB
coo_P_cpu = torch.sparse_coo_tensor(indices, values, size=(n_p, n_t), dtype=torch.float32, device='cpu', requires_grad=False)
torch.save(coo_P_cpu, f'data/msmarco-passage/matrix/unicoil-b8.coo-{section}.pt')
# csr_P_cpu = coo_P_cpu.to_sparse_csr()