# Text search app with BERT models from python

> Introducing pyvespa ML API. Cutting-edge Vespa search with few lines of code.

- toc: true 
- badges: false
- comments: true
- categories: [vespa, pyvespa, BERT, transformers]

## Define the application

Start with a basic text search app

In [1]:
from vespa.package import ApplicationPackage, Field, FieldSet, RankProfile

app_package = ApplicationPackage(name="cord19")
app_package.schema.add_fields(
    Field(name = "cord_uid", type = "string", indexing = ["attribute", "summary"]),
    Field(name = "title", type = "string", indexing = ["index", "summary"], index = "enable-bm25")
)
app_package.schema.add_field_set(
    FieldSet(name = "default", fields = ["title"])
)
app_package.schema.add_rank_profile(
    RankProfile(name = "bm25", first_phase = "bm25(title)")
)

Define your Vespa model configuration

In [2]:
from vespa.ml import BertModelConfig

bert_config = BertModelConfig(
    model_id="pretrained_bert_tiny",
    tokenizer="google/bert_uncased_L-2_H-128_A-2",
    model="google/bert_uncased_L-2_H-128_A-2",    
    query_input_size=32,
    doc_input_size=96
)

Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification w

Create your model based rank profile

In [3]:
from vespa.package import SecondPhaseRanking

app_package.add_model_ranking(
    model_config=bert_config,
    inherits="default",
    first_phase="bm25(title)",
    second_phase=SecondPhaseRanking(
        rerank_count=10, expression="logit1"
    ),
)

Using framework PyTorch: 1.7.1
Found input input_ids with shape: {0: 'batch', 1: 'sequence'}
Found input token_type_ids with shape: {0: 'batch', 1: 'sequence'}
Found input attention_mask with shape: {0: 'batch', 1: 'sequence'}
Found output output_0 with shape: {0: 'batch'}
Ensuring inputs are in correct order
position_ids is not present in the generated input list.
Generated inputs order: ['input_ids', 'attention_mask', 'token_type_ids']


  position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  assert all(


## Deploy your application

In [6]:
import os
from vespa.package import VespaDocker

vespa_docker = VespaDocker(
    port=8080, 
    disk_folder="/Users/username/sample_application", # need absolute path
    container_memory="8G"
)

app = vespa_docker.deploy(
    application_package = app_package,
)

Waiting for configuration server.
Waiting for configuration server.
Waiting for configuration server.
Waiting for configuration server.
Waiting for configuration server.
Waiting for configuration server.
Waiting for configuration server.
Waiting for configuration server.
Waiting for configuration server.
Waiting for configuration server.
Waiting for configuration server.
Waiting for configuration server.
Waiting for configuration server.
Waiting for configuration server.
Waiting for configuration server.
Waiting for application status.
Waiting for application status.
Finished deployment.


## Feed some data

In [20]:
from pandas import read_csv

parsed_feed = read_csv("https://thigm85.github.io/data/cord19/parsed_feed_100.csv")
parsed_feed = parsed_feed.head(100)

In [21]:
parsed_feed

Unnamed: 0,cord_uid,title,abstract
0,ug7v899j,Clinical features of culture-proven Mycoplasma...,OBJECTIVE: This retrospective chart review des...
1,02tnwd4m,Nitric oxide: a pro-inflammatory mediator in l...,Inflammatory diseases of the respiratory tract...
2,ejv2xln0,Surfactant protein-D and pulmonary host defense,Surfactant protein-D (SP-D) participates in th...
3,2b73a28n,Role of endothelin-1 in lung disease,Endothelin-1 (ET-1) is a 21 amino acid peptide...
4,9785vg6d,Gene expression in epithelial cells in respons...,Respiratory syncytial virus (RSV) and pneumoni...
...,...,...,...
95,63bos83o,Global Surveillance of Emerging Influenza Viru...,BACKGROUND: Effective influenza surveillance r...
96,hqc7u9w3,Transmission Parameters of the 2001 Foot and M...,"Despite intensive ongoing research, key aspect..."
97,87zt7lew,Efficient replication of pneumonia virus of mi...,Pneumonia virus of mice (PVM; family Paramyxov...
98,wgxt36jv,Designing and conducting tabletop exercises to...,"BACKGROUND: Since 2001, state and local health..."


In [8]:
for idx, row in parsed_feed.iterrows():
    fields = {
        "cord_uid": str(row["cord_uid"]),
        "title": str(row["title"]),
    }
    fields.update(
        bert_config.doc_fields(text = str(row["title"]))
    )
    response = app.feed_data_point(
        schema = "cord19",
        data_id = str(row["cord_uid"]),
        fields = fields,
    )

In [9]:
response.json()

{'pathId': '/document/v1/cord19/cord19/docid/qbldmef1',
 'id': 'id:cord19:cord19::qbldmef1'}

## Query your application

In [15]:
from vespa.query import QueryModel, RankProfile as Ranking, OR, QueryRankingFeature

result = app.query(
    query="this is a test", 
    query_model=QueryModel(
        query_properties=[
            QueryRankingFeature(  # this part makes sure you send the right query vector
                name=bert_config.query_token_ids_name,
                mapping=bert_config.query_tensor_mapping,
            )
        ],
        match_phase = OR(),
        rank_profile = Ranking(name="pretrained_bert_tiny")
    )
    
)

In [16]:
result.json

{'root': {'id': 'toplevel',
  'relevance': 1.0,
  'fields': {'totalCount': 26},
  'coverage': {'coverage': 100,
   'documents': 100,
   'full': True,
   'nodes': 1,
   'results': 1,
   'resultsFull': 1},
  'children': [{'id': 'id:cord19:cord19::l3z27806',
    'relevance': -0.011664621531963348,
    'source': 'cord19_content',
    'fields': {'sddocname': 'cord19',
     'documentid': 'id:cord19:cord19::l3z27806',
     'cord_uid': 'l3z27806',
     'title': 'GIDEON: a comprehensive Web-based resource for geographic medicine',
     'pretrained_bert_tiny_doc_token_ids': {'cells': [{'address': {'d0': '0'},
        'value': 12137.0},
       {'address': {'d0': '1'}, 'value': 1024.0},
       {'address': {'d0': '2'}, 'value': 1037.0},
       {'address': {'d0': '3'}, 'value': 7721.0},
       {'address': {'d0': '4'}, 'value': 4773.0},
       {'address': {'d0': '5'}, 'value': 1011.0},
       {'address': {'d0': '6'}, 'value': 2241.0},
       {'address': {'d0': '7'}, 'value': 7692.0},
       {'address

Note that you can inspect what is being sent to Vespa behind the scenes by selecting `debug_request=True` and later accessing the query sent via `result.request_body`. 

In [17]:
debug_result = app.query(
    debug_request=True,
    query="this is a test", 
    query_model=QueryModel(
        query_properties=[
            QueryRankingFeature(  # this part makes sure you send the right query vector
                name=bert_config.query_token_ids_name,
                mapping=bert_config.query_tensor_mapping,
            )
        ],
        match_phase = OR(),
        rank_profile = Ranking(name="pretrained_bert_tiny")
    )
    
)

In [18]:
debug_result.request_body

{'yql': 'select * from sources * where ([{"grammar": "any"}]userInput("this is a test"));',
 'ranking': {'profile': 'pretrained_bert_tiny', 'listFeatures': 'false'},
 'ranking.features.query(pretrained_bert_tiny_query_token_ids)': '[2023, 2003, 1037, 3231, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]'}