In [1]:
import os
import sys
from tqdm.notebook import tqdm
import numpy as np
from collections import Counter, defaultdict

from nltk.tokenize import sent_tokenize

In [2]:
import roberta_ses
from roberta_ses.interface import Roberta_SES_Entailment

# from sentence_transformers import SentenceTransformer

# import faiss

## Entailment between sentences

In [3]:
### 0 = contradict, 1 = neutral, 2 = entailment 

In [5]:
# ses = Roberta_SES_Entailment(
#     roberta_path='/Users/ytshao/Desktop/Yutong/models/roberta-large-mnli/',
#     ckpt_path='/Users/ytshao/Desktop/Yutong/external_repos/Roberta_SES/checkpoints/roberta-large/epoch=2-valid_loss=-0.2620-valid_acc_end=0.9223.ckpt',
#     max_length=512,
#     device_name='cpu')

ses = Roberta_SES_Entailment(
    roberta_path='/Users/ytshao/Desktop/Yutong/models/roberta-large/',
    ckpt_path='/Users/ytshao/Desktop/Yutong/external_repos/Roberta_SES/checkpoints/roberta-large/epoch=2-valid_loss=-0.2620-valid_acc_end=0.9223.ckpt',
    output_classes=3,
    max_length=512,
    device_name='cpu')

# ses = Roberta_SES_Entailment(
#     roberta_path='/Users/ytshao/Desktop/Yutong/models/roberta-base',
#     ckpt_path='/Users/ytshao/Desktop/Yutong/external_repos/Roberta_SES/checkpoints/roberta-base/epoch=4-valid_loss=-0.6472-valid_acc_end=0.9173.ckpt',
#     max_length=512,
#     device_name='cpu')

In [6]:
ses.model.bert_config

RobertaConfig {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "type_vocab_size": 1,
  "vocab_size": 50265
}

### Manual examples

In [7]:
pred = ses.predict(
    "I like sports",  # premise 
    "I like soccer",  # hypothesis 
)

pred

(tensor(1), tensor([1.3612e-03, 9.9769e-01, 9.4710e-04]))

In [8]:
pred = ses.predict(
    "I like soccer",  # premise 
    "I like sports",  # hypothesis 
)

pred

(tensor(2), tensor([0.0016, 0.3138, 0.6846]))

In [None]:
pred = ses.predict(
    "I like apple",  # premise 
    "I like apple and banana",  # hypothesis 
)

pred

In [None]:
pred = ses.predict(
    "I like apple and banana",  # premise 
    "I like apple",  # hypothesis 
)

pred

### Load AMASum

In [56]:
dataset_dir = "/Users/ytshao/Desktop/Yutong/dataset/AMASum/form_min_10_max_100_revs"
split = "valid"
dataset_path = os.path.join(dataset_dir, f"{split}.source")
target_path = os.path.join(dataset_dir, f"{split}.target")

In [59]:
# List[<product>:List[<doc>:List[<sent>:str]]]
products = []

with open(dataset_path, 'r') as f:
    for l in tqdm(f):
        product_docs = l.strip().split('</s>')
        product_docs_sents = []
        for doc in product_docs:
            sents = sent_tokenize(doc)
            sents = [sent for sent in sents if len(sent) > 3]
            product_docs_sents.append(sents)
        products.append(product_docs_sents)


0it [00:00, ?it/s]

In [84]:
# List[<product>:List[<summs>:List[verd, pros, cons]]]
product_summs = []

with open(target_path, 'r') as f:
    for l in tqdm(f):
        summs = l.strip().split('</s>')
        product_summs.append(summs)
        

0it [00:00, ?it/s]

In [85]:
n_products = len(products)
n_docs = sum([len(product) for product in products])
n_sents = sum([sum([len(doc) for doc in product]) for product in products])
n_products, n_docs, n_sents

(3302, 256764, 1086461)

In [86]:
len(product_summs)

3302

In [79]:
short_sent_counter = Counter()

for product in products:
    for doc in product:
        for sent in doc:
            if len(sent) <= 10:
                short_sent_counter[len(sent)] += 1

sorted(short_sent_counter.most_common())

[(4, 1243), (5, 1564), (6, 1431), (7, 1847), (8, 2628), (9, 3063), (10, 4958)]

### Check: review sentences vs. gold summ
Too slow to predict for all

In [None]:
entail_counter = Counter()

pbar = tqdm(total=n_sents)
for product, summ in zip(products, product_summs):
    for doc in product:
        for sent in doc:
            pred, probs = ses.predict(sent, summ)
            pred = pred.item()
            if pred == 1:
                res_str = 'neutral'
            elif pred == 0:
                res_str = f'contra-{probs[0].item():.2f}'
            elif pred == 2:
                res_str = f'entail-{probs[2].item():.2f}'
            entail_counter[res_str] += 1
            
            pbar.update(1)
            
pbar.close()

