# Evaluating BART on Training and Testing Data.

### This notebook contains a class with functions designed to evaluate BART's accuracy on my specific Reddit label set.

In [2]:
import matplotlib.pyplot as plt
import seaborn as sns
from wordcloud import WordCloud
import pandas as pd
import numpy
import torch
from transformers import BartForSequenceClassification, BartTokenizer, Trainer, TrainingArguments
from datasets import load_dataset

In [8]:
class eval_BART():

    # -------------------------------------------------------------------------------------
    # This class configures data and uses BART's zero shot model to evaluate BART's accuracy on my sentiment label set.
    # -------------------------------------------------------------------------------------
    
    def __init__(self):

        # -------------------------------------------------------------------------------------
        # initialize class with tokenizer and model.
        # -------------------------------------------------------------------------------------

        model_name = 'facebook/bart-large-mnli'
        self.tokenizer = BartTokenizer.from_pretrained(model_name)
        self.model = BartForSequenceClassification.from_pretrained(model_name, num_labels=3, multi_label=True)
        
        return 

# ---------------------------------------------------------------------------------------------------------------------------------
    
    def bin_conf_scores_via_premise(self, data):
        
        # -------------------------------------------------------------------------------------
        # Given a pandas dataframe, evaluate each row using BART. A confidence score will be assigned 
        # to each premise. At the end, we bin each premise under their respective confidence score.
        #
        # Returns a dictionary where key = 'confidence score interval', and value = 'premise'.
        # -------------------------------------------------------------------------------------
        
        LABELS = ["contradiction", "neutral", "entailment"]
        high_conf_list = []

        # Key = confidence interval ... value = premise (reddit text)
        high_conf_dict = {
            "97-100": [],
            "95-97": [],
            "90-95": [],
            "80-90": [],
            "60-80": [],
            "40-60": [],
            "0-40": []
        }

        # Iterate through our pandas data to get the current row information.
        for _, row in data.iterrows():
            premise = row["premise"]
            hypothesis = row["hypothesis"]

            # Get Barts confidence score, given the premise and hypothesis.
            inputs = self.tokenizer(
                premise,
                hypothesis,
                return_tensors="pt",
                truncation=True,
                padding=True
            )
            scores = torch.softmax(self.model(**inputs).logits, dim=-1)
            scores = scores.squeeze().tolist()

            # Throw the premise into its respective slot.
            if scores[2] > .97:
                high_conf_dict["97-100"].append(premise)
            elif scores[2] >= .95 and scores[2] < .97:
                high_conf_dict["95-97"].append(premise)
            elif scores[2] >= .90 and scores[2] < .95:
                high_conf_dict["90-95"].append(premise)
            elif scores[2] >= .80 and scores[2] < .90:
                high_conf_dict["80-90"].append(premise)
            elif scores[2] >= .60 and scores[2] < .80:
                high_conf_dict["60-80"].append(premise)
            elif scores[2] >= .40 and scores[2] < .60:
                high_conf_dict["40-60"].append(premise)
            elif scores[2] < .40:
                high_conf_dict["0-40"].append(premise)

        # Returns a dictionary where key = 'confidence score interval', and value = 'premise'.
        return high_conf_dict

# ---------------------------------------------------------------------------------------------------------------------------------

    def value_to_num(self, score_dict):

        # -------------------------------------------------------------------------------------
        # Given a dictionary from the bin_conf_scores_via_premise function, turn those values into
        # numbers rather than text examples. The numbers represent the amount of text files in each 
        # category.
        #
        # Returns a dictionary where key = 'confidence score interval', and value = 'count of premise'
        # -------------------------------------------------------------------------------------
        
        for key in score_dict.keys():
            new_val = len(score_dict[key])
            score_dict[key] = new_val

        # Returns a dictionary where key = 'confidence score interval', and value = 'count of premise'
        return score_dict

