In [1]:
!pip install torch transformers faiss-gpu tensorboard boto3

Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Collecting botocore<1.30.0,>=1.29.100 (from boto3)
  Downloading botocore-1.29.165-py3-none-any.whl.metadata (5.9 kB)
Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m20.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading botocore-1.29.165-py3-none-any.whl (11.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.0/11.0 MB[0m [31m92.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-gpu, botocore
  Attempting uninstall: botocore
    Found existing installation: botocore 1.34.69
    Uninstalling botocore-1.34.69:
      Successfully uninstalled botocore-1.34.69
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This beh

In [2]:
import sys

from torch._C import dtype
sys.path += ['./']
import torch
from torch import nn
import transformers
if int(transformers.__version__[0]) <=3:
    from transformers.modeling_roberta import RobertaPreTrainedModel
else:
    from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
from transformers import RobertaModel
import torch.nn.functional as F
from torch.cuda.amp import autocast

class EmbeddingMixin:
    """
    Mixin for common functions in most embedding models. Each model should define its own bert-like backbone and forward.
    We inherit from RobertaModel to use from_pretrained 
    """
    def __init__(self, model_argobj):
        if model_argobj is None:
            self.use_mean = False
        else:
            self.use_mean = model_argobj.use_mean
        print("Using mean:", self.use_mean)

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding, nn.Conv1d)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=0.02)

    def masked_mean(self, t, mask):
        s = torch.sum(t * mask.unsqueeze(-1).float(), axis=1)
        d = mask.sum(axis=1, keepdim=True).float()
        return s / d

    def masked_mean_or_first(self, emb_all, mask):
        # emb_all is a tuple from bert - sequence output, pooler
        assert isinstance(emb_all, tuple)
        if self.use_mean:
            return self.masked_mean(emb_all[0], mask)
        else:
            return emb_all[0][:, 0]

    def query_emb(self, input_ids, attention_mask):
        raise NotImplementedError("Please Implement this method")

    def body_emb(self, input_ids, attention_mask):
        raise NotImplementedError("Please Implement this method")


class BaseModelDot(EmbeddingMixin):
    def _text_encode(self, input_ids, attention_mask):
        # TODO should raise NotImplementedError
        # temporarily do this  
        return None 

    def query_emb(self, input_ids, attention_mask):
        outputs1 = self._text_encode(input_ids=input_ids,
                                attention_mask=attention_mask)
        full_emb = self.masked_mean_or_first(outputs1, attention_mask)
        query1 = self.norm(self.embeddingHead(full_emb))
        return query1

    def body_emb(self, input_ids, attention_mask):
        return self.query_emb(input_ids, attention_mask)

    def forward(self, input_ids, attention_mask, is_query, *args):
        assert len(args) == 0
        if is_query:
            return self.query_emb(input_ids, attention_mask)
        else:
            return self.body_emb(input_ids, attention_mask)

        
class RobertaDot(BaseModelDot, RobertaPreTrainedModel):
    def __init__(self, config, model_argobj=None):
        BaseModelDot.__init__(self, model_argobj)
        RobertaPreTrainedModel.__init__(self, config)
        if int(transformers.__version__[0]) ==4 :
            config.return_dict = False
        self.roberta = RobertaModel(config, add_pooling_layer=False)
        if hasattr(config, "output_embedding_size"):
            self.output_embedding_size = config.output_embedding_size
        else:
            self.output_embedding_size = config.hidden_size
        print("output_embedding_size", self.output_embedding_size)
        self.embeddingHead = nn.Linear(config.hidden_size, self.output_embedding_size)
        self.norm = nn.LayerNorm(self.output_embedding_size)
        self.apply(self._init_weights)

    def _text_encode(self, input_ids, attention_mask):
        outputs1 = self.roberta(input_ids=input_ids,
                                attention_mask=attention_mask)
        return outputs1


In [3]:
import sys
sys.path += ["./"]
import os
import math
import json
import torch
import pickle
import random
import logging
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from torch.utils.data import Dataset
from typing import List

logger = logging.getLogger(__name__)


class TextTokenIdsCache:
    def __init__(self, data_dir, prefix):
        meta = json.load(open(f"{data_dir}/{prefix}_meta"))
        self.total_number = meta['total_number']
        self.max_seq_len = meta['embedding_size']
        try:
            self.ids_arr = np.memmap(f"{data_dir}/{prefix}.memmap", 
                shape=(self.total_number, self.max_seq_len), 
                dtype=np.dtype(meta['type']), mode="r")
            self.lengths_arr = np.load(f"{data_dir}/{prefix}_length.npy")
        except FileNotFoundError:
            self.ids_arr = np.memmap(f"{data_dir}/memmap/{prefix}.memmap", 
                shape=(self.total_number, self.max_seq_len), 
                dtype=np.dtype(meta['type']), mode="r")
            self.lengths_arr = np.load(f"{data_dir}/memmap/{prefix}_length.npy")
        assert len(self.lengths_arr) == self.total_number
        
    def __len__(self):
        return self.total_number
    
    def __getitem__(self, item):
        return self.ids_arr[item, :self.lengths_arr[item]]

class SequenceDataset(Dataset):
    def __init__(self, ids_cache, max_seq_length):
        self.ids_cache = ids_cache
        self.max_seq_length = max_seq_length
        
    def __len__(self):  
        return len(self.ids_cache)

    def __getitem__(self, item):
        input_ids = self.ids_cache[item].tolist()
        seq_length = min(self.max_seq_length-1, len(input_ids)-1)
        input_ids = [input_ids[0]] + input_ids[1:seq_length] + [input_ids[-1]]
        attention_mask = [1]*len(input_ids)

        ret_val = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "id": item,
        }
        return ret_val
    


