In [99]:
import torch
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
print(input)
print(target)
# loss = F.binary_cross_entropy_with_logits(input, target)
# loss.backward()

tensor([-0.8773,  1.2374, -0.3002], requires_grad=True)
tensor([1., 0., 1.])


In [1]:
import pandas as pd
import numpy as np
import os
from convokit import download, Corpus
from sklearn.metrics import roc_curve
from scipy.special import softmax

In [2]:
wikicorpus = Corpus(filename=download("conversations-gone-awry-corpus"))
cmvcorpus = Corpus(filename=download("conversations-gone-awry-cmv-corpus"))

Dataset already exists at /home/sqt2/.convokit/downloads/conversations-gone-awry-corpus
Dataset already exists at /home/sqt2/.convokit/downloads/conversations-gone-awry-cmv-corpus


In [28]:
def load_logits(saved_path):
    id = []
    logit0 = []
    logit1 = []
    data = {"id": [], "logit0": [], "logit1": []}
    pred_file = open(saved_path, 'r')
    pred_lines = pred_file.readlines()[1:]
    for line in pred_lines:
        id2pred = line.split(",")
        
        assert len(id2pred) == 5
        id.append(id2pred[0])
        logit0.append(id2pred[3])
        logit1.append(id2pred[4])
    df = pd.DataFrame({"logit0": logit0, "logit1": logit1}, index=id)
    convert_dict = {'logit0': float, 'logit1': float}
 
    df = df.astype(convert_dict)
    return df
    
def calculate_counterfactual(saved_model, split, alpha):
    orig_path = os.path.join(saved_model, "{}_predictions.csv".format(split))
    orig_logits = load_logits(orig_path)

    counterfactual_path = os.path.join(saved_model, "{}_counterfactual_predictions.csv".format(split))
    counterfactual_logits = load_logits(counterfactual_path)
    counterfactual_logits["logit0"] = counterfactual_logits["logit0"] * float(alpha)
    counterfactual_logits["logit1"] = counterfactual_logits["logit1"] * float(alpha)

    final_logits = pd.DataFrame()
    final_logits["logit0"] = orig_logits["logit0"] - counterfactual_logits["logit0"]
    final_logits["logit1"] = orig_logits["logit1"] - counterfactual_logits["logit1"]
    final_logits["score"] = softmax(final_logits[['logit0', 'logit1']].to_numpy(), axis=1)[:,1]
    return final_logits
    
def acc_with_threshold(y_true, y_score, thresh):
    y_pred = (y_score > thresh).astype(int)
    return (y_pred == y_true).mean() 

def tune_model_for_val(corpus, corpus_name, saved_model):
    label_metadata = "conversation_has_personal_attack" if corpus_name == "wikiconv" else "has_removed_comment"
    utt_label_metadata = "comment_has_personal_attack" if corpus_name == "wikiconv" else None
    best_acc = 0
    best_alpha = 0
    best_threshold = 0
    for alpha in np.arange(1.5,1.51,0.1):
        val_scores = calculate_counterfactual(saved_model, "val", alpha)
        highest_convo_scores = {c.id: -1 for c in corpus.iter_conversations(lambda convo: convo.meta["split"] == "val")}
        for utt_id in val_scores.index:
            parent_convo = corpus.get_utterance(utt_id).get_conversation()
            utt_score = val_scores.loc[utt_id].score
            if parent_convo.id in highest_convo_scores:
                if utt_score > highest_convo_scores[parent_convo.id]:
                    highest_convo_scores[parent_convo.id] = utt_score
        val_convo_ids = [c.id for c in corpus.iter_conversations(lambda convo: convo.meta["split"] == "val")]
        val_labels = np.asarray([int(corpus.get_conversation(c).meta[label_metadata]) for c in val_convo_ids])
        val_scores = np.asarray([highest_convo_scores[c] for c in val_convo_ids])
        _, _, thresholds = roc_curve(val_labels, val_scores)
        accs = [acc_with_threshold(val_labels, val_scores, t) for t in thresholds]
        best_acc_idx = np.argmax(accs)
        # print("{} {} |||| Achieved Accuracy:".format(alpha, thresholds[best_acc_idx]), accs[best_acc_idx])
        if accs[best_acc_idx] > best_acc:
            best_acc = accs[best_acc_idx]
            best_alpha = alpha
            best_threshold = thresholds[best_acc_idx]
    print("{} {} |||| Achieved Accuracy:".format(best_alpha, best_threshold), best_acc)
    return best_acc, best_alpha, best_threshold