# ---------------------------------------------------------------------------------------------------------------------------------

    def bin_conf_scores_via_premise_count(self, data):

        # -------------------------------------------------------------------------------------
        # Given a pandas dataframe, evaluate each row using BART. A confidence score will be assigned 
        # to each premise. At the end, we bin the sum of hypothesis types for each confidence interval.
        #
        # Returns a nested dictionary where root_key = 'confidence score interval', key = 'hypothesis phrase', value = count.
        # -------------------------------------------------------------------------------------
        
        LABELS = ["contradiction", "neutral", "entailment"]
        high_conf_list = []

        # Nested dictionary where confidence interval points to hypothesis phrases and their counts.
        high_conf_dict = {
            "97-100": {"This text is about politics.": 0,
                       "This text expresses gratitude.": 0,
                       "This text expresses frustration.": 0,
                       "This text is focused on solutions.": 0,
                       "This text provides information or news.": 0,
                       "This text expresses fear or panic.": 0,
                       "This text contains blaming.": 0,
                       "This text is seeking help or advice.": 0,
                       "This text is about wildfires.": 0,
                       "This text is about prescribed burns.": 0,
                       "This text is about fire management.": 0},
            "95-97": {"This text is about politics.": 0,
                       "This text expresses gratitude.": 0,
                       "This text expresses frustration.": 0,
                       "This text is focused on solutions.": 0,
                       "This text provides information or news.": 0,
                       "This text expresses fear or panic.": 0,
                       "This text contains blaming.": 0,
                       "This text is seeking help or advice.": 0,
                       "This text is about wildfires.": 0,
                        "This text is about prescribed burns.": 0,
                       "This text is about fire management.": 0},
            "90-95": {"This text is about politics.": 0,
                       "This text expresses gratitude.": 0,
                       "This text expresses frustration.": 0,
                       "This text is focused on solutions.": 0,
                       "This text provides information or news.": 0,
                       "This text expresses fear or panic.": 0,
                       "This text contains blaming.": 0,
                       "This text is seeking help or advice.": 0,
                       "This text is about wildfires.": 0,
                       "This text is about prescribed burns.": 0,
                       "This text is about fire management.": 0},
            "80-90": {"This text is about politics.": 0,
                       "This text expresses gratitude.": 0,
                       "This text expresses frustration.": 0,
                       "This text is focused on solutions.": 0,
                       "This text provides information or news.": 0,
                       "This text expresses fear or panic.": 0,
                       "This text contains blaming.": 0,
                       "This text is seeking help or advice.": 0,
                       "This text is about wildfires.": 0,
                       "This text is about prescribed burns.": 0,
                       "This text is about fire management.": 0},
            "60-80": {"This text is about politics.": 0,
                       "This text expresses gratitude.": 0,
                       "This text expresses frustration.": 0,
                       "This text is focused on solutions.": 0,
                       "This text provides information or news.": 0,
                       "This text expresses fear or panic.": 0,
                       "This text contains blaming.": 0,
                       "This text is seeking help or advice.": 0,
                       "This text is about wildfires.": 0,
                       "This text is about prescribed burns.": 0,
                       "This text is about fire management.": 0},
            "40-60": {"This text is about politics.": 0,
                       "This text expresses gratitude.": 0,
                       "This text expresses frustration.": 0,
                       "This text is focused on solutions.": 0,
                       "This text provides information or news.": 0,
                       "This text expresses fear or panic.": 0,
                       "This text contains blaming.": 0,
                        "This text is seeking help or advice.": 0,
                       "This text is about wildfires.": 0,
                        "This text is about prescribed burns.": 0,
                       "This text is about fire management.": 0},
            "0-40": {"This text is about politics.": 0,
                       "This text expresses gratitude.": 0,
                       "This text expresses frustration.": 0,
                       "This text is focused on solutions.": 0,
                       "This text provides information or news.": 0,
                       "This text expresses fear or panic.": 0,
                       "This text contains blaming.": 0,
                       "This text is seeking help or advice.": 0,
                       "This text is about wildfires.": 0,
                       "This text is about prescribed burns.": 0,
                       "This text is about fire management.": 0},
        }

        # Iterate through our pandas data to get the current row information.
        for _, row in data.iterrows():
            premise = row["premise"]
            hypothesis = row["hypothesis"]
            label = row["label"]
        
            inputs = self.tokenizer(
                premise,
                hypothesis,
                return_tensors="pt",
                truncation=True,
                padding=True
            )
            scores = torch.softmax(self.model(**inputs).logits, dim=-1)
            scores = scores.squeeze().tolist()

            # Add a count value to the correct interval -> hypothesis phrasing.
            if scores[2] > .97:
                high_conf_dict["97-100"][hypothesis] += 1
            elif scores[2] >= .95 and scores[2] < .97:
                high_conf_dict["95-97"][hypothesis] += 1
            elif scores[2] >= .90 and scores[2] < .95:
                high_conf_dict["90-95"][hypothesis] += 1
            elif scores[2] >= .80 and scores[2] < .90:
                high_conf_dict["80-90"][hypothesis] += 1
            elif scores[2] >= .60 and scores[2] < .80:
                high_conf_dict["60-80"][hypothesis] += 1
            elif scores[2] >= .40 and scores[2] < .60:
                high_conf_dict["40-60"][hypothesis] += 1
            elif scores[2] < .40:
                high_conf_dict["0-40"][hypothesis] += 1

        # Returns a nested dictionary where root_key = 'confidence score interval', key = 'hypothesis phrase', value = count.
        return high_conf_dict

