This notebook is used for surveying the changes of forecasts due to the lack of context. We survey on 3 dimensions:
- calm/awry conversations: which type of conversations need more context?
- positive/negative changes: The lack of context makes utterances sound more calm?
- before/after triggered: The changes may happen more after the triggering preds? We use trigger on normal settings.


The min threshold for changes is 0.2.

In [12]:
import pandas as pd
import os
import matplotlib.pyplot as plt
import numpy as np
from convokit import Corpus, download
from collections import Counter

seeds = [11, 12, 13, 14, 15, 42, 81, 93, 188, 830]

In [2]:
cmv_dir = "/reef/lyt5_cga_cmv"
corpus = Corpus(cmv_dir)
label_metadata = "has_removed_comment"

In [3]:
corpus.print_summary_stats()

Number of Speakers: 25466
Number of Utterances: 123441
Number of Conversations: 20576


In [39]:
count = 0
for convo in corpus.iter_conversations():
    if convo.meta['split'] == 'test':
        count += 1
print(count/2)

420.0


In [3]:
full_preds = {}
for seed in seeds:
    full_pred_path = os.path.join("/reef/sqt2/TraMa_Exp/Full-SinglePreds/train-cmv/roberta-large/", "seed-{}".format(seed), "test-cmv/full/test_predictions.csv")
    full_preds[seed] = pd.read_csv(full_pred_path, index_col=0)
single_preds = {}
for seed in seeds:
    single_pred_path = os.path.join("/reef/sqt2/TraMa_Exp/Full-SinglePreds/train-cmv/roberta-large/", "seed-{}".format(seed), "test-cmv/single/test_predictions.csv")
    single_preds[seed] = pd.read_csv(single_pred_path, index_col=0)

In [9]:
change_characteristics = [] #['calm', 'positive', 'before']
for seed in seeds:
    full_pred = full_preds[seed]
    single_pred = single_preds[seed]
    for convo in corpus.iter_conversations():
        if convo.meta['split'] == 'test':
            label = 'awry' if convo.meta[label_metadata] else 'calm'
            triggered = 'before'
            for utt in convo.iter_utterances():
                id = utt.id
                if full_pred['forecast'][id]:
                    triggered = 'after'
                diff = abs(single_pred['forecast_prob'][id] - full_pred['forecast_prob'][id])
                if diff > 0.2:
                    change = 'positive' if single_pred['forecast_prob'][id] >\
                        full_pred['forecast_prob'][id] else 'negative'
                    change_characteristics.append((label, triggered, change))



In [10]:
len(change_characteristics)

51140

In [13]:
frequency = Counter(change_characteristics)
for char in frequency:
    frequency[char] /= len(change_characteristics)
frequency

Counter({('awry', 'after', 'negative'): 0.3446812671098944,
         ('calm', 'after', 'negative'): 0.2117129448572546,
         ('calm', 'before', 'positive'): 0.15756746186937817,
         ('awry', 'before', 'positive'): 0.12555729370355886,
         ('awry', 'before', 'negative'): 0.0721744231521314,
         ('calm', 'before', 'negative'): 0.06689479859210012,
         ('calm', 'after', 'positive'): 0.011497849041845913,
         ('awry', 'after', 'positive'): 0.009913961673836527})

In [22]:
change_characteristics = [] #['calm', 'positive', 'before']
for seed in [11]:
    full_pred = full_preds[seed]
    single_pred = single_preds[seed]
    convo_pred_change_true, convo_pred_change = 0, 0 
    for convo in corpus.iter_conversations():
        if convo.meta['split'] == 'test':
            label = 'awry' if convo.meta[label_metadata] else 'calm'
            triggered = 'before'
            
            convo_pred = 0
            # Get convo_preds
            for utt in convo.iter_utterances():
                if full_pred['forecast'][utt.id]:
                    convo_pred = 1

            for utt in convo.iter_utterances():
                id = utt.id
                diff = single_pred['forecast_prob'][id] - full_pred['forecast_prob'][id]
                if diff > 0.2:
                    if single_pred['forecast'][id] > convo_pred:
                        convo_pred_change += 1
                        # if convo.meta[label_metadata]:
                        #     convo_pred_change_true += 1
                        break
    print(convo_pred_change, convo_pred_change_true)

346 0


# Recovery

## CMV

