In [1]:
import numpy as np
import onnxruntime as ort
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [2]:
#Download model from https://data.vespa.oath.cloud/onnx_models/bpr-question-encoder.onnx

session = ort.InferenceSession("bpr-question-encoder.onnx")

In [3]:
input_dict = {
    'input_ids': np.array([[  101,  2054,  2003,  1996, 13747,  3137,  1999,  1996,  2088,102]]), 
    'token_type_ids': np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 
    'attention_mask': np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
}

In [None]:
vespa_token_input = [2054.0, 2003.0, 1996.0, 13747.0, 3137.0, 1999.0, 1996.0, 2088.0]

In [5]:
out, = session.run(input_feed=input_dict, output_names=['output_0'])

In [6]:
onnx_rt_packed = np.packbits(np.where(out[0][0] > 0, 1,0)).astype(np.int8)

In [7]:
onnx_rt_packed

array([ 103,  118,  -54,   -4, -100,  -11,   86,  -40,   86,  -41,  -26,
         -2,  -82,  -80, -102,  -53,   74,  -75,  -83,  122,   58,  100,
         87,   -5,  115,   39,  -96,   -2,   49,  -22,   59,   69,  -76,
         74,   61,   -2, -119,  -28,  -58,   17,  -12,  -89,   44,  -13,
        -62,  -67,  103,   32,   71,   60,  123,   51, -118,   38,   81,
        -70,  -40,  -95,  -48, -126,   23,  113,  -18,  -79,  -80,   72,
        -18, -102,  118,   64,   58,  -75,  -91,   94,  116, -123,  104,
        -73,   -5,  -89, -115,  -39,  123,  -27,   71,  -14,   50,  126,
        -98,  -30,   85,   31,   45,  -62,   70,  -86], dtype=int8)

This is the encoding as reproduced when using sample code from https://github.com/studio-ousia/bpr
<pre>
query_embeddings = retriever.encode_queries(["what is the tallest mountain in the world"])
</pre>

In [8]:
bpr_encoder_packed = np.array([ 103,  118,  -54,   -4, -100,  -11,   86,  -40,   86,  -41,  -26,
         -2,  -82,  -80, -102,  -53,   74,  -75,  -83,  122,   58,  100,
         87,   -5,  115,   39,  -96,   -2,   49,  -22,   59,   69,  -76,
         74,   61,   -2, -119,  -28,  -58,   17,  -12,  -89,   44,  -13,
        -62,  -67,  103,   32,   71,   60,  123,   51, -118,   38,   81,
        -70,  -40,  -95,  -48, -126,   23,  113,  -18,  -79,  -80,   72,
        -18, -102,  118,   64,   58,  -75,  -91,   94,  116, -123,  104,
        -73,   -5,  -89, -115,  -39,  123,  -27,   71,  -14,   50,  126,
        -98,  -30,   85,   31,   45,  -62,   70,  -86], dtype=np.int8)

In [9]:
bpr_encoder_packed == onnx_rt_packed

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True])

Ok, so let us try with optimizations enabled

In [11]:
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

In [12]:
session = ort.InferenceSession("bpr-query-encoder2.onnx",sess_options)

In [13]:
out, = session.run(input_feed=input_dict, output_names=['output_0'])
onnx_rt_packed_optimized = np.packbits(np.where(out[0][0] > 0, 1,0)).astype(np.int8)

In [14]:
onnx_rt_packed_optimized

array([ 103,  118,  -54,   -4, -100,  -11,   86,  -40,   86,  -41,  -26,
         -2,  -82,  -80, -102,  -53,   74,  -75,  -83,  122,   58,  100,
         87,   -5,  115,   39,  -96,   -2,   49,  -22,   59,   69,  -76,
         74,   61,   -2, -119,  -28,  -58,   17,  -12,  -89,   44,  -13,
        -62,  -67,  103,   32,   71,   60,  123,   51, -118,   38,   81,
        -70,  -40,  -95,  -48, -126,   23,  113,  -18,  -79,  -80,   72,
        -18, -102,  118,   64,   58,  -75,  -91,   94,  116, -123,  104,
        -73,   -5,  -89, -115,  -39,  123,  -27,   71,  -14,   50,  126,
        -98,  -30,   85,   31,   45,  -62,   70,  -86], dtype=int8)