def tune_model_for_dynamic(corpus, corpus_name, saved_model):
    label_metadata = "conversation_has_personal_attack" if corpus_name == "wikiconv" else "has_removed_comment"
    utt_label_metadata = "comment_has_personal_attack" if corpus_name == "wikiconv" else None
    best_acc = 0
    best_alpha = 0
    best_threshold = 0
    for alpha in np.arange(1.5,1.51,0.1):
        val_scores = calculate_counterfactual(saved_model, "val", alpha)
        highest_convo_scores = {c.id: -1 for c in corpus.iter_conversations(lambda convo: convo.meta["split"] == "val")}
        for utt_id in val_scores.index:
            parent_convo = corpus.get_utterance(utt_id).get_conversation()
            utt_score = val_scores.loc[utt_id].score
            if parent_convo.id in highest_convo_scores:
                if utt_score > highest_convo_scores[parent_convo.id]:
                    highest_convo_scores[parent_convo.id] = utt_score
        val_convo_ids = [c.id for c in corpus.iter_conversations(lambda convo: convo.meta["split"] == "val")]
        val_labels = np.asarray([int(corpus.get_conversation(c).meta[label_metadata]) for c in val_convo_ids])
        val_scores = np.asarray([highest_convo_scores[c] for c in val_convo_ids])
        _, _, thresholds = roc_curve(val_labels, val_scores)
        accs = [acc_with_threshold(val_labels, val_scores, t) for t in thresholds]
        best_acc_idx = np.argmax(accs)
        # print("{} {} |||| Achieved Accuracy:".format(alpha, thresholds[best_acc_idx]), accs[best_acc_idx])
        if accs[best_acc_idx] > best_acc:
            best_acc = accs[best_acc_idx]
            best_alpha = alpha
            best_threshold = thresholds[best_acc_idx]
    print("{} {} |||| Achieved Accuracy:".format(best_alpha, best_threshold), best_acc)
    return best_acc, best_alpha, best_threshold
    
def counterfactual_evaluate(corpus, corpus_name, saved_model):
    label_metadata = "conversation_has_personal_attack" if corpus_name == "wikiconv" else "has_removed_comment"
    utt_label_metadata = "comment_has_personal_attack" if corpus_name == "wikiconv" else None
    
    _, best_alpha, best_threshold = tune_model_for_val(corpus, corpus_name, saved_model)
    test_scores = calculate_counterfactual(saved_model, "test", best_alpha)
    
    test_scores["prediction"] = (test_scores["score"] > best_threshold).astype(int)
    prediction_file = os.path.join(saved_model, "counterfactual_final.csv")
    test_scores.to_csv(prediction_file)
    
    highest_convo_scores = {c.id: -1 for c in corpus.iter_conversations(lambda convo: convo.meta['split']=="test")}
    for utt_id in test_scores.index:
        parent_convo = corpus.get_utterance(utt_id).get_conversation()
        utt_score = test_scores.loc[utt_id].score
        if utt_score > highest_convo_scores[parent_convo.id]:
            highest_convo_scores[parent_convo.id] = utt_score
    test_convo_ids = [c.id for c in corpus.iter_conversations(lambda convo: convo.meta['split'] == 'test')]
    test_labels = np.asarray([int(corpus.get_conversation(c).meta[label_metadata]) for c in test_convo_ids])
    test_scores = np.asarray([highest_convo_scores[c] for c in test_convo_ids])
    test_pred = (test_scores > best_threshold).astype(int)
    print((test_pred == test_labels).mean())
    return 

In [29]:
saved_model = "/reef/sqt2/BERTCRAFT_counterfactual/wikiconv/roberta-base/seed-1"
counterfactual_evaluate(wikicorpus, "wikiconv", saved_model)

1.5 0.8211821152695813 |||| Achieved Accuracy: 0.6214285714285714
0.6035714285714285


In [30]:
saved_model = "/reef/sqt2/BERTCRAFT_counterfactual/wikiconv/roberta-base/seed-2"
counterfactual_evaluate(wikicorpus, "wikiconv", saved_model)

1.5 0.801236981197876 |||| Achieved Accuracy: 0.6261904761904762
0.6095238095238096


In [31]:
saved_model = "/reef/sqt2/BERTCRAFT_counterfactual/wikiconv/roberta-base/seed-3"
counterfactual_evaluate(wikicorpus, "wikiconv", saved_model)

1.5 0.7643885192192834 |||| Achieved Accuracy: 0.5845238095238096
0.6035714285714285


In [32]:
saved_model = "/reef/sqt2/BERTCRAFT_counterfactual/wikiconv/roberta-base/seed-4"
counterfactual_evaluate(wikicorpus, "wikiconv", saved_model)

1.5 0.6740890164667562 |||| Achieved Accuracy: 0.6047619047619047
0.5964285714285714


## CMV

In [88]:
saved_model = "/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-1"
counterfactual_evaluate(cmvcorpus, "cmv", saved_model)

0.0 0.6330996735001478 |||| Achieved Accuracy: 0.6769005847953217
0.6652046783625731


