In [1]:
import comet_ml
from comet_ml import Experiment
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset
from transformers import pipeline
from sklearn.metrics import classification_report
from sklearn.metrics import classification_report

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# change as needed
# model 
checkpoints_out_dir = '../checkpoints/clinc_plus/checkpoint-30500'

# predictions dir
correct_out_path = '../predictions/clinc_plus_augmented_data_correct.csv'
incorrect_out_path = '../predictions/clinc_plus_augmented_data_incorrect.csv'

# augmented data path
augmented_data_path = '../prompts/generated_text/ChatGPT.csv'

In [3]:
# Building an experiment with your API key
experiment = Experiment(
    api_key="IkRq4zmkwF7SO5EiZAG4UjEVQ",
    project_name="CLINC",
    workspace="gdhanania",
)

# Setting hyperparameters
hyper_params = {"test_batch_size": 16}

#device
device = 'cuda:0'

# pipeline
pipeline_task = 'text-classification'

# Logging hyperparamters
experiment.log_parameters(hyper_params)

COMET INFO: Experiment is live on comet.com https://www.comet.com/gdhanania/clinc/41d987f4b4014ee7bcbb17f9dc295bf3



In [6]:
def verify(augmented_data_path, checkpoints_out_dir, correct_out_path, incorrect_out_path):    
    dataset = load_dataset("csv", data_files=augmented_data_path)
    dataset = dataset.rename_column("Sentence", "text")
    dataset = dataset.rename_column("Label", "label")
    dataset = dataset['train']

    classifier = pipeline(pipeline_task, model=checkpoints_out_dir, device=device)

    # Make predictions on the dataset
    predictions = classifier(dataset['text'], batch_size=hyper_params['test_batch_size'])

    # Convert the predictions to a list of labels
    predicted_labels = [p['label'] for p in predictions]
    true_labels = [label for label in dataset['label']]
    
    with experiment.test():
        report = classification_report(true_labels, predicted_labels, output_dict=True)

        # report has three root variables 1. accuracy 2. macro avg 3. weighted avg
        macro_avg_f1_score = report['macro avg']['f1-score']
        weighted_avg_f1_score = report['weighted avg']['f1-score']

        accuracy = report['accuracy']

        print('Macro Average F1 score: {:.2f}'.format(macro_avg_f1_score))
        print('Weighted Average F1 score: {:.2f}'.format(weighted_avg_f1_score))
        print('Accuracy: {:.2f}%'.format(accuracy * 100))

        # Logging metrics
        experiment.log_metrics({
                                "Macro Average F1 score": macro_avg_f1_score, 
                                "Weighted Average F1 score" : weighted_avg_f1_score,
                                "Accuracy" : accuracy * 100
                               })
        
        result_df = pd.DataFrame(zip(dataset['text'], true_labels, predicted_labels))
        result_df.columns = ['text', 'label', 'predicted']
        result_df = result_df.reset_index()
        
        display(result_df)
        
        result_df_correct = result_df[(result_df['label'] == result_df['predicted'])]
        result_df_incorrect = result_df[(result_df['label'] != result_df['predicted'])]
        
        display(result_df_correct)
        display(result_df_incorrect)
        
        result_df_correct.to_csv(correct_out_path, encoding='utf-8', index=False)
        result_df_incorrect.to_csv(incorrect_out_path, encoding='utf-8', index=False)
        
        return (result_df_correct, result_df_incorrect)

In [7]:
(result_df_correct, result_df_incorrect) = verify(augmented_data_path, checkpoints_out_dir, correct_out_path, incorrect_out_path)

Found cached dataset csv (/work/pi_adrozdov_umass_edu/gdhanania_umass_edu/hf_cache/datasets/csv/default-bff4dae9b497e255/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)
100%|██████████| 1/1 [00:00<00:00, 251.52it/s]


Macro Average F1 score: 0.05
Weighted Average F1 score: 0.43
Accuracy: 34.67%


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,index,text,label,predicted
0,0,How long does a replacement card last?,replacement_card_duration,replacement_card_duration
1,1,What is the duration of a replacement card?,replacement_card_duration,replacement_card_duration
2,2,Is there an expiration date on a replacement c...,replacement_card_duration,expiration_date
3,3,How many days is a replacement card valid for?,replacement_card_duration,replacement_card_duration
4,4,Can a replacement card be reissued if it expires?,replacement_card_duration,expiration_date
...,...,...,...,...
295,295,Are there any fees associated with paying your...,bill_due,international_fees
296,296,Do you prefer to pay your bills in person or o...,bill_due,pay_bill
297,297,What happens if you overpay your bill?,bill_due,pay_bill
298,298,Have you ever had a bill dismissed due to a le...,bill_due,card_declined


Unnamed: 0,index,text,label,predicted
0,0,How long does a replacement card last?,replacement_card_duration,replacement_card_duration
1,1,What is the duration of a replacement card?,replacement_card_duration,replacement_card_duration
3,3,How many days is a replacement card valid for?,replacement_card_duration,replacement_card_duration
6,6,Are replacement cards valid for the same durat...,replacement_card_duration,replacement_card_duration
7,7,How can I determine the duration of my replace...,replacement_card_duration,replacement_card_duration
...,...,...,...,...
284,284,What is the consequence for consistently missi...,bill_due,bill_due
285,285,Have you ever incurred a late fee for paying y...,bill_due,bill_due
289,289,Have you ever been charged interest for missin...,bill_due,bill_due
290,290,What are the common reasons for missing a bill...,bill_due,bill_due


Unnamed: 0,index,text,label,predicted
2,2,Is there an expiration date on a replacement c...,replacement_card_duration,expiration_date
4,4,Can a replacement card be reissued if it expires?,replacement_card_duration,expiration_date
5,5,What happens if a replacement card expires?,replacement_card_duration,expiration_date
9,9,Can I use my replacement card up until the exp...,replacement_card_duration,expiration_date
13,13,Can I renew my replacement card before it expi...,replacement_card_duration,expiration_date
...,...,...,...,...
295,295,Are there any fees associated with paying your...,bill_due,international_fees
296,296,Do you prefer to pay your bills in person or o...,bill_due,pay_bill
297,297,What happens if you overpay your bill?,bill_due,pay_bill
298,298,Have you ever had a bill dismissed due to a le...,bill_due,card_declined