def pack_tensor_2D(lstlst, default, dtype, length=None):
    batch_size = len(lstlst)
    length = length if length is not None else max(len(l) for l in lstlst)
    tensor = default * torch.ones((batch_size, length), dtype=dtype)
    for i, l in enumerate(lstlst):
        tensor[i, :len(l)] = torch.tensor(l, dtype=dtype)
    return tensor

    
    
def get_collate_function(max_seq_length):
    cnt = 0
    def collate_function(batch):
        nonlocal cnt
        length = None
        if cnt < 10:
            length = max_seq_length
            cnt += 1

        input_ids = [x["input_ids"] for x in batch]
        attention_mask = [x["attention_mask"] for x in batch]
        data = {
            "input_ids": pack_tensor_2D(input_ids, default=1, 
                dtype=torch.int64, length=length),
            "attention_mask": pack_tensor_2D(attention_mask, default=0, 
                dtype=torch.int64, length=length),
        }
        ids = [x['id'] for x in batch]
        return data, ids
    return collate_function  


In [4]:
import sys
sys.path += ['./']
import os
import faiss
import numpy as np
from tqdm import tqdm
from timeit import default_timer as timer


def index_retrieve(index, query_embeddings, topk, batch=None):
    print("Query Num", len(query_embeddings))
    start = timer()
    if batch is None:
        _, nearest_neighbors = index.search(query_embeddings, topk)
    else:
        query_offset_base = 0
        pbar = tqdm(total=len(query_embeddings))
        nearest_neighbors = []
        while query_offset_base < len(query_embeddings):
            batch_query_embeddings = query_embeddings[query_offset_base:query_offset_base+ batch]
            batch_nn = index.search(batch_query_embeddings, topk)[1]
            nearest_neighbors.extend(batch_nn.tolist())
            query_offset_base += len(batch_query_embeddings)
            pbar.update(len(batch_query_embeddings))
        pbar.close()

    elapsed_time = timer() - start
    elapsed_time_per_query = 1000 * elapsed_time / len(query_embeddings)
    print(f"Elapsed Time: {elapsed_time:.1f}s, Elapsed Time per query: {elapsed_time_per_query:.1f}ms")
    return nearest_neighbors



def construct_flatindex_from_embeddings(embeddings, ids=None):
    dim = embeddings.shape[1]
    print('embedding shape: ' + str(embeddings.shape))
    index = faiss.index_factory(dim, "Flat", faiss.METRIC_INNER_PRODUCT)
    if ids is not None:
        ids = ids.astype(np.int64)
        print(ids.shape, ids.dtype)
        index = faiss.IndexIDMap2(index)
        index.add_with_ids(embeddings, ids)
    else:
        index.add(embeddings)
    return index


gpu_resources = []

def convert_index_to_gpu(index, faiss_gpu_index, useFloat16=False):
    if type(faiss_gpu_index) == list and len(faiss_gpu_index) == 1:
        faiss_gpu_index = faiss_gpu_index[0]
    if isinstance(faiss_gpu_index, int):
        res = faiss.StandardGpuResources()
        res.setTempMemory(512*1024*1024)
        co = faiss.GpuClonerOptions()
        co.useFloat16 = useFloat16
        index = faiss.index_cpu_to_gpu(res, faiss_gpu_index, index, co)
    else:
        global gpu_resources
        if len(gpu_resources) == 0:
            import torch
            for i in range(torch.cuda.device_count()):
                res = faiss.StandardGpuResources()
                res.setTempMemory(256*1024*1024)
                gpu_resources.append(res)

        assert isinstance(faiss_gpu_index, list)
        vres = faiss.GpuResourcesVector()
        vdev = faiss.IntVector()
        co = faiss.GpuMultipleClonerOptions()
        co.shard = True
        co.useFloat16 = useFloat16
        for i in faiss_gpu_index:
            vdev.push_back(i)
            vres.push_back(gpu_resources[i])
        index = faiss.index_cpu_to_gpu_multiple(vres, vdev, index, co)

    return index
        