# ---------------------------------------------------------------------------------------------------------------------------------
        
    def eval_perc(self, data, thresh):

        # -------------------------------------------------------------------------------------
        # Given a dict from the bin_conf_scores_via_premise_count function, convert those values into 
        # lists [hypothesis count past threshold, total hypothesis count], so that they can be turned into 
        # percentages.
        #
        # Returns a dictionary where key = 'hypothesis phrasing', and value = ratio list.
        # -------------------------------------------------------------------------------------
        
        perc_dict = {
          'This text is about politics.': [0,0],
          'This text expresses gratitude.': [0,0],
          'This text expresses frustration.': [0,0],
          'This text is focused on solutions.': [0,0],
          'This text provides information or news.': [0,0],
          'This text expresses fear or panic.': [0,0],
          'This text contains blaming.': [0,0],
          'This text is seeking help or advice.': [0,0],
          'This text is about wildfires.': [0,0],
          'This text is about prescribed burns.': [0,0],
          'This text is about fire management.': [0,0]
        }
        date_range_conv = {
            "97-100": 97,
            "95-97": 95,
            "90-95": 90,
            "80-90": 80,
            "60-80": 60,
            "40-60": 40,
            "0-40": 40
        }

        # Iterate confidence interval -> hypothesis phrasing, and add counts to the correct list index.
        for date_range in data.keys():
            for text in data[date_range].keys():
                if date_range_conv[date_range] >= thresh:
                    perc_dict[text][0] += data[date_range][text]
    
                perc_dict[text][1] += data[date_range][text]

        # Returns a dictionary where key = 'hypothesis phrasing', and value = ratio list.
        return perc_dict

# ---------------------------------------------------------------------------------------------------------------------------------
        
    def turn_into_perc(self, data_dict):

        # -------------------------------------------------------------------------------------
        # Given a dict from the eval_perc function, convert those list values into percentages.
        #
        # Returns a dictionary where key = 'hypothesis phrasing', and value = percentage.
        # -------------------------------------------------------------------------------------

        perc_dict = {
          'This text is about politics.': 0,
          'This text expresses gratitude.': 0,
          'This text expresses frustration.': 0,
          'This text is focused on solutions.': 0,
          'This text provides information or news.': 0,
          'This text expresses fear or panic.': 0,
          'This text contains blaming.': 0,
          'This text is seeking help or advice.': 0,
          'This text is about wildfires.': 0,
          'This text is about prescribed burns.': 0,
          'This text is about fire management.': 0
        }

        # Iterate through dictionary and devide list[1] by list[0] to get the percentage.
        # If list[0] = 0, return 0 to avoid devision by 0.
        for text in data_dict.keys():
            if data_dict[text][0] == 0:
                perc_dict[text] = 0
            else:
                perc_dict[text] = data_dict[text][0] / data_dict[text][1]

        # Returns a dictionary where key = 'hypothesis phrasing', and value = percentage.
        return perc_dict

