In [1]:
import torch
import pandas as pd
import numpy as np
from itertools import combinations
from nltk.tokenize import sent_tokenize
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

## Load NLI Model

In [2]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_ckpt = "tals/albert-xlarge-vitaminc-mnli"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt, use_fast=False)
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt).to(device)

In [3]:
def split_sentences(text):
    sentences = sent_tokenize(text)
    sentences = [sent for sent in sentences if len(sent)>10]
    return sentences

def is_consecutive_by_one(numbers):
    for i in range(1, len(numbers)):
        if abs(numbers[i] - numbers[i-1]) != 1:
            return False
    return True

In [4]:
doc_sample = "looking after elderly parents can be difficult at the best of times . but lu xincai, this man takes caring for lu xincai's alzheimer ' s - suffering mother to another level . lu xincai, a security guard from china has touched hearts across the country because lu xincai takes lu xincai's 84 - year - old mother with lu xincai to work on the back of lu xincai's motorbike every single day , reported the people ' s daily online . lu xincai, lu xincai , who lives in zhejiang province in eastern china , says that lu xincai is scared lu xincai's mother will get lost if lu xincai leaves her at home by herself because she suffers from the degenerative disease . devoted : lu xincai, lu xincai takes lu xincai's 84 - year - old mother to work with lu xincai on the back of lu xincai's motorbike every day . lu xincai ties a sash around both of their waists to make sure she does n ' t fall off she would often go up to the mountains to collect firewood and there were a few occasions when she got lost after dark . when mr lu ' s father passed away earlier this year , lu xincai decided to take lu xincai's mother with lu xincai to work because there was no one else who could look after her . lu xincai's wife works in a different city and lu xincai's son is still in school . after helping lu xincai's mother to get up at 5 am every morning , lu xincai puts her on the back seat of lu xincai's motorbike and ties a sash around both of their waists to ensure that she does not fall off . mr lu said that lu xincai rides the four kilometres to work slowly to make sure lu xincai's mother feels safe and so that they can chat along the way . the whole journey takes an hour . even when at work lu xincai checks up on lu xincai's mother , who has been given her own room by lu xincai's employers , a bank , to make sure that she has not wandered off somewhere . lu xincai said that lu xincai's mother devoted her life to caring for her children , and now lu xincai feels like lu xincai has a duty to care for her in return . vulnerable : lu xincai's elderly mother suffers from alzheimer ' s and used to get lost when she was left alone lu xincai said : ` lu xincai was an apple in lu xincai's mum ' s eye , and now she ' s lu xincai's apple . ' ` our mother carried us on her back to the fields when she went to work on the farm and collect firewood when we were young . ' lu xincai added : ` only if lu xincai see her will lu xincai feel relaxed . otherwise lu xincai would be afraid is she had wandered away . '"
summary_sample = "lu xincai takes Lu Xincai's 84 - year - old mother to work with Lu Xincai on the back of Lu Xincai's motorbike every day . Lu Xincai's mother suffers from alzheimer ' s and used to get lost when she was left alone . Lu Xincai ties a sash around both of their waists to ensure that she does not fall off ."
atomic_facts_sample = "Lu Xincai has a 84 - year - old mother. Lu Xincai takes his mother to work with him. Lu Xincai's mother works with him on the back of his motorbike. Lu Xincai uses a motorbike to take his mother to work. Lu Xincai has a mother. Lu Xincai's mother suffers from Alzheimer's. Lu Xincai's mother used to get lost when she was left alone. Lu Xincai ties a sash. The sash is around both of their waists. The purpose of the sash is to ensure that she does not fall off."

In [5]:
gran = 3 + 1
analyze_scores = []
doc_sentences = split_sentences(doc_sample)
# summary_sentences = split_sentences(summary_sample)
summary_sentences = split_sentences(atomic_facts_sample)
max_scores = []
for j in range(len(summary_sentences)):
    summary_sentence = summary_sentences[j].strip()
    summary_scores = [[], [], []]
    # doc scoring
    for k in range(len(doc_sentences)):
        doc_sentence = doc_sentences[k].strip()
        features = tokenizer([doc_sentence], [summary_sentence], padding=True, truncation=True, return_tensors="pt").to(device)
        model.eval()
        with torch.no_grad():
            logits = model(**features).logits
            scores = torch.nn.functional.softmax(logits, dim=-1)
        entail_score = np.array(scores[0][0].cpu()).item()
        cont_score = np.array(scores[0][1].cpu()).item()
        neut_score = np.array(scores[0][2].cpu()).item()

        summary_scores[0].append(entail_score)
        summary_scores[1].append(cont_score)
        summary_scores[2].append(neut_score)

    max_entail_score = max(summary_scores[0])
    max_entail_idx = summary_scores[0].index(max_entail_score)

    # e > c and e > n
    if summary_scores[0][max_entail_idx] > summary_scores[1][max_entail_idx] and summary_scores[0][max_entail_idx] > summary_scores[2][max_entail_idx]:
        max_scores.append(max_entail_score)
    else:
        temp_scores = []
        new_doc_sentences = []

        expanded_gran_list = []
        for g in range(1, gran):
            combination = combinations(list(range(len(doc_sentences))), g)
            comb_list = list(combination)
            expanded_gran_list.extend(comb_list)

        expanded_gran_idx_list = []
        for expanded_gran in expanded_gran_list:
            idx_list = list(expanded_gran)
            if max_entail_idx in idx_list and is_consecutive_by_one(idx_list):
                expanded_gran_idx_list.append(idx_list)

        for gran_idx_list in expanded_gran_idx_list:
            new_doc_sentence = ""
            for idx in gran_idx_list:
                new_doc_sentence += doc_sentences[idx] + " "
            new_doc_sentences.append(new_doc_sentence)

        for new_doc_sentence in new_doc_sentences:
            features = tokenizer([new_doc_sentence], [summary_sentence], padding=True, truncation=True, return_tensors="pt").to(device)
            model.eval()
            with torch.no_grad():
                logits = model(**features).logits
                scores = torch.nn.functional.softmax(logits, dim=-1)
            entail_score = np.array(scores[0][0].cpu()).item()
            temp_scores.append(entail_score)

        max_temp_score = max(temp_scores)
        max_scores.append(max(max_entail_score, max_temp_score))
analyze_scores.append(min(max_scores))
print(analyze_scores)

[0.9676105976104736]
