In [1]:
import numpy as np
from texttable import Texttable

from dataset_handler import *
from semantic_search import *

In [2]:
dataset_dir = os.path.join('data', 'dataset')
indices_dir = os.path.join('data', 'indices')

In [3]:
def get_search_results(query_embeddings, index):
    neighbors = index.knnQueryBatch(query_embeddings, k=1, num_threads=2)
    output_ids = [neighbor[0][0] + 1 for neighbor in neighbors]
    match_scores = [neighbor[1][0] for neighbor in neighbors]
    
    return output_ids, match_scores

In [4]:
def get_accuracy(true_ids, output_ids):
    n = len(true_ids)
    acc = np.sum(np.array(true_ids) == np.array(output_ids)) / n
    
    return acc

In [5]:
def print_result(clause_ids, output_ids, match_scores):
    table = Texttable()
    table.add_row(['clause_id', 'output_id', 'match_score'])
    table.add_rows(list(zip(clause_ids, output_ids, match_scores)), header=False)
    print(table.draw())
    
    return

In [6]:
clauses_dict, query_clauses = load_dataset(dataset_dir=dataset_dir)
print(query_clauses.shape)
query_clauses.head()

(49, 4)


Unnamed: 0,id,clause_id,query,clause
0,100,1,Remove any major changes to house before leaving.,The tenant shall at the termination of this ag...
1,200,2,Take license if you're carrying out business.,The tenant shall himself obtain the license fo...
2,201,2,Make sure you're having a license if doing som...,The tenant shall himself obtain the license fo...
3,300,3,I won't provide any insurance or security cover.,All kinds of security arrangements insurances ...
4,400,4,Stay good with neighbors.,The tenant shall keep good relationship with n...


In [7]:
queries = query_clauses['query'].tolist()
clause_ids = query_clauses['clause_id'].tolist()
len(queries), len(clause_ids)

(49, 49)

In [8]:
clauses = load_clauses(dataset_dir=dataset_dir)
len(clauses)

45

In [9]:
tokenizer, model = load_models()
tokenizer, model

(PreTrainedTokenizerFast(name_or_path='sentence-transformers/roberta-base-nli-stsb-mean-tokens', vocab_size=50262, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)}),
 RobertaModel(
   (embeddings): RobertaEmbeddings(
     (word_embeddings): Embedding(50265, 768, padding_idx=1)
     (position_embeddings): Embedding(514, 768, padding_idx=1)
     (token_type_embeddings): Embedding(1, 768)
     (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
     (dropout): Dropout(p=0.1, inplace=False)
   )
   (encoder): RobertaEncoder(
     (layer): ModuleList(
       (0): RobertaLayer(
         (attention): RobertaAttention(
           (self): RobertaSelfAttention(
             (query): Linear(in_features=768, out

In [10]:
clause_embeddings = get_embeddings(clauses, tokenizer, model)
clause_embeddings.shape

torch.Size([45, 768])

In [11]:
create_and_store_index(clause_embeddings, name=os.path.join(indices_dir, 'roberta_base_nli_stsb'))
index = load_index(name=os.path.join(indices_dir, 'roberta_base_nli_stsb'))
index

<nmslib.FloatIndex method='hnsw' space='cosinesimil' at 0x55e9dcd47670>

In [12]:
query_embeddings = get_embeddings(queries, tokenizer, model)
query_embeddings.shape

torch.Size([49, 768])

In [13]:
output_ids, match_scores = get_search_results(query_embeddings, index)
print_result(clause_ids, output_ids, match_scores)

+-----------+-----------+-------------+
| clause_id | output_id | match_score |
+-----------+-----------+-------------+
| 1         | 44        | 0.397       |
+-----------+-----------+-------------+
| 2         | 2         | 0.365       |
+-----------+-----------+-------------+
| 2         | 2         | 0.423       |
+-----------+-----------+-------------+
| 3         | 12        | 0.474       |
+-----------+-----------+-------------+
| 4         | 4         | 0.100       |
+-----------+-----------+-------------+
| 5         | 5         | 0.272       |
+-----------+-----------+-------------+
| 6         | 6         | 0.268       |
+-----------+-----------+-------------+
| 7         | 36        | 0.253       |
+-----------+-----------+-------------+
| 8         | 8         | 0.478       |
+-----------+-----------+-------------+
| 9         | 9         | 0.325       |
+-----------+-----------+-------------+
| 10        | 10        | 0.184       |
+-----------+-----------+-------------+


In [14]:
get_accuracy(clause_ids, output_ids)

0.7346938775510204