In [82]:
import matplotlib.pyplot as plt
import seaborn as sns
from wordcloud import WordCloud
import pandas as pd
import numpy
import torch 
from transformers import BartForSequenceClassification, BartTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import torch
import random

In [83]:
def retrieve_text_via_conf(data, conf_score, rand=True, type="entailment"):
    print(type, " - text for scores with confidence score = ", conf_score)
    print("-----------------------------------------------------------------")

    count = 0
    for dict in range(len(data[conf_score]["text"])):
        print("text: ", data[conf_score]["text"][count])
        print("hypo: ", data[conf_score]["hyp"][count])
        print()
        count += 1

        if (count%len(data[conf_score]["text"])) == 0:
            count = 0
            print("-----------------------------------------------------------------")
    
    return 

def return_high_conf(data):
    LABELS = ["contradiction", "neutral", "entailment"]
    high_conf_list = []

    high_conf_dict = {
        "97-100": {"text": [],
                   "label": [],
                   "hyp": []},
        "95-97": {"text": [],
                   "label": [],
                   "hyp": []},
        "90-95": {"text": [],
                   "label": [],
                   "hyp": []},
        "80-90": {"text": [],
                   "label": [],
                   "hyp": []},
        "60-80": {"text": [],
                   "label": [],
                   "hyp": []},
        "40-60": {"text": [],
                   "label": [],
                   "hyp": []},
        "0-40": {"text": [],
                   "label": [],
                   "hyp": []},
    }
    
    for _, row in data.iterrows():
        premise = row["premise"]
        hypothesis = row["hypothesis"]
        label = row["label"]
        
        inputs = tokenizer(
            premise,
            hypothesis,
            return_tensors="pt",
            truncation=True,
            padding=True
        )

        scores = torch.softmax(model(**inputs).logits, dim=-1)
        scores = scores.squeeze().tolist()

        if scores[2] > .97:
            high_conf_dict["97-100"]["text"].append(premise)
            high_conf_dict["97-100"]["label"].append(label)
            high_conf_dict["97-100"]["hyp"].append(hypothesis)
        elif scores[2] >= .95 and scores[2] < .97:
            high_conf_dict["95-97"]["text"].append(premise)
            high_conf_dict["95-97"]["label"].append(label)
            high_conf_dict["95-97"]["hyp"].append(hypothesis)
        elif scores[2] >= .90 and scores[2] < .95:
            high_conf_dict["90-95"]["text"].append(premise)
            high_conf_dict["90-95"]["label"].append(label)
            high_conf_dict["90-95"]["hyp"].append(hypothesis)
        elif scores[2] >= .80 and scores[2] < .90:
            high_conf_dict["80-90"]["text"].append(premise)
            high_conf_dict["80-90"]["label"].append(label)
            high_conf_dict["80-90"]["hyp"].append(hypothesis)
        elif scores[2] >= .60 and scores[2] < .80:
            high_conf_dict["60-80"]["text"].append(premise)
            high_conf_dict["60-80"]["label"].append(label)
            high_conf_dict["60-80"]["hyp"].append(hypothesis)
        elif scores[2] >= .40 and scores[2] < .60:
            high_conf_dict["40-60"]["text"].append(premise)
            high_conf_dict["40-60"]["label"].append(label)
            high_conf_dict["40-60"]["hyp"].append(hypothesis)
        elif scores[2] < .40:
            high_conf_dict["0-40"]["text"].append(premise)
            high_conf_dict["0-40"]["label"].append(label)
            high_conf_dict["0-40"]["hyp"].append(hypothesis)
    
    return high_conf_dict

def value_to_num(score_dict):
    for key in score_dict.keys():
        new_val = len(score_dict[key])
        score_dict[key] = new_val

    return score_dict

def graph_score_distribution(len_score_dict, data_name=""):
    keys = list(len_score_dict.keys())
    values = list(len_score_dict.values())

    plt.figure()
    for i in range(len(keys)):
        plt.bar(keys[i], values[i])

    plt.title("confidence scores from current bart model - " + data_name)
    plt.xlabel("confidence level")
    plt.ylabel("text count")
    plt.show()

    return 
    
def get_unique_labeled_text(df):
    new_df = pd.DataFrame(columns=df.columns)
    count = 0
    
    for num in range(len(df['label'])):     
        if df['label'][num] == "entailment":            
            new_df.loc[count, 'premise'] = df.loc[num, 'premise']
            new_df.loc[count, 'hypothesis'] = df.loc[num, 'hypothesis']
            new_df.loc[count, 'label'] = df.loc[num, 'label']
            
            count += 1

    return new_df

In [89]:
model_name = 'facebook/bart-large-mnli'
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForSequenceClassification.from_pretrained(model_name, num_labels=3, multi_label=True)



In [90]:
test_data_entail_sixEight = pd.read_csv("entailment_test_data_sixEight.csv")
test_1_entail = return_high_conf(test_data_entail_sixEight)

In [95]:
retrieve_text_via_conf(test_1_entail, "80-90")

entailment  - text for scores with confidence score =  80-90
-----------------------------------------------------------------
text:  Hey Sam. I’m also a wildland firefighter with the FS and a union rep (also a Sam). You have very broad rights and protections for the things you do or say in your personal time, and we need more folks talking publicly about what’s happening at our public lands agencies. Union member or not, I’d love to help. DM me and let’s get in touch.
hypo:  This text provides information or news.

text:  If you’re feeling brave and upset and want to share your story, try High Country News, they’ve had a lot of good coverage of the Forest Service mess under the current administration. Other potential media to contact include HuffPost and politico.
hypo:  This text provides information or news.

text:  Hey Sam, I’m a national parks reporter who covers all types of public land management. Would love to talk more, feel free to DM me
hypo:  This text provides information 