In [None]:
import re
import json
import numpy as np
import pandas as pd
import itertools

from sentence_transformers import SentenceTransformer
from sentence_transformers import util
from functools import reduce
from underthesea import sent_tokenize
from tqdm.notebook import tqdm
from functools import reduce

In [None]:
# MODEL_NAME = "keepitreal/vietnamese-sbert"
MODEL_NAME = "Oztobuzz/sbert_mnr_5_epoch_v1"

In [None]:
model = SentenceTransformer(MODEL_NAME).cuda()

In [None]:
INPUT_DIR = '../data/private/ise-dsc01-private-test-offcial.json'

In [None]:
dict_data = json.load(open(INPUT_DIR))
data = [(dict_data[key], key) for key in dict_data.keys()]
index = [x[1] for x in data]
context = [x[0]['context'] for x in data]
claim = [x[0]['claim'] for x in data]

In [None]:
def embed_claim(claim):
    """
    A function to create embedding for claims
    
    Parameters
    ----------
    claim: List[str]
    
    Returns
    ----------
    np.array
    """
    return model.encode(claim, batch_size=64, show_progress_bar=True)

In [None]:
def split_context(context, split_type):
    """
    A function to preprocess context
    
    Parameters
    ----------
    context: List[str]
    split_type: PARAGRAPH or SENTENCE
        Use 'PARAGRAPH' to split the context with \n\n
        Use 'SENTENCE' to split the context with .
        
    Returns
    ----------
    List[List[str]]
    """
    if split_type == 'PARAGRAPH':
        return [x.split('\n\n') for x in context]
    elif split_type == 'SENTENCE':
        context = [sent_tokenize(x) for x in context]
        context = [[re.split(r'[.][.]+',x) for x in y] for y in context]
        context = [list(itertools.chain(*x)) for x in context]
        return context
    else:
        raise Exception("Please use PARAGRAPH or SENTENCE for split_type")

In [None]:
def embed_context(context):
    """
    A function to create embedding of context or eidence
    
    Parameters:
    ----------
    context: List[List[str]]
        List of context
    Returns
    ---------
    context_embeddng: List[np.array()] with dim = (number of paragraph/array in each context, embedding size)
    """
    # Create length list
    context_length = [len(x) for x in context]
    
    # Embed flatten context
    flatten_context_embedding = model.encode(
        list(itertools.chain(*context)),
        batch_size=64,
        show_progress_bar=True
    )
    
    # Remap context
    context_embedding = list()
    start_index = 0
    for length in context_length:
        context_embedding.append(flatten_context_embedding[start_index:start_index+length])
        start_index += length
    
    return context_embedding

In [None]:
def retrieve(claim_embedding, context_embedding, context, top_k=1, threshold=None):
    """
    A function to retrieve relevant contexts with respect to claims
    
    Parameters
    ----------
    claim: List[str]
        List of claims
    context: List[np.array()]
        List of context embedding
    top_k: int
        Top result to return
    threshold: float
        Threshold value to get
        
    Returns
    ----------
    retrieve_result: List[Dict]
    """
    
    retrieve_result = list()
    for i in tqdm(range(len(claim))):
        hits = util.semantic_search(claim_embedding[i], context_embedding[i], top_k=top_k)
        hits = hits[0]
        temp=list()
        for hit in hits:
            if (threshold==None) or (threshold and hit['score']>=threshold):
                row_evidence = context[i][hit['corpus_id']]
                if len(row_evidence) > 10:
                    temp.append({'evidence': row_evidence, 'score': round(hit['score'],4), 'id': hit['corpus_id']})
        if len(temp) == 0:
            print("empty case")
            temp.append({'evidence': '', 'score': 1.0, 'id': -1})
        retrieve_result.append(temp)
    return retrieve_result

In [None]:
def add_comma(text):
    """
    A function to add comma at the end of a sentence
    
    Parameters
    ----------
    text: str
    
    Returns
    ----------
    str
    """
    try:
        if text.strip()[-1] != '.':
            text+='.'
        return text
    except:
        print(text)
        return text

In [None]:
claim_embedding = embed_claim(
    claim=claim
)

# Get Top 5 Paragraphs

paragraph = split_context(
    context=context,
    split_type='PARAGRAPH'
)

context_embedding = embed_context(
    context=paragraph
)


paragraph_result = retrieve(
    claim_embedding=claim_embedding,
    context_embedding=context_embedding,
    context=paragraph,
    top_k=5,
    threshold=None,
)

# Get Top 5 Sentences

sentence = split_context(
    context=[" ".join([add_comma(x['evidence']) for x in y]) for y in paragraph_result],
    split_type='SENTENCE'
)

sentence_embedding = embed_context(
    context=sentence
)

sentence_result = retrieve(
    claim_embedding=claim_embedding,
    context_embedding=sentence_embedding,
    context=sentence,
    top_k=5,
    threshold=None,
)

# Append to original dict
retrieve_result = json.load(open(INPUT_DIR))
for sample_order, sample_number in enumerate(index):
    retrieve_result[sample_number]['evidence'] = [x['evidence'] for x in sentence_result[sample_order]]

In [None]:
# Save New Result
with open("../result/retrieve_result/private_test_retrieval_v1_top5_top_5.json", "w") as outfile:
    json.dump(retrieve_result, outfile)