In [89]:
saved_model = "/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-2"
counterfactual_evaluate(cmvcorpus, "cmv", saved_model)

0.0 0.4992547301506855 |||| Achieved Accuracy: 0.6732456140350878
0.6622807017543859


In [90]:
saved_model = "/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-3"
counterfactual_evaluate(cmvcorpus, "cmv", saved_model)

0.0 0.328870954988892 |||| Achieved Accuracy: 0.6776315789473685
0.6622807017543859


In [91]:
saved_model = "/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-4"
counterfactual_evaluate(cmvcorpus, "cmv", saved_model)

0.0 0.43241085007452523 |||| Achieved Accuracy: 0.679093567251462
0.6754385964912281


# Evaluate

In [8]:
def extract_dynamic_samples(all_predictions, corpus, corpus_name):
    label_metadata = "conversation_has_personal_attack" if corpus_name == "wikiconv" else "has_removed_comment"
    num_convo = 0
    hard_pos, hard_neg = 0, 0
    all_pos, all_neg = 0, 0
    dynamic_samples = []
    test_samples = []
    for convo in corpus.iter_conversations():
        if convo.meta['split'] == 'test':
            test_samples.append(convo.id)
            max_agreement = 0
            for utterance in convo.iter_utterances():
                id = utterance.id
                if id in all_predictions:
                    if all_predictions[id] > max_agreement:
                        max_agreement = all_predictions[id]
            if convo.meta[label_metadata] == False:
                all_neg += 1
                if max_agreement >= 3:
                    hard_neg += 1
                    dynamic_samples.append(convo.id)
            else:
                all_pos += 1
                if max_agreement <= 8:
                    hard_pos += 1
                    dynamic_samples.append(convo.id)
    print("We have {} positive samples and {} negative samples in the test set".format(all_pos, all_neg))
    # print("We have {} positive samples and {} negative dynamic samples".format(hard_pos, hard_neg))

    single_samples = [id for id in test_samples if id not in dynamic_samples]
    print(len(dynamic_samples))
    print(len(single_samples))
    return test_samples, dynamic_samples, single_samples
def test(test_samples, pred_path, corpus, corpus_name):
    print(pred_path)
    label_metadata = "conversation_has_personal_attack" if corpus_name == "wikiconv" else "has_removed_comment"
    pred_file = open(pred_path, 'r')
    pred_lines = pred_file.readlines()[1:]
    pred_dict = {}
    for line in pred_lines:
        id2pred = line.split(",")
        
        assert len(id2pred) == 5
        utt_id = id2pred[0]
        utt_pred = id2pred[4]
        pred_dict[utt_id] = int(utt_pred)
        
    for convo in corpus.iter_conversations():
        # only consider test set conversations (we did not make predictions for the other ones)
        if convo.id in test_samples:
            for utt in convo.iter_utterances():
                if utt.id in pred_dict:
                    utt.meta['forecast_score'] = pred_dict[utt.id]
    
    conversational_forecasts_df = {
            "convo_id": [],
            "label": [],
            "prediction": []
        }
    for convo in corpus.iter_conversations():
        if convo.id in test_samples:
            conversational_forecasts_df['convo_id'].append(convo.id)
            conversational_forecasts_df['label'].append(int(convo.meta[label_metadata]))
            forecast_scores = [utt.meta['forecast_score'] for utt in convo.iter_utterances() if 'forecast_score' in utt.meta]
            conversational_forecasts_df['prediction'].append(max(forecast_scores))
    conversational_forecasts_df = pd.DataFrame(conversational_forecasts_df).set_index("convo_id")
    test_labels = conversational_forecasts_df.label
    test_preds = conversational_forecasts_df.prediction
    test_acc = (test_labels == test_preds).mean()
    
    tp = ((test_labels==1)&(test_preds==1)).sum()
    fp = ((test_labels==0)&(test_preds==1)).sum()
    tn = ((test_labels==0)&(test_preds==0)).sum()
    fn = ((test_labels==1)&(test_preds==0)).sum()

    test_precision = tp / (tp + fp)
    test_recall = tp / (tp + fn)
    test_fpr = fp / (fp + tn)
    test_f1 = 2 / (((tp + fp) / tp) + ((tp + fn) / tp))
    return {"accuracy":test_acc, "precision":test_precision, "recall":test_recall, "f1":test_f1}
    

In [9]:
def get_single_utt_preds(saved_path):
    single_utt_predictions = {}
    all_seeds = os.listdir(saved_path)
    for seed in all_seeds:
        pred_path = os.path.join(saved_path, seed, "predictions.csv")
        pred_file = open(pred_path, 'r')
        pred_lines = pred_file.readlines()[1:]
        for line in pred_lines:
            id2pred = line.split(",")
            
            assert len(id2pred) == 3
            utt_id = id2pred[0]
            utt_pred = id2pred[1]
            if utt_id not in single_utt_predictions:
                single_utt_predictions[utt_id] = int(utt_pred)
            else:
                single_utt_predictions[utt_id] += int(utt_pred)
    return single_utt_predictions
    
def full_evaluate(full_model_name, full_model_path, single_model_name, single_model_path, corpus, corpus_name):
    single_model_path = os.path.join(single_model_path, corpus_name, single_model_name)
    full_model_path = os.path.join(full_model_path, corpus_name, full_model_name)
    
    single_utt_predictions = get_single_utt_preds(single_model_path)
    test_samples, dynamic_samples, single_samples = extract_dynamic_samples(single_utt_predictions, corpus, corpus_name)

    with open('{}.txt'.format(corpus_name), 'w') as f:
        for id in dynamic_samples:
            f.write("%s\n" % id)

    result_dict = {"full_test": {"accuracy":[], "precision":[], "recall":[], "f1":[]},
                  "dynamic_only": {"accuracy":[], "precision":[], "recall":[], "f1":[]},
                  "single_enough": {"accuracy":[], "precision":[], "recall":[], "f1":[]}}

    for seed in range(1,5):        
        try:
            pred_path = os.path.join(full_model_path, "seed-{}".format(seed), "counterfactual_final.csv")
            full_test = test(test_samples, pred_path, corpus, corpus_name)
            for metric in full_test:
                result_dict['full_test'][metric].append(full_test[metric])
            dynamic_only = test(dynamic_samples, pred_path, corpus, corpus_name)
            for metric in dynamic_only:
                result_dict['dynamic_only'][metric].append(dynamic_only[metric])
            single_enough = test(single_samples, pred_path, corpus, corpus_name)
            for metric in single_enough:
                result_dict['single_enough'][metric].append(single_enough[metric])
        except:
            continue
    for metric in result_dict['full_test']:
        result_dict['full_test'][metric] = np.mean(result_dict['full_test'][metric])
        result_dict['dynamic_only'][metric] = np.mean(result_dict['dynamic_only'][metric])
        result_dict['single_enough'][metric] = np.mean(result_dict['single_enough'][metric])
    return result_dict

In [10]:
single_model_path = "/reef/sqt2/SINGLE_UTT"
full_model_path = "/reef/sqt2/BERTCRAFT_counterfactual"

In [17]:
print(full_evaluate("roberta-base", full_model_path, "roberta-base", single_model_path, wikicorpus, 'wikiconv'))

We have 420 positive samples and 420 negative samples in the test set
420
420
/reef/sqt2/BERTCRAFT_counterfactual/wikiconv/roberta-base/seed-1/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/wikiconv/roberta-base/seed-1/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/wikiconv/roberta-base/seed-1/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/wikiconv/roberta-base/seed-2/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/wikiconv/roberta-base/seed-2/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/wikiconv/roberta-base/seed-2/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/wikiconv/roberta-base/seed-3/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/wikiconv/roberta-base/seed-3/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/wikiconv/roberta-base/seed-3/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/wikiconv/roberta-base/seed-4/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_c

In [95]:
print(full_evaluate("roberta-base", full_model_path, "roberta-base", single_model_path, cmvcorpus, 'cmv'))

We have 684 positive samples and 684 negative samples in the test set
643
725
/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-1/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-1/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-1/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-2/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-2/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-2/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-3/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-3/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-3/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-4/counterfactual_final.csv
/reef/sqt2/BERTCRAFT_counterfactual/cmv/roberta-base/seed-4/counterfactu

In [80]:
single_model_path = "/reef/sqt2/SINGLE_UTT"
full_model_path = "/reef/sqt2/BERTCRAFT"

In [81]:
print(full_evaluate("roberta-base", full_model_path, "roberta-base", single_model_path, cmvcorpus, 'cmv'))

We have 684 positive samples and 684 negative samples in the test set
643
725
/reef/sqt2/BERTCRAFT/cmv/roberta-base/seed-1/counterfactual_final.csv
/reef/sqt2/BERTCRAFT/cmv/roberta-base/seed-2/counterfactual_final.csv
/reef/sqt2/BERTCRAFT/cmv/roberta-base/seed-3/counterfactual_final.csv
/reef/sqt2/BERTCRAFT/cmv/roberta-base/seed-4/counterfactual_final.csv
{'full_test': {'accuracy': nan, 'precision': nan, 'recall': nan, 'f1': nan}, 'dynamic_only': {'accuracy': nan, 'precision': nan, 'recall': nan, 'f1': nan}, 'single_enough': {'accuracy': nan, 'precision': nan, 'recall': nan, 'f1': nan}}


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
