In [1]:
import torch
from transformers import AutoTokenizer,AutoModelForMaskedLM
from splade_preprocessor import SparseDocTextPreprocessor
from pymongo import MongoClient
import torch.autograd.profiler as profiler
from functools import partial
import sys

model_names = [
    "naver/splade_v2_max",
    "naver/splade_v2_distil",
    "naver/splade-cocondenser-ensembledistil",
    "naver/efficient-splade-VI-BT-large-query",
    "naver/efficient-splade-VI-BT-large-doc",
]
splade=  model_names[3]
def tokenizer2(func):
    def _tokenizer2(*args, **kwargs):
        to_return = func(*args, **kwargs)
        del to_return['token_type_ids']
        return to_return
    return partial(_tokenizer2, return_tensors="pt", padding="longest", truncation=True, max_length=128)

doc_text = MongoClient('localhost',27017)['catalogStore']['doc_text']
doc_samples = []
counter = 0
for doc in doc_text.find({}).limit(32):
    counter += 1
    doc_samples.append(doc['text'])
    if counter == 32:
        break

tokenizer = AutoTokenizer.from_pretrained(splade)
# spl_tokenizer = tokenizer2(tokenizer)
spl_tokenizer = partial(tokenizer, return_tensors="pt", padding="longest", truncation=True, max_length=128)


# spl_tokenizer("hell")
proc = SparseDocTextPreprocessor()
doc_samples = [proc.clean_text(doc) for doc in doc_samples]
tokens = spl_tokenizer(doc_samples, return_tensors="pt")

input_ids = tokens['input_ids']
token_type_ids = tokens["token_type_ids"]
attention_mask = tokens["attention_mask"]

In [2]:
# from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor


class TransformerMLM(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # with profiler.record_function("model init"):
        self.model = AutoModelForMaskedLM.from_pretrained(splade, torchscript = True)
        # self.model.to('cuda')  # type: ignore
        self.model.eval()


    def forward(self, input_ids, token_type_ids, attention_mask):
        # with profiler.record_function("model forward"):
        with torch.cuda.amp.autocast(enabled=True):  # type: ignore
            with torch.no_grad():
                # This model produces a tuple as an output
                return self.model(input_ids, token_type_ids, attention_mask)[0]
                

    
class SparseModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bertlm = TransformerMLM()
        self.bertlm.eval()
        # self.bertlm = self.bertlm.to("cuda")

    def forward(self, input_ids, token_type_ids, attention_mask):
        with torch.cuda.amp.autocast(enabled=True):  # type: ignore
            with torch.no_grad():
                mlm_logits = self.bertlm(input_ids, token_type_ids, attention_mask)
                mlm_logits, _ = torch.max(
                    torch.log(1+torch.relu(mlm_logits))*attention_mask.unsqueeze(-1),
                dim=1
                )
                del _
                return mlm_logits         


sm = SparseModel()
# sm = sm.to("cuda")
sm = sm.eval()
traced_model = torch.jit.trace(sm, [input_ids,token_type_ids,attention_mask])
torch_jit_model_path = splade.replace("naver/","splade_models/")+'.pt'
onnx_model_path = splade.replace("naver/","splade_onnx/")+'.onnx'
traced_model.save(torch_jit_model_path) # type: ignore
del sm 
sm = torch.jit.load(torch_jit_model_path)
# sm = sm.to("cuda")
sm = sm.eval()
torch.onnx.export(
    sm,
    (input_ids,token_type_ids,attention_mask),
    onnx_model_path,
    do_constant_folding=True,
    input_names=["input_ids","token_type_ids","attention_mask"],
    output_names=["sparse_embeddings"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},  # variable lenght axes
        "token_type_ids": {0: "batch_size", 1: "sequence_length"},  # variable lenght axes
        "attention_mask": {0: "batch_size", 1: "sequence_length"},
        "sparse_embeddings":{0: "batch_size"}
        })



In [3]:
mlm = sm(input_ids,attention_mask)

In [11]:
a,b = mlm.nonzero(as_tuple=True)
a = mlm.nonzero()[:,0]
b = mlm.nonzero()[:,1]

In [12]:
mlm[a,b]

tensor([0.4475, 0.2231, 0.3381,  ..., 0.4599, 0.1003, 0.0896], device='cuda:0',
       grad_fn=<IndexBackward0>)

In [137]:
class SpareseResults(torch.nn.Module):
    def __init__(self):
        super().__init__()
    

    def forward(self,mlm_logits):
        with torch.no_grad():
            batch_size = mlm_logits.size(0)
            mlm_nz = mlm_logits.nonzero()
            vec_indices = torch.vstack((mlm_nz[:,0], mlm_nz[:,1]))
            vec_values = mlm_logits[mlm_nz[:,0], mlm_nz[:,1]]
            del mlm_logits
            results = torch.zeros((batch_size, 2,  512),device="cuda")
            for row in range(batch_size):
                indices = torch.zeros(512,device="cuda")  # type: ignore
                values = torch.zeros(512,device="cuda")  # type: ignore
                mask = vec_indices[0].eq(row)
                row_indices = torch.masked_select(vec_indices[1], mask)
                indices[:row_indices.shape[0]] = row_indices
                row_values = torch.masked_select(vec_values, mask)
                values[:row_values.shape[0]] = row_values
                result = torch.vstack((indices, values))
                results[row] = result
            return results

sp = SpareseResults()
sp = sp.to("cuda")
res = sp(mlm)

In [138]:
sp = torch.jit.script(sp)

In [145]:
mlm.dtype

torch.float32

In [140]:
n = torch.compile(sp)

In [143]:
n.save("splade_models/sparse.pt")

In [None]:
torch.jit.save(sp)

In [8]:
tokens = spl_tokenizer(doc_samples, return_tensors="pt")

input_ids = tokens['input_ids'].to("cuda")
attention_mask = tokens["attention_mask"].to("cuda")


In [9]:
%%time
indices = sm(input_ids = input_ids, attention_mask=attention_mask)


CPU times: user 21.1 ms, sys: 1.09 ms, total: 22.1 ms
Wall time: 21.7 ms


In [2]:
onnx_model_path = splade.replace("naver/","splade_onnx/")+'.onnx'




In [4]:
import torch
import onnxruntime as ort

In [9]:
ort_session = ort.InferenceSession(onnx_model_path, providers=["CUDAExecutionProvider"], )

[0;93m2024-02-12 16:25:02.220570132 [W:onnxruntime:, transformer_memcpy.cc:74 ApplyImpl] 257 Memcpy nodes are added to the graph main_graph for CUDAExecutionProvider. It might have negative impact on performance (including unable to run CUDA graph). Set session_options.log_severity_level=1 to see the detail logs before this message.[m
[0;93m2024-02-12 16:25:02.225388700 [W:onnxruntime:, session_state.cc:1166 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.[m
[0;93m2024-02-12 16:25:02.225393861 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.[m


<onnxruntime.capi.onnxruntime_pybind11_state.NodeArg at 0x70e2f5f91d30>

In [6]:
import onnx
onnx_model = onnx.load("splade_onnx/splade_model_coco_ensemb_opt.onnx")
onnx.checker.check_model(onnx_model)

In [7]:
tk = {k:tokens[k].numpy() for k in tokens}