In [71]:
sorted(entail_counter.most_common())

[('contra-0.37', 1),
 ('contra-0.38', 4),
 ('contra-0.40', 2),
 ('contra-0.41', 5),
 ('contra-0.42', 9),
 ('contra-0.43', 5),
 ('contra-0.44', 7),
 ('contra-0.45', 13),
 ('contra-0.46', 21),
 ('contra-0.47', 29),
 ('contra-0.48', 48),
 ('contra-0.49', 41),
 ('contra-0.50', 60),
 ('contra-0.51', 72),
 ('contra-0.52', 76),
 ('contra-0.53', 56),
 ('contra-0.54', 72),
 ('contra-0.55', 56),
 ('contra-0.56', 47),
 ('contra-0.57', 58),
 ('contra-0.58', 58),
 ('contra-0.59', 48),
 ('contra-0.60', 72),
 ('contra-0.61', 49),
 ('contra-0.62', 48),
 ('contra-0.63', 43),
 ('contra-0.64', 34),
 ('contra-0.65', 51),
 ('contra-0.66', 50),
 ('contra-0.67', 33),
 ('contra-0.68', 35),
 ('contra-0.69', 33),
 ('contra-0.70', 30),
 ('contra-0.71', 36),
 ('contra-0.72', 38),
 ('contra-0.73', 35),
 ('contra-0.74', 30),
 ('contra-0.75', 40),
 ('contra-0.76', 27),
 ('contra-0.77', 28),
 ('contra-0.78', 26),
 ('contra-0.79', 34),
 ('contra-0.80', 33),
 ('contra-0.81', 38),
 ('contra-0.82', 25),
 ('contra-0.83', 

#### Entail / contra pairs

In [87]:
entail_pairs = []
contra_pairs = []

pbar = tqdm(total=n_products)
for product, summs in zip(products, product_summs):
    summ = summs[0]  ## verd 
    for doc in product[:1]:
        for sent in doc[:1]:
            pred, probs = ses.predict(sent, summ)
            pred = pred.item()
            if pred == 0:
                contra_pairs.append((sent, summ, f'contra-{probs[0].item():.2f}'))
            elif pred == 2:
                entail_pairs.append((sent, summ, f'entail-{probs[2].item():.2f}'))
            
            pbar.update(1)
            
pbar.close()

  0%|          | 0/3302 [00:00<?, ?it/s]

In [88]:
entail_pairs

[('i recommend this, it\'s perfect for the our box garden, "eases bending over back pain" easy to scoot left and right while sitting and working harvest,',
  "A unique cart won't work for everyone, but provides both a portable garden tool compartment and handy seat for those who can use it. ",
  'entail-0.75'),
 ('I love this case for my new Mac book 12 inch, it fits perfectly, and protects my laptop.',
  'Highly protective and well padded, this case will keep your laptop safe in a fall. ',
  'entail-0.53'),
 ('Follow-up: This is a very nice piece except for the hardware.',
  'Stands out for its solid construction. This is a great option once you get past the “heavy fumes” period. ',
  'entail-0.56'),
 ('My son loves these PowerBar Protein Bars.',
  'For those looking to augment lean muscle growth, this is a tasty protein bomb. ',
  'entail-0.61'),
 ('I definitely recommend these.',
  'Worth considering for pet owners who need a digestive supplement with appealing flavor. If side effec

In [93]:
contra_pairs

[("Buyer beware: while the ad doesn't say it this is a costco warehouse version of the tommy Bahamas chair, so cheaper quality.",
  "If you are looking for the ultimate beach chair, this one has numerous features its competitors can't top — nearly everything you need in a chair specifically designed for the beach. ",
  'contra-0.68'),
 ('I purchased these to replace what should have been identical Riedel glasses that we broke.',
  'An affordable version of one of the most popular bowl shapes from the pricier Vinum range. ',
  'contra-0.56'),
 ('The programme removed viruses from my computer exactly as promised.',
  "If a malware infestation keeps you from booting Windows or running an antivirus scan, the bootable FixMeStick can save you. But don't ditch your existing antivirus; FixMeStick offers no real-time protection. ",
  'contra-0.96'),
 ('I purchased these less than two months ago and the black exterior coating started to peel off all of them last week.',
  'Stainless steel constr

In [94]:
len(entail_pairs), len(contra_pairs)

(77, 625)

## Clustering

In [98]:
sbert_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [102]:
embeddings = sbert_model.encode(["This is a sentence"])

In [109]:
embeddings.shape

(1, 384)

In [107]:
print(sbert_model.encode.__doc__)


        Computes sentence embeddings

        :param sentences: the sentences to embed
        :param batch_size: the batch size used for the computation
        :param show_progress_bar: Output a progress bar when encode sentences
        :param output_value:  Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
        :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
        :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
        :param device: Which torch.device to use for the computation
        :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.

        :return:
           By default, a list of tensors is returned. If con