In [None]:
import os
from convokit import download, Corpus
import numpy as np
import pandas as pd

In [None]:
def extract_dynamic_samples(all_predictions, corpus, corpus_name):
    label_metadata = "conversation_has_personal_attack" if corpus_name == "wikiconv" else "has_removed_comment"
    num_convo = 0
    hard_att = 0
    hard_non = 0
    dynamic_samples = []
    test_samples = []
    for convo in corpus.iter_conversations():
        if convo.meta['split'] == 'test':
            test_samples.append(convo.id)
            max_agreement = 0
            for utterance in convo.iter_utterances():
                id = utterance.id
                if id in all_predictions:
                    if all_predictions[id] > max_agreement:
                        max_agreement = all_predictions[id]
            if convo.meta[label_metadata] == False:
                if max_agreement >= 3:
                    hard_non += 1
                    dynamic_samples.append(convo.id)
            else:
                if max_agreement <= 7:
                    hard_att += 1
                    dynamic_samples.append(convo.id)
    print("We have {} positive samples and {} negative dynamic samples".format(hard_att, hard_non))
    single_samples = [id for id in test_samples if id not in dynamic_samples]
    return test_samples, dynamic_samples, single_samples
def test(test_samples, pred_path, corpus, corpus_name):
    label_metadata = "conversation_has_personal_attack" if corpus_name == "wikiconv" else "has_removed_comment"
    pred_file = open(pred_path, 'r')
    pred_lines = pred_file.readlines()[1:]
    pred_dict = {}
    for line in pred_lines:
        id2pred = line.split(",")
        
        assert len(id2pred) == 3
        utt_id = id2pred[0]
        utt_pred = id2pred[1]
        pred_dict[utt_id] = int(utt_pred)
        
    for convo in corpus.iter_conversations():
        # only consider test set conversations (we did not make predictions for the other ones)
        if convo.id in test_samples:
            for utt in convo.iter_utterances():
                if utt.id in pred_dict:
                    utt.meta['forecast_score'] = pred_dict[utt.id]
    
    conversational_forecasts_df = {
            "convo_id": [],
            "label": [],
            "prediction": []
        }
    for convo in corpus.iter_conversations():
        if convo.id in test_samples:
            conversational_forecasts_df['convo_id'].append(convo.id)
            conversational_forecasts_df['label'].append(int(convo.meta[label_metadata]))
            forecast_scores = [utt.meta['forecast_score'] for utt in convo.iter_utterances() if 'forecast_score' in utt.meta]
            conversational_forecasts_df['prediction'].append(max(forecast_scores))
    conversational_forecasts_df = pd.DataFrame(conversational_forecasts_df).set_index("convo_id")
    test_labels = conversational_forecasts_df.label
    test_preds = conversational_forecasts_df.prediction
    test_acc = (test_labels == test_preds).mean()
    
    tp = ((test_labels==1)&(test_preds==1)).sum()
    fp = ((test_labels==0)&(test_preds==1)).sum()
    tn = ((test_labels==0)&(test_preds==0)).sum()
    fn = ((test_labels==1)&(test_preds==0)).sum()

    test_precision = tp / (tp + fp)
    test_recall = tp / (tp + fn)
    test_fpr = fp / (fp + tn)
    test_f1 = 2 / (((tp + fp) / tp) + ((tp + fn) / tp))
    return {"accuracy":test_acc, "precision":test_precision, "recall":test_recall, "f1":test_f1}
    

In [None]:
def get_single_utt_preds(saved_path):
    single_utt_predictions = {}
    all_seeds = os.listdir(saved_path)
    for seed in all_seeds:
        pred_path = os.path.join(saved_path, seed, "predictions.csv")
        pred_file = open(pred_path, 'r')
        pred_lines = pred_file.readlines()[1:]
        for line in pred_lines:
            id2pred = line.split(",")
            
            assert len(id2pred) == 3
            utt_id = id2pred[0]
            utt_pred = id2pred[1]
            if utt_id not in single_utt_predictions:
                single_utt_predictions[utt_id] = int(utt_pred)
            else:
                single_utt_predictions[utt_id] += int(utt_pred)
    return single_utt_predictions
def full_evaluate(model_name, full_model_path, single_model_path, corpus, corpus_name):
    single_model_path = os.path.join(single_model_path, corpus_name, model_name)
    full_model_path = os.path.join(full_model_path, corpus_name, model_name)
    
    single_utt_predictions = get_single_utt_preds(single_model_path)
    test_samples, dynamic_samples, single_samples = extract_dynamic_samples(single_utt_predictions, corpus, corpus_name)

    with open('{}.txt'.format(corpus_name), 'w') as f:
        for id in dynamic_samples:
            f.write("%s\n" % id)

    result_dict = {"full_test": {"accuracy":[], "precision":[], "recall":[], "f1":[]},
                  "dynamic_only": {"accuracy":[], "precision":[], "recall":[], "f1":[]},
                  "single_enough": {"accuracy":[], "precision":[], "recall":[], "f1":[]}}

    for seed in range(1,11):        
        pred_path = os.path.join(full_model_path, "seed-{}".format(seed), "predictions.csv")
        full_test = test(test_samples, pred_path, corpus, corpus_name)
        for metric in full_test:
            result_dict['full_test'][metric].append(full_test[metric])
        dynamic_only = test(dynamic_samples, pred_path, corpus, corpus_name)
        for metric in dynamic_only:
            result_dict['dynamic_only'][metric].append(dynamic_only[metric])
        single_enough = test(single_samples, pred_path, corpus, corpus_name)
        for metric in single_enough:
            result_dict['single_enough'][metric].append(single_enough[metric])
    for metric in result_dict['full_test']:
        result_dict['full_test'][metric] = np.mean(result_dict['full_test'][metric])
        result_dict['dynamic_only'][metric] = np.mean(result_dict['dynamic_only'][metric])
        result_dict['single_enough'][metric] = np.mean(result_dict['single_enough'][metric])
    return result_dict

In [None]:
full_model_path = "/reef/sqt2/BERTCRAFT"
single_model_path = "/reef/sqt2/SINGLE_UTT"
corpus_name = "wikiconv"
if corpus_name == "wikiconv":
    corpus = Corpus(filename=download("conversations-gone-awry-corpus"))
elif corpus_name == "cmv":
    corpus = Corpus(filename=download("conversations-gone-awry-cmv-corpus"))
else:
    raise Exception("Sorry, no corpus_name matched the input {}.\
     Please input a valid corpus_name [wikiconv, cmv]".format(corpus_name))

In [None]:
print(full_evaluate("roberta-base", full_model_path, single_model_path, corpus, corpus_name))

In [None]:
# Performance of single_utterance models
print(full_evaluate("roberta-base", single_model_path, single_model_path, corpus, corpus_name))

In [None]:
full_model_path = "/reef/sqt2/BERTCRAFT"
single_model_path = "/reef/sqt2/SINGLE_UTT"
corpus_name = "cmv"
if corpus_name == "wikiconv":
    corpus = Corpus(filename=download("conversations-gone-awry-corpus"))
elif corpus_name == "cmv":
    corpus = Corpus(filename=download("conversations-gone-awry-cmv-corpus"))
else:
    raise Exception("Sorry, no corpus_name matched the input {}.\
     Please input a valid corpus_name [wikiconv, cmv]".format(corpus_name))

In [None]:
print(full_evaluate("roberta-base", full_model_path, single_model_path, corpus, corpus_name))

In [None]:
# Performance of single_utterance models
print(full_evaluate("roberta-base", single_model_path, single_model_path, corpus, corpus_name))