In [30]:
cmv_dir = "/reef/lyt5_cga_cmv"
cmv_corpus = Corpus(cmv_dir)


In [68]:
full_preds = {}
seeds = [11, 12, 13, 14, 15, 42, 81, 93, 188, 830]
for seed in seeds:
    full_pred_path = os.path.join("/reef/sqt2/TraMa_Exp/Full-SinglePreds/train-cmv/roberta-large/", "seed-{}".format(seed), "test-cmv/full/test_predictions.csv")
    full_preds[seed] = pd.read_csv(full_pred_path, index_col=0)
single_preds = {}
for seed in seeds:
    single_pred_path = os.path.join("/reef/sqt2/TraMa_Exp/Full-SinglePreds/train-cmv/roberta-large/", "seed-{}".format(seed), "test-cmv/single/test_predictions.csv")
    single_preds[seed] = pd.read_csv(single_pred_path, index_col=0)

### Full

In [69]:
label_metadata = "has_removed_comment"
count_trigger, count_recovery = 0, 0
for seed in full_preds:
    full_pred = full_preds[seed]
    for convo in cmv_corpus.iter_conversations():
        if convo.meta['split'] == 'test' and not convo.meta[label_metadata]:
            utt_list = convo.get_chronological_utterance_list()
            triggered = False
            for utt in utt_list[:-1]:
                id = utt.id
                if full_pred['forecast'][id]:
                    triggered = True
            if triggered:
                count_trigger += 1
                if not full_pred['forecast'][utt_list[-1].id]:
                    count_recovery += 1
print(count_trigger, count_recovery, count_recovery/count_trigger, count_trigger- count_recovery)

6251 2771 0.4432890737482003 3480


In [70]:
count_trigger, count_recovery = 0, 0
for seed in full_preds:
    full_pred = full_preds[seed]
    for convo in cmv_corpus.iter_conversations():
        if convo.meta['split'] == 'test' and convo.meta[label_metadata]:
            utt_list = convo.get_chronological_utterance_list()
            triggered = False
            for utt in utt_list[:-1]:
                id = utt.id
                if full_pred['forecast'][id]:
                    triggered = True
            if triggered:
                count_trigger += 1
                if not full_pred['forecast'][utt_list[-1].id]:
                    count_recovery += 1
print(count_trigger, count_recovery, count_recovery/count_trigger, count_trigger- count_recovery)

12298 2189 0.1779964221824687 10109


In [72]:
recall = 10109/12298
precision = 10109/(10109+3480)
print(recall, precision, 2*(recall*precision)/(recall + precision))

0.8220035778175313 0.743910515858415 0.7810097732452582


### Single


In [73]:
count_trigger, count_recovery = 0, 0
for seed in single_preds:
    full_pred = single_preds[seed]
    for convo in cmv_corpus.iter_conversations():
        if convo.meta['split'] == 'test' and not convo.meta[label_metadata]:
            utt_list = convo.get_chronological_utterance_list()
            triggered = False
            for utt in utt_list[:-1]:
                id = utt.id
                if full_pred['forecast'][id]:
                    triggered = True
            if triggered:
                count_trigger += 1
                if not full_pred['forecast'][utt_list[-1].id]:
                    count_recovery += 1
print(count_trigger, count_recovery, count_recovery/count_trigger, count_trigger- count_recovery)

6685 5061 0.7570680628272252 1624


In [74]:
count_trigger, count_recovery = 0, 0
for seed in single_preds:
    full_pred = single_preds[seed]
    for convo in cmv_corpus.iter_conversations():
        if convo.meta['split'] == 'test' and convo.meta[label_metadata]:
            utt_list = convo.get_chronological_utterance_list()
            triggered = False
            for utt in utt_list[:-1]:
                id = utt.id
                if full_pred['forecast'][id]:
                    triggered = True
            if triggered:
                count_trigger += 1
                if not full_pred['forecast'][utt_list[-1].id]:
                    count_recovery += 1
print(count_trigger, count_recovery, count_recovery/count_trigger, count_trigger- count_recovery)

12803 6942 0.5422166679684449 5861


In [75]:
recall = 5861/12803
precision = 5861/(5861+1624)
print(recall, precision, 2*(recall*precision)/(recall + precision))


0.4577833320315551 0.7830327321309285 0.5777799684542587


## Wiki

In [13]:
corpus = Corpus(filename=download("conversations-gone-awry-corpus"))