In [15]:
onnx_rt_packed_optimized == bpr_encoder_packed

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True])

http://localhost:8080/search/?query=sddocname:query&tracelevel=3&searchChain=vespa&ranking.features.query(query_token_ids)=[2054.0,%202003.0,%201996.0,%2013747.0,%203137.0,%201999.0,%201996.0,%202088.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0,%200.0]&ranking=question_encoder&restrict=query

In [21]:
vespa_float_cls_tensor = [-0.09260391443967819, 0.09589125216007233, 0.21242474019527435, -0.04614098742604256, -0.09816332161426544, 0.0022295061498880386, 0.25977623462677, 0.19362443685531616, -0.0577840581536293, 0.1469210684299469, 0.09795565903186798, -0.005916787311434746, -0.08090659230947495, 0.14698192477226257, 0.0571487620472908, -0.13508020341396332, 0.061348576098680496, 0.06454582512378693, -0.020564429461956024, 0.013327296823263168, 0.015428952872753143, -0.16201762855052948, 0.15164946019649506, -0.19514136016368866, 0.21722018718719482, 0.050431061536073685, 0.1042608767747879, 0.21470504999160767, 0.040510982275009155, 0.2714185118675232, -0.13171374797821045, -0.01953350007534027, 0.059845030307769775, -0.02458948642015457, -0.03563278540968895, 0.04487576708197594, 0.13419713079929352, 0.08207549154758453, -0.03494885191321373, -0.09698593616485596, 0.3261881172657013, -0.010116110555827618, -0.0578533411026001, 0.09321344643831253, -0.11521147191524506, 0.10882143676280975, -1.010148286819458, 0.26237091422080994, -0.048911720514297485, 0.009055964648723602, -0.07108787447214127, 0.16291537880897522, -0.041095323860645294, 0.0012842994183301926, 0.13866513967514038, -0.07895149290561676, 0.04259653016924858, 0.3193230628967285, -0.04639390856027603, -0.01963479444384575, 0.10477026551961899, 0.04003889858722687, -0.08939782530069351, -0.05315198749303818, -0.11286215484142303, 0.06553882360458374, -0.07896851748228073, 0.07460647821426392, -0.1286059319972992, 0.2446172684431076, 0.02571595087647438, -0.057992592453956604, 0.05339517444372177, 0.2022450715303421, -0.12797623872756958, 0.19967642426490784, -0.14365379512310028, 0.24453739821910858, -0.020549222826957703, 0.007649587467312813, 0.11799938976764679, -0.043237701058387756, 0.02701548859477043, 0.009973357431590557, -0.10037989169359207, 0.16931360960006714, 0.13949066400527954, -0.04259189963340759, 0.039685383439064026, 0.2693621516227722, -0.019999029114842415, 0.06925217807292938, 0.12707862257957458, 0.13505995273590088, 0.1656278371810913, -0.23442891240119934, 0.09819380193948746, -0.04803745448589325, 0.3211420476436615, -0.103200763463974, 0.1088438481092453, 0.08143369853496552, 0.21309879422187805, -0.009197831153869629, 0.0225190669298172, -0.10090905427932739, 0.0452352911233902, 0.11151620745658875, -0.09008897840976715, -2.964568853378296, -0.27820324897766113, 0.059886496514081955, 0.027348121628165245, -0.08743882924318314, 0.07554551959037781, 0.30660775303840637, 0.42628368735313416, -0.0056546591222286224, 0.14483177661895752, -0.12268850207328796, 0.20162233710289001, 0.6552745699882507, -0.2734993100166321, -0.13707345724105835, 0.10907363146543503, -0.019954536110162735, 0.08651627600193024, 0.08239337801933289, -0.1556563824415207, 0.11532501876354218, -0.1929844617843628, -0.10569953918457031, 0.16933532059192657, -0.031564801931381226, 0.1599547564983368, -0.145606130361557, 0.10939224064350128, -0.16906075179576874, 0.009226318448781967, 0.06946643441915512, -0.09436118602752686, 0.08933039754629135, -3.399688959121704, 0.14799463748931885, 0.18573321402072906, -0.01618058979511261, 0.0995749980211258, -0.14707837998867035, 0.11239468306303024, 0.10110116750001907, -0.07808751612901688, 0.09183278679847717, -0.19675502181053162, 0.14690743386745453, 0.08030713349580765, 0.1514114886522293, 0.22655093669891357, 0.06690926849842072, 0.030970923602581024, -0.13069367408752441, 0.046533383429050446, -0.022184375673532486, 0.09195462614297867, -0.006205245852470398, 0.16636697947978973, -0.19082200527191162, 0.0952625721693039, -0.18309977650642395, -0.18764472007751465, -0.02726544812321663, 0.03730660676956177, -0.12031123042106628, -0.15611138939857483, 0.12438170611858368, -0.0917474627494812, -0.20210915803909302, -0.055290587246418, 0.08981101214885712, -0.16422665119171143, 0.05147716403007507, -0.008933499455451965, 0.2517159879207611, 0.007549408823251724, 0.18720273673534393, 0.05067440867424011, 0.10908827185630798, 0.2904185652732849, 0.09176517277956009, 0.05896022915840149, -0.26033446192741394, 0.05770315229892731, 0.02749565988779068, -0.045660339295864105, 0.01725747622549534, 0.32759931683540344, 0.07115371525287628, -0.196461021900177, -0.15418460965156555, 0.034262701869010925, 0.2406788468360901, -0.06697729974985123, -0.18513698875904083, 0.15222302079200745, -0.018652789294719696, -0.16288313269615173, 3.5397963523864746, 0.047475822269916534, -0.029598545283079147, 0.21041342616081238, -0.10810661315917969, 0.08022855967283249, -0.1220695972442627, -0.05542696639895439, 0.06272751092910767, -0.10783660411834717, -0.06665987521409988, 0.08372732251882553, 0.03777021914720535, 0.16203203797340393, 0.06612756103277206, 0.17296920716762543, 0.1813896745443344, 0.005956074222922325, -0.0967722237110138, -0.0936691090464592, -0.060441579669713974, 0.14334669709205627, 0.09660568088293076, -0.09458868205547333, -0.4472626745700836, -0.04547855630517006, 0.041651442646980286, 0.03151153028011322, 0.24052341282367706, 0.11923091113567352, -0.15489743649959564, 0.046028900891542435, -0.07901570200920105, 0.46063822507858276, -0.1302708238363266, -0.15184596180915833, -0.1766522079706192, 0.001175716519355774, 0.20589812099933624, 0.06117139384150505, -0.04783254861831665, 0.3234087824821472, 0.15599209070205688, -0.3401312232017517, 0.17274488508701324, -0.1529773771762848, -0.06629449874162674, -0.09689021110534668, 0.14565058052539825, -0.10706362873315811, 0.02130233496427536, 0.057746391743421555, -0.20963287353515625, 0.10481417924165726, 0.08598046749830246, -0.006778445094823837, -0.04849717393517494, -0.002570774406194687, -0.07070109993219376, 0.0033854953944683075, 0.04048818349838257, -0.03480708599090576, -0.11402352154254913, 0.2319347858428955, -0.06295306980609894, 0.07707850635051727, -0.10404856503009796, 0.043426260352134705, -4.724329471588135, 0.08771122246980667, 0.05875295400619507, 0.11834660172462463, 0.10960329324007034, -0.20537850260734558, 0.014008618891239166, 0.1192534863948822, 0.22244328260421753, 0.13085076212882996, 0.1938590705394745, 0.09972349554300308, -0.014157291501760483, 0.2983853816986084, -0.2204110026359558, 0.04608481377363205, -0.037306685000658035, -0.007790294475853443, -0.02639736235141754, 0.05394582450389862, -0.12494518607854843, -0.03887142241001129, 0.09672394394874573, 0.169552281498909, 0.10999514907598495, 0.023979581892490387, -0.08255065977573395, -0.13196642696857452, 0.023779021576046944, -0.07003059983253479, 0.056901965290308, 0.07635859400033951, 0.049644459038972855, -0.07043056190013885, 0.008377042599022388, -3.8319551944732666, 0.06807567179203033, 0.14672136306762695, -0.18802712857723236, -0.16151760518550873, 0.03524532914161682, -0.0986337661743164, 0.04190768301486969, -0.17881789803504944, -0.07331064343452454, 0.004497811198234558, -0.042436957359313965, 0.05011773481965065, 0.05396943539381027, 0.16991108655929565, 0.1332002729177475, -0.018424615263938904, 0.017033834010362625, -0.14435794949531555, 0.02471233531832695, 0.11985760182142258, -0.08728329837322235, 0.06381426006555557, -0.039797235280275345, -0.07781567424535751, 0.16644003987312317, -0.05759374797344208, 0.06960498541593552, -0.10186906903982162, -0.15559545159339905, 0.17665408551692963, -0.14355453848838806, 0.02464250475168228, 0.18078447878360748, 0.055751826614141464, -0.09287013858556747, 0.12496763467788696, 0.11202673614025116, 0.30275899171829224, 0.04122347757220268, -0.01570598967373371, -0.04713970795273781, 0.13260185718536377, 0.1070210188627243, 0.3097829222679138, -0.024926267564296722, -0.0903494656085968, -0.09552019834518433, -0.07847079634666443, -0.03517051413655281, 0.12309668213129044, 0.019575342535972595, 0.9876673221588135, -0.07435482740402222, 0.11128734052181244, 0.12167977541685104, 0.2279254049062729, 0.18554288148880005, -0.0657505989074707, 0.14152726531028748, -0.11927416920661926, 0.20808571577072144, 0.059014804661273956, -0.27570903301239014, 0.04008174687623978, -0.02461499720811844, -0.0035793185234069824, 0.10887035727500916, -0.08042659610509872, -0.1290692389011383, 0.12982527911663055, 0.058926306664943695, -0.07350589334964752, -0.09034866094589233, -0.0330803208053112, -0.2037447690963745, 0.04903719946742058, -0.005954273045063019, -0.19483613967895508, -0.12785013020038605, -0.2806346118450165, 0.034609146416187286, 0.09210485219955444, 0.12370320409536362, -0.024112962186336517, -0.08130928874015808, -0.05252479761838913, 0.13000911474227905, 0.17019885778427124, 0.13797327876091003, 0.005850182846188545, -0.16243796050548553, -0.03424453362822533, -0.023659592494368553, 0.332891583442688, 0.0021173283457756042, 0.06294301897287369, -0.3099365830421448, 0.2125529795885086, 0.0670425221323967, 0.01794932782649994, -0.01564336009323597, 0.15104757249355316, 0.026846513152122498, -0.11467166990041733, -0.16062766313552856, -0.008904796093702316, 0.2472052276134491, 0.14941571652889252, -0.14336717128753662, -0.0474875383079052, 0.022629493847489357, 0.043266892433166504, -0.11785262823104858, 0.7521278262138367, 0.0023015476763248444, -0.13927781581878662, -0.07951948046684265, 0.10987243801355362, -0.14803074300289154, -0.0917365625500679, 0.09290768951177597, 0.01185966283082962, -0.1835382580757141, -0.12200754880905151, 0.08121014386415482, -0.1992967277765274, 0.13080470263957977, -0.20277372002601624, -0.025698266923427582, -0.10249563306570053, 0.0532715767621994, 0.16297924518585205, -0.3343513011932373, 0.058987557888031006, 0.1238107830286026, 0.12833775579929352, -0.11993405222892761, 0.11858583986759186, -0.0009658858180046082, 0.20207957923412323, 0.09975779056549072, -0.11352461576461792, 0.4505535364151001, 0.2798716127872467, -0.11140313744544983, -0.1520804911851883, 0.014625165611505508, 0.1361844837665558, -0.02500903606414795, -0.042276859283447266, -0.0536951869726181, -0.12383407354354858, -0.023282332345843315, 0.0720507949590683, 0.2816048264503479, 0.14635038375854492, 0.19177396595478058, 0.025820191949605942, -0.02209986373782158, 0.030616868287324905, -0.1328410506248474, -0.7023516893386841, -0.034257154911756516, -0.014234527945518494, -0.01196136325597763, -0.016605868935585022, -0.04308260977268219, -0.016224127262830734, -0.12332071363925934, 0.2814399302005768, -0.0860934779047966, -0.22393254935741425, -0.038413096219301224, -0.1344522088766098, 0.16330277919769287, -0.09319069236516953, 0.10578152537345886, 0.24894562363624573, 0.17435382306575775, 0.013369807042181492, 0.1548963189125061, 0.06639961898326874, 0.33721473813056946, -0.22020190954208374, -0.0263831689953804, -0.10054904222488403, 0.17592252790927887, 0.0030411742627620697, 0.1440598964691162, 0.10014630854129791, -0.26280009746551514, 0.08357219398021698, -0.02680831402540207, 0.12832222878932953, 0.05822794884443283, 0.11399063467979431, -0.11243091523647308, 0.051253825426101685, 0.07304824888706207, -0.17112381756305695, -0.05557689443230629, -0.0382172130048275, 0.07970664650201797, 0.13745081424713135, -0.11511969566345215, 0.34776973724365234, 0.0013090521097183228, -0.10678171366453171, -0.10754184424877167, -0.03680908679962158, -0.11517737805843353, -0.06648803502321243, 0.0385674424469471, -0.006416200660169125, 0.009098835289478302, -0.012128714472055435, -0.040153618901968, 0.00547989085316658, -0.08469443768262863, 0.2498491257429123, 0.07882604748010635, 0.011398494243621826, -0.16152513027191162, 0.12611845135688782, 0.45807787775993347, -0.0687701553106308, -0.0454833060503006, 0.14332722127437592, -0.2586124837398529, -0.07269644737243652, 0.028274694457650185, 0.014544717967510223, -0.06326843053102493, 0.04953804612159729, 0.026799358427524567, -0.04919631406664848, 0.02290835976600647, -0.007035437971353531, 0.1975128948688507, -0.06952586770057678, 0.13582096993923187, 0.04226251319050789, -0.18306797742843628, -0.17523817718029022, 0.2390468567609787, -0.08613499999046326, -0.014560339972376823, 0.03554582968354225, -0.21383002400398254, -0.08761914819478989, 0.09325301647186279, -0.16551825404167175, 0.012385450303554535, 0.39200371503829956, 0.026104405522346497, 0.041204825043678284, -0.0811271220445633, 0.018162909895181656, -0.10059411823749542, 0.11126667261123657, -0.058525439351797104, -0.01384500041604042, 0.19287627935409546, 0.06926106661558151, -0.02889631688594818, -0.24038714170455933, 0.05199429392814636, 0.08908367902040482, -0.04986542835831642, 0.037662290036678314, 0.026423979550600052, 0.009480863809585571, 0.2232363522052765, -0.008996419608592987, 0.1566999852657318, -0.016166921705007553, 0.1188608705997467, -0.11231216788291931, 0.14666584134101868, 2.51878023147583, 0.15585148334503174, 0.0554804801940918, -0.1365218162536621, -0.09763389825820923, 0.045212697237730026, 0.1139029860496521, 0.07191610336303711, -0.05138470605015755, 0.13530945777893066, -0.2148897349834442, 0.0018538348376750946, -0.1255156248807907, -0.14137017726898193, -0.08672071993350983, -0.05282869189977646, 0.002223718911409378, 0.13166718184947968, -0.032217759639024734, 0.09405176341533661, -0.016005437821149826, 0.11101282387971878, 0.17081227898597717, -0.24205616116523743, 0.12085014581680298, -0.017135102301836014, -0.15541934967041016, -0.2158358097076416, 0.13087555766105652, -0.2549808621406555, 0.10793474316596985, 0.1406404674053192, -0.07237599790096283, 0.2281767874956131, 0.132211372256279, 0.03219960629940033, 0.06858889013528824, 0.15230831503868103, 0.09438962489366531, 0.18310034275054932, 0.05054670572280884, 0.07559661567211151, 0.03571196272969246, 0.04292334243655205, 0.09503374993801117, 0.01143386960029602, 0.052276287227869034, -0.1682429015636444, -0.04236971586942673, 0.22122381627559662, -0.06667859852313995, -0.06371711194515228, -0.03640083223581314, -0.003083576448261738, -0.08543500304222107, -0.32224974036216736, -0.05161238834261894, 0.11770026385784149, -0.06602110713720322, 0.25934213399887085, 0.21215009689331055, 0.04266522079706192, -0.0799938440322876, 0.06520034372806549, 0.14988946914672852, -0.1145656406879425, -0.10210753977298737, 0.017618805170059204, -0.031066114082932472, 0.27783724665641785, 0.03729860484600067, 0.13934554159641266, -0.033291399478912354, -0.04513583332300186, 0.05115635320544243, 0.26040977239608765, -0.001785092055797577, -0.05910230055451393, 0.03586268424987793, -3.4086198806762695, -0.02737981267273426, 0.09774677455425262, -0.017572371289134026, 0.03859999030828476, -0.09794627130031586, 0.2378385365009308, -0.06944010406732559, -0.004423182457685471, -0.04711540415883064, 0.19552381336688995, 0.12733332812786102, 0.09700190275907516, 0.030968017876148224, 0.056188128888607025, 0.02051575481891632, 0.22856026887893677, -0.299394816160202, 0.03377850353717804, 0.09814099222421646, -0.21119776368141174, -0.06467742472887039, 0.030762087553739548, 0.09838669002056122, 0.021888965740799904, 0.00665653683245182, -0.06396031379699707, 0.09961600601673126, -0.011195443570613861, -0.13691335916519165, 0.12073533982038498, 0.03988261520862579, -0.035032521933317184, 0.02244553714990616, 0.13529717922210693, 0.09776811301708221, 0.04201053828001022, 0.11066734790802002, 0.02868419699370861, -0.13093703985214233, 0.04278012737631798, 0.1975778490304947, -0.0999005138874054, 0.019111257046461105, -0.10779550671577454, 0.042171910405159, 0.15418961644172668, 0.18657580018043518, -0.14580248296260834, -0.04038548469543457, -0.2212972640991211, -0.005243539810180664, -0.1287483274936676, -0.024056918919086456, 0.07094519585371017, -0.02272043190896511, 0.0674278736114502, -0.11763018369674683, 0.0663069486618042, -0.011795202270150185, -0.06854239106178284, -0.158526211977005, -0.12321586161851883, -0.012642555870115757, 0.17990009486675262, 0.07283437997102737, 0.007345527410507202, 0.01567596197128296, 0.10693997144699097, -0.24579988420009613, -0.11011457443237305, 0.05615675449371338, -0.06260638684034348, -0.03848991543054581, 0.191745787858963, -0.04179578274488449, 0.13549813628196716, 0.327109694480896, 0.17386959493160248, -0.14949959516525269, -0.29855877161026, -0.24083510041236877, 0.027013003826141357, -0.053055278956890106, -0.09138717502355576, -8.004268646240234, -0.005018599331378937, -0.07175980508327484, -0.2903030514717102, -0.15034055709838867, 0.23229360580444336, 0.08595450222492218, -0.10982118546962738, 0.16416652500629425, -0.07101142406463623, 0.23411497473716736, -0.017857130616903305, 0.08447383344173431, 0.006400817073881626, 0.2419111728668213, 0.03781512379646301]