In [5]:
import sys
sys.path += ["./"]
import os
import torch
import faiss
import logging
import subprocess
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import RobertaConfig

logger = logging.getLogger(__name__)
logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s',
                        datefmt = '%d %H:%M:%S',
                        level = logging.INFO)


def evaluate(args, model):
    """ Train the model """    
    dev_dataset = SequenceDataset(
            TextTokenIdsCache(args.preprocess_dir, f"queries"), 
            args.max_seq_length)
    collate_fn = get_collate_function(args.max_seq_length)
    batch_size = args.pergpu_eval_batch_size
    if args.n_gpu > 1:
        batch_size *= args.n_gpu
    dev_dataloader = DataLoader(dev_dataset, 
        batch_size= batch_size, collate_fn=collate_fn)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
    qembedding_memmap = np.memmap(args.qmemmap_path, dtype="float32",
        shape=(len(dev_dataset), 768), mode="w+")
    with torch.no_grad():
        for step, (batch, qoffsets) in enumerate(tqdm(dev_dataloader)):
            batch = {k:v.to(args.model_device) for k, v in batch.items()}
            model.eval()            
            embeddings = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"], 
                is_query=True)
            embeddings = embeddings.detach().cpu().numpy()
            qembedding_memmap[qoffsets] = embeddings
    return qembedding_memmap


class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [6]:

args = Namespace(
    model_dir = "/kaggle/input/adore-pretrained",
    output_dir = "/kaggle/working/",
    preprocess_dir = "/kaggle/input/adore-preprocess/data/doc/preprocess",
    mode = "dev",
    topk = 100,
    dmemmap_path = "/kaggle/input/star-inference/{args.data_type}/evaluate/star/passages.memmap",
    max_seq_length = 64,
    pergpu_eval_batch_size = 32,
    no_cuda = False,
    faiss_gpus = None   
)

assert os.path.exists(args.dmemmap_path)
os.makedirs(args.output_dir, exist_ok=True)
# Setup CUDA, GPU 
args.use_gpu = torch.cuda.is_available() and not args.no_cuda
args.model_device = torch.device(f"cuda:0" if args.use_gpu else "cpu")
args.n_gpu = 1

# Setup logging
logger.warning("Model Device: %s, n_gpu: %s", args.model_device, args.n_gpu)
config = RobertaConfig.from_pretrained(args.model_dir)
model = RobertaDot.from_pretrained(args.model_dir, config=config)
    
model.to(args.model_device)
logger.info("Training/evaluation parameters %s", args)
# Evaluation
args.qmemmap_path = f"{args.output_dir}/{args.mode}.qembed.memmap"
evaluate(args, model)

doc_embeddings = np.memmap(args.dmemmap_path, 
    dtype=np.float32, mode="r").reshape(-1, model.output_embedding_size)

query_embeddings = np.memmap(args.qmemmap_path, 
    dtype=np.float32, mode="r").reshape(-1, model.output_embedding_size)
model = None
torch.cuda.empty_cache()

index = construct_flatindex_from_embeddings(doc_embeddings, None)
if args.faiss_gpus:
    index = convert_index_to_gpu(index, args.faiss_gpus, False)
else:
    faiss.omp_set_num_threads(32)
nearest_neighbors = index_retrieve(index, query_embeddings, args.topk, batch=32)
output_rank_file = os.path.join(args.output_dir, f"{args.mode}.rank.tsv")
with open(output_rank_file, 'w') as outputfile:
    for qid, neighbors in enumerate(nearest_neighbors):
        for idx, pid in enumerate(neighbors):
            outputfile.write(f"{qid}\t{pid}\t{idx+1}\n")

  return self.fget.__get__(instance, owner)()


Using mean: False
output_embedding_size 768


100%|██████████| 38/38 [00:02<00:00, 17.42it/s]


embedding shape: (563424, 768)
Query Num 1200


100%|██████████| 1200/1200 [00:55<00:00, 21.72it/s]

Elapsed Time: 55.2s, Elapsed Time per query: 46.0ms