Dataset already exists at /home/sqt2/.convokit/downloads/conversations-gone-awry-corpus


In [50]:
seeds = [11, 12, 13, 14, 15, 42, 81, 93, 830]
full_preds = {}
for seed in seeds:
    full_pred_path = os.path.join("/reef/sqt2/TraMa_Exp/Full-SinglePreds/train-wikiconv/roberta-large/", "seed-{}".format(seed), "test-wikiconv/full/test_predictions.csv")
    full_preds[seed] = pd.read_csv(full_pred_path, index_col=0)
single_preds = {}
for seed in seeds:
    single_pred_path = os.path.join("/reef/sqt2/TraMa_Exp/Full-SinglePreds/train-wikiconv/roberta-large/", "seed-{}".format(seed), "test-wikiconv/single/test_predictions.csv")
    single_preds[seed] = pd.read_csv(single_pred_path, index_col=0)

### Full

In [58]:
label_metadata = "conversation_has_personal_attack"
count_trigger, count_recovery = 0, 0
for seed in full_preds:
    full_pred = full_preds[seed]
    for convo in corpus.iter_conversations():
        if convo.meta['split'] == 'test' and not convo.meta[label_metadata]:
            utt_list = convo.get_chronological_utterance_list()
            triggered = False
            for utt in utt_list[:-2]:
                id = utt.id
                if full_pred['forecast'][id]:
                    triggered = True
            if triggered:
                count_trigger += 1
                if not full_pred['forecast'][utt_list[-2].id]:
                    count_recovery += 1
print(count_trigger, count_recovery, count_recovery/count_trigger, count_trigger- count_recovery)

1043 301 0.28859060402684567 742


In [60]:
count_trigger, count_recovery = 0, 0
for seed in full_preds:
    full_pred = full_preds[seed]
    for convo in corpus.iter_conversations():
        if convo.meta['split'] == 'test' and convo.meta[label_metadata]:
            utt_list = convo.get_chronological_utterance_list()
            triggered = False
            for utt in utt_list[:-2]:
                id = utt.id
                if full_pred['forecast'][id]:
                    triggered = True
            if triggered:
                count_trigger += 1
                if not full_pred['forecast'][utt_list[-2].id]:
                    count_recovery += 1
print(count_trigger, count_recovery, count_recovery/count_trigger, count_trigger- count_recovery)

2144 299 0.1394589552238806 1845


In [57]:
2144 - 299

1845

In [59]:
recall = 1845/2144
precision = 1845/(1845+742)
print(recall, precision, 2*(recall*precision)/(recall + precision))
# print((0.28859060402684567 + 1 - 0.1394589552238806)/2) 

0.8605410447761194 0.7131812910707384 0.7799619530754598


### Single

In [62]:
count_trigger, count_recovery = 0, 0
for seed in single_preds:
    full_pred = single_preds[seed]
    for convo in corpus.iter_conversations():
        if convo.meta['split'] == 'test' and not convo.meta[label_metadata]:
            utt_list = convo.get_chronological_utterance_list()
            triggered = False
            for utt in utt_list[:-2]:
                id = utt.id
                if full_pred['forecast'][id]:
                    triggered = True
            if triggered:
                count_trigger += 1
                if not full_pred['forecast'][utt_list[-2].id]:
                    count_recovery += 1
print(count_trigger, count_recovery, count_recovery/count_trigger, count_trigger- count_recovery)

862 719 0.834106728538283 143


In [61]:
count_trigger, count_recovery = 0, 0
for seed in single_preds:
    full_pred = single_preds[seed]
    for convo in corpus.iter_conversations():
        if convo.meta['split'] == 'test' and convo.meta[label_metadata]:
            utt_list = convo.get_chronological_utterance_list()
            triggered = False
            for utt in utt_list[:-2]:
                id = utt.id
                if full_pred['forecast'][id]:
                    triggered = True
            if triggered:
                count_trigger += 1
                if not full_pred['forecast'][utt_list[-2].id]:
                    count_recovery += 1
print(count_trigger, count_recovery, count_recovery/count_trigger, count_trigger- count_recovery)

1824 1076 0.5899122807017544 748


In [63]:
recall = 748/1824
precision = 748/(748+143)
print(recall, precision, 2*(recall*precision)/(recall + precision))

0.4100877192982456 0.8395061728395061 0.5510128913443831