In [26]:
vespa_embedding_tensor = np.array(vespa_float_cls_tensor)
vespa_embedding_tensor

array([-9.26039144e-02,  9.58912522e-02,  2.12424740e-01, -4.61409874e-02,
       -9.81633216e-02,  2.22950615e-03,  2.59776235e-01,  1.93624437e-01,
       -5.77840582e-02,  1.46921068e-01,  9.79556590e-02, -5.91678731e-03,
       -8.09065923e-02,  1.46981925e-01,  5.71487620e-02, -1.35080203e-01,
        6.13485761e-02,  6.45458251e-02, -2.05644295e-02,  1.33272968e-02,
        1.54289529e-02, -1.62017629e-01,  1.51649460e-01, -1.95141360e-01,
        2.17220187e-01,  5.04310615e-02,  1.04260877e-01,  2.14705050e-01,
        4.05109823e-02,  2.71418512e-01, -1.31713748e-01, -1.95335001e-02,
        5.98450303e-02, -2.45894864e-02, -3.56327854e-02,  4.48757671e-02,
        1.34197131e-01,  8.20754915e-02, -3.49488519e-02, -9.69859362e-02,
        3.26188117e-01, -1.01161106e-02, -5.78533411e-02,  9.32134464e-02,
       -1.15211472e-01,  1.08821437e-01, -1.01014829e+00,  2.62370914e-01,
       -4.89117205e-02,  9.05596465e-03, -7.10878745e-02,  1.62915379e-01,
       -4.10953239e-02,  

In [23]:
vespa_pack = np.packbits(np.where(vespa_embedding_tensor > 0, 1,0)).astype(np.int8)

In [24]:
vespa_pack

array([ 103,  102,  -38,   -4, -100, -107,   86,  -52,   86,  -43,  -74,
        -34,  -82,  -79,  -70,  -53,   74,  -75,  -83,  126,  -86,   36,
         87,   -5,  115,   38,  -92,   -2,   49,  -22,   59,   69,  -80,
        -54,  -67,   -6, -119,  -27,  -42,   82,  -11,  -91,   46,  -13,
       -125,  -67,  105,   48, -121,   30,   59,  -79, -101,   38,   81,
        -70,  -39, -125,  -24,    2,   23,  -15,  -21,  -79,  -80,   82,
        -20, -101,   86,   73,  122, -103,  -67,   94,  117,   13,  104,
        -73,   -1,  -28,    5,  -39,  115,   37,   71,  -10,  122,  111,
        -38,  -32,   84,   31,   37,  -60,    6,  -81], dtype=int8)

In [25]:
onnx_rt_packed_optimized

array([ 103,  118,  -54,   -4, -100,  -11,   86,  -40,   86,  -41,  -26,
         -2,  -82,  -80, -102,  -53,   74,  -75,  -83,  122,   58,  100,
         87,   -5,  115,   39,  -96,   -2,   49,  -22,   59,   69,  -76,
         74,   61,   -2, -119,  -28,  -58,   17,  -12,  -89,   44,  -13,
        -62,  -67,  103,   32,   71,   60,  123,   51, -118,   38,   81,
        -70,  -40,  -95,  -48, -126,   23,  113,  -18,  -79,  -80,   72,
        -18, -102,  118,   64,   58,  -75,  -91,   94,  116, -123,  104,
        -73,   -5,  -89, -115,  -39,  123,  -27,   71,  -14,   50,  126,
        -98,  -30,   85,   31,   45,  -62,   70,  -86], dtype=int8)

In [27]:
vespa_pack == onnx_rt_packed_optimized

array([ True, False, False,  True,  True, False,  True, False,  True,
       False, False, False,  True, False, False,  True,  True,  True,
        True, False, False, False,  True,  True,  True, False, False,
        True,  True,  True,  True,  True, False, False, False, False,
        True, False, False, False, False, False, False,  True, False,
        True, False, False, False, False, False, False, False,  True,
        True,  True, False, False, False, False,  True, False, False,
        True,  True, False, False, False, False, False, False, False,
       False,  True, False, False,  True,  True, False, False, False,
        True, False, False,  True, False, False, False, False, False,
       False,  True, False, False, False, False])