In [1]:
import comet_ml
from comet_ml import Experiment
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset
from transformers import pipeline
from sklearn.metrics import classification_report
from sklearn.metrics import classification_report

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# change as needed
# model 
checkpoints_out_dir = '../checkpoints/clinc_plus/checkpoint-30500'

# predictions dir
correct_out_path = '../predictions/clinc_plus_augmented_data_correct_prompt4.csv'
incorrect_out_path = '../predictions/clinc_plus_augmented_data_incorrect_prompt4.csv'

# augmented data path
augmented_data_path = '../prompts/generated_text/ChatGPT_prompt4.csv'

In [3]:
# Building an experiment with your API key
experiment = Experiment(
    api_key="IkRq4zmkwF7SO5EiZAG4UjEVQ",
    project_name="CLINC",
    workspace="gdhanania",
)

# Setting hyperparameters
hyper_params = {"test_batch_size": 16}

#device
device = 'cuda:0'

# pipeline
pipeline_task = 'text-classification'

# Logging hyperparamters
experiment.log_parameters(hyper_params)

COMET INFO: Experiment is live on comet.com https://www.comet.com/gdhanania/clinc/e8ed4f7e97e3448f803cd3116be1fe8e



In [4]:
def verify(augmented_data_path, checkpoints_out_dir, correct_out_path, incorrect_out_path):    
    dataset = load_dataset("csv", data_files=augmented_data_path)
    dataset = dataset.rename_column("Sentence", "text")
    dataset = dataset.rename_column("Label", "label")
    dataset = dataset['train']

    classifier = pipeline(pipeline_task, model=checkpoints_out_dir, device=device)

    # Make predictions on the dataset
    predictions = classifier(dataset['text'], batch_size=hyper_params['test_batch_size'])

    # Convert the predictions to a list of labels
    predicted_labels = [p['label'] for p in predictions]
    true_labels = [label for label in dataset['label']]
    
    with experiment.test():
        report = classification_report(true_labels, predicted_labels, output_dict=True)

        # report has three root variables 1. accuracy 2. macro avg 3. weighted avg
        macro_avg_f1_score = report['macro avg']['f1-score']
        weighted_avg_f1_score = report['weighted avg']['f1-score']

        accuracy = report['accuracy']

        print('Macro Average F1 score: {:.2f}'.format(macro_avg_f1_score))
        print('Weighted Average F1 score: {:.2f}'.format(weighted_avg_f1_score))
        print('Accuracy: {:.2f}%'.format(accuracy * 100))

        # Logging metrics
        experiment.log_metrics({
                                "Macro Average F1 score": macro_avg_f1_score, 
                                "Weighted Average F1 score" : weighted_avg_f1_score,
                                "Accuracy" : accuracy * 100
                               })
        
        result_df = pd.DataFrame(zip(dataset['text'], true_labels, predicted_labels))
        result_df.columns = ['text', 'label', 'predicted']
        result_df = result_df.reset_index()
        
        display(result_df)
        
        result_df_correct = result_df[(result_df['label'] == result_df['predicted'])]
        result_df_incorrect = result_df[(result_df['label'] != result_df['predicted'])]
        
        display(result_df_correct)
        display(result_df_incorrect)
        
        result_df_correct.to_csv(correct_out_path, encoding='utf-8', index=False)
        result_df_incorrect.to_csv(incorrect_out_path, encoding='utf-8', index=False)
        
        return (result_df_correct, result_df_incorrect)

In [5]:
(result_df_correct, result_df_incorrect) = verify(augmented_data_path, checkpoints_out_dir, correct_out_path, incorrect_out_path)

Downloading and preparing dataset csv/default to /work/pi_adrozdov_umass_edu/gdhanania_umass_edu/hf_cache/datasets/csv/default-d7d2df36536fbcf7/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 2328.88it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 113.94it/s]
  return pd.read_csv(xopen(filepath_or_buffer, "rb", use_auth_token=use_auth_token), **kwargs)
                                                                

Dataset csv downloaded and prepared to /work/pi_adrozdov_umass_edu/gdhanania_umass_edu/hf_cache/datasets/csv/default-d7d2df36536fbcf7/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 20.53it/s]


Macro Average F1 score: 0.06
Weighted Average F1 score: 0.53
Accuracy: 43.33%


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,index,text,label,predicted
0,0,How long does it usually take to get a replace...,replacement_card_duration,replacement_card_duration
1,1,Can you tell me about the process to get a rep...,replacement_card_duration,replacement_card_duration
2,2,What is the expected turnaround time for a rep...,replacement_card_duration,replacement_card_duration
3,3,"If I lose my card, how soon can I get a replac...",replacement_card_duration,replacement_card_duration
4,4,Is there an expedited option for getting a rep...,replacement_card_duration,replacement_card_duration
...,...,...,...,...
295,295,When is my credit card payment due?,bill_due,bill_due
296,296,Have I made the full payment for my hospital b...,bill_due,bill_due
297,297,"I think I forgot to pay my credit card bill, c...",bill_due,pay_bill
298,298,Could you let me know the date for the next pa...,bill_due,bill_due


Unnamed: 0,index,text,label,predicted
0,0,How long does it usually take to get a replace...,replacement_card_duration,replacement_card_duration
1,1,Can you tell me about the process to get a rep...,replacement_card_duration,replacement_card_duration
2,2,What is the expected turnaround time for a rep...,replacement_card_duration,replacement_card_duration
3,3,"If I lose my card, how soon can I get a replac...",replacement_card_duration,replacement_card_duration
4,4,Is there an expedited option for getting a rep...,replacement_card_duration,replacement_card_duration
...,...,...,...,...
287,287,When is my cell phone bill due?,bill_due,bill_due
288,288,Could you verify if my cable bill has been paid?,bill_due,bill_due
295,295,When is my credit card payment due?,bill_due,bill_due
296,296,Have I made the full payment for my hospital b...,bill_due,bill_due


Unnamed: 0,index,text,label,predicted
6,6,Are there any fees associated with getting a r...,replacement_card_duration,international_fees
10,10,Can you explain the documents required to appl...,replacement_card_duration,new_card
11,11,Do I need to deactivate my lost card before ap...,replacement_card_duration,report_lost_card
13,13,Will the new card have the same PIN as the old...,replacement_card_duration,pin_change
14,14,Will I be able to use the replacement card for...,replacement_card_duration,international_fees
...,...,...,...,...
292,292,Could you let me know when I need to renew my ...,bill_due,expiration_date
293,293,Is there a balance due on my Dish Network bill?,bill_due,bill_balance
294,294,Can you give me the status of my tuition payme...,bill_due,bill_balance
297,297,"I think I forgot to pay my credit card bill, c...",bill_due,pay_bill