# ---------------------------------------------------------------------------------------------------------------------------------

    def change_hyp_phrase(self, data, old_hyp, new_hyp):

        # -------------------------------------------------------------------------------------
        # Given a pandas dataframe, change a hypothesis phrasing 
        #
        # Returns a new dataframe with the corrected phrasing.
        # -------------------------------------------------------------------------------------

        data.loc[data["hypothesis"] == old_hyp, "hypothesis"] = new_hyp

        # Returns a new dataframe with the corrected phrasing.
        return data 

# ---------------------------------------------------------------------------------------------------------------------------------

    def graph_score_distribution(self, len_score_dict, data_name=""):

        # -------------------------------------------------------------------------------------
        # Given a score dict and data name (for labeling purposes), plot the amount of text that appears under each interval.
        #
        # Returns nothing, graphs as it goes.
        # -------------------------------------------------------------------------------------
        
        keys = list(len_score_dict.keys())
        values = list(len_score_dict.values())
    
        plt.figure()
        for i in range(len(keys)):
            plt.bar(keys[i], values[i])
    
        plt.title("confidence scores from current bart model - " + data_name)
        plt.xlabel("confidence level")
        plt.ylabel("text count")
        plt.show()
    
        return 

# ---------------------------------------------------------------------------------------------------------------------------------

    def create_word_clouds(self, score_text_list, data_name=""):

        # -------------------------------------------------------------------------------------
        # Given a text list from the iterate_word_cloud function, create a word cloud.
        #
        # Returns nothing, graphs as it goes.
        # -------------------------------------------------------------------------------------
        
        fig_name = str(data_name + ".png")
        full_string_words = ' '.join(score_text_list)
    
        wordcloud = WordCloud(width=350,
                              height=150,
                              background_color='white',
                              colormap="coolwarm").generate(full_string_words)
    
        plt.imshow(wordcloud, interpolation='bilinear')
        plt.title(data_name + " - Word Cloud")
        plt.axis('off')
        plt.savefig(fig_name, dpi=300, bbox_inches="tight")
        plt.show()
    
        return 

# ---------------------------------------------------------------------------------------------------------------------------------

    def iterate_word_cloud(self, score_dict):
    
        # -------------------------------------------------------------------------------------
        # Given a score dict, iterate through each interval to pass through the create_word_clouds 
        # function. 
        #
        # Returns nothing.
        # -------------------------------------------------------------------------------------
        
        for key in score_dict.keys():
            if len(score_dict[key]) > 0:
                create_word_clouds(score_dict[key], data_name=key)
    
        return 

# ---------------------------------------------------------------------------------------------------------------------------------

    def retrieve_text_via_conf(self, data, conf_score, rand=True, type="entailment"):

        # -------------------------------------------------------------------------------------
        # Given a text dict, print each premise in the requested 'confidence score interval' 
        #
        # Returns nothing.
        # -------------------------------------------------------------------------------------
        
        print(type, " - text for scores with confidence score = ", conf_score)
        print("-----------------------------------------------------------------")
    
        count = 0
        for dict in range(len(data[conf_score]["text"])):
            print("text: ", data[conf_score]["text"][count])
            print("hypo: ", data[conf_score]["hyp"][count])
            print()
            count += 1
    
            if (count%len(data[conf_score]["text"])) == 0:
                count = 0
                print("-----------------------------------------------------------------")
        
        return 

In [10]:
BART_eval_class = eval_BART()

