In [1]:
import ast

import numpy as np
import pandas as pd
import scipy.stats as stats
import matplotlib.pyplot as plt

from statsmodels.stats.contingency_tables import mcnemar                                                           

In [2]:
def compute_power(prob_table, dataset_size, alpha=0.05, r=10000):
    print("Dataset size: ", dataset_size)
    if prob_table[0, 1] == prob_table[1, 0]:
        raise RuntimeError("Power is undefined when the true effect is zero.")

    pvals = []
    diffs = []
    for i in range(r):  # number of simulations
        sample = np.random.multinomial(n=dataset_size, pvals=prob_table.reshape((4,))).reshape((2,2))
        acc_diff = (sample[0,1] - sample[1, 0]) / dataset_size
        test_results = mcnemar(sample)
        pvals.append(test_results.pvalue)
        diffs.append(acc_diff)

    true_diff = prob_table[0, 1] - prob_table[1, 0]
    true_sign = np.sign(true_diff) 
    sig_diffs = [d for i, d in enumerate(diffs) if pvals[i] <= alpha]
    power = len([d for i, d in enumerate(diffs) if pvals[i] <= alpha and np.sign(d) == true_sign]) / r
    mean_effect = np.mean(diffs)
    type_m = np.mean(np.abs(sig_diffs) / np.abs(true_diff))
    type_s = np.mean(np.sign(sig_diffs) != true_sign)
    return power, mean_effect, type_m, type_s


In [3]:
with open("../data/original_datasets/NEG-136-SIMP.txt") as f:
    raw_data = [d.strip(",\n") for d in f.readlines()]


data_orig = []
for line in raw_data:
    y = line.split(" ")[-1]
    x = line.strip(y).strip(" ")
    is_negative = True if "not" in line else False
    data_orig.append((x, y, is_negative))


with open("../data/NEG-1500-SIMP-TEMP.txt") as f:
    raw_data = [d.strip(",\n") for d in f.readlines()]


data_extended = []
for line in raw_data:
    y = line.split(" ")[-1]
    x = line.strip(y).strip(" ")
    is_negative = True if "not" in line else False
    data_extended.append((x, y, is_negative))

In [4]:
model = "bert-large-uncased"
print("Model: ", model)

with open(f"../predictions/NEG-136-SIMP/{model}.txt") as f:
    preds_orig = [ast.literal_eval(d.strip("\n"))[:5] for d in f.readlines()]

# Affirmative accuracy
def get_is_correct(label, predictions, is_negative):
    if not is_negative:
        return label in predictions
    return label not in predictions

def is_affirmative(data_point):
    return "not" not in data_point[0]

is_correct_orig = [
    get_is_correct(data_orig[i][1], d, data_orig[i][2])
    for i, d in enumerate(preds_orig) if is_affirmative(data_orig[i])
]

print("Orig arrifmative accuracy: ", np.mean(is_correct_orig))

with open(f"../predictions/NEG-1500-SIMP-TEMP/{model}.txt") as f:
    preds_extended = [d.strip("\n") for d in f.readlines()]

is_correct_extended = [
    get_is_correct(data_extended[i][1], d, data_extended[i][2])
    for i, d in enumerate(preds_extended) if is_affirmative(data_extended[i])
]

print("Extended arrifmative accuracy: ", np.mean(is_correct_extended))
print("-" * 40)

model2 = "roberta-base"
print("Model: ", model2)

with open(f"../predictions/NEG-136-SIMP/{model2}.txt") as f:
    preds_orig2 = [ast.literal_eval(d.strip("\n"))[:5] for d in f.readlines()]

is_correct_orig2 = [
    get_is_correct(data_orig[i][1], d, data_orig[i][2])
    for i, d in enumerate(preds_orig2) if is_affirmative(data_orig[i])
]

print("Orig arrifmative accuracy: ", np.mean(is_correct_orig2))

with open(f"../predictions/NEG-1500-SIMP-TEMP/{model2}.txt") as f:
    preds_extended2 = [d.strip("\n") for d in f.readlines()]

is_correct_extended2 = [
    get_is_correct(data_extended[i][1], d, data_extended[i][2])
    for i, d in enumerate(preds_extended2) if is_affirmative(data_extended[i])
]

print("Extended arrifmative accuracy: ", np.mean(is_correct_extended2))
print("-" * 40)
# Model agreement

agrees_orig = [
    is_correct_orig[i] == is_correct_orig2[i]
    for i in range(len(is_correct_orig))
]
agrees_extended = [
    is_correct_extended[i] == is_correct_extended2[i]
    for i in range(len(is_correct_extended))
]

print("Orig agreement: ", np.mean(agrees_orig))
print("Extended agreement: ", np.mean(agrees_extended))

print("-" * 40)
print("Original dataset powers:")

# p_both_correct, p_only_1_correct, p_only_2_correct, p_both_incorrect
p_both_correct = np.mean([a and b for a, b in zip(is_correct_orig, is_correct_orig2)])
p_only_1_correct = np.mean([a and not b for a, b in zip(is_correct_orig, is_correct_orig2)])
p_only_2_correct = np.mean([not a and b for a, b in zip(is_correct_orig, is_correct_orig2)])
p_both_incorrect = np.mean([not a and not b for a, b in zip(is_correct_orig, is_correct_orig2)])

for p in [p_both_correct, p_only_1_correct, p_only_2_correct, p_both_incorrect]:
    assert p >= 0

prob_table = np.array([[p_both_incorrect, p_only_2_correct], [p_only_1_correct, p_both_correct]]) 
print("Probability table:")
print(prob_table)
print("acc1 = {:.3f}".format(prob_table[1, :].sum()))
print("acc2 = {:.3f}".format(prob_table[:, 1].sum()))

power, mean_effect, type_m, type_s = compute_power(prob_table, len(is_correct_orig))

print("Approx power = {:.3f}".format(power))
print("Approx Type-M error = {:.3f}".format(type_m))
print("Approx Type-S error = {:.3f}".format(type_s))

print("-" * 40)
print("Extended dataset powers:")

# p_both_correct, p_only_1_correct, p_only_2_correct, p_both_incorrect
p_both_correct = np.mean([a and b for a, b in zip(is_correct_extended, is_correct_extended2)])
p_only_1_correct = np.mean([a and not b for a, b in zip(is_correct_extended, is_correct_extended2)])
p_only_2_correct = np.mean([not a and b for a, b in zip(is_correct_extended, is_correct_extended2)])
p_both_incorrect = np.mean([not a and not b for a, b in zip(is_correct_extended, is_correct_extended2)])

for p in [p_both_correct, p_only_1_correct, p_only_2_correct, p_both_incorrect]:
    assert p >= 0

prob_table = np.array([[p_both_incorrect, p_only_2_correct], [p_only_1_correct, p_both_correct]])
print("Probability table:")
print(prob_table)
print("acc1 = {:.3f}".format(prob_table[1, :].sum()))
print("acc2 = {:.3f}".format(prob_table[:, 1].sum()))

power, mean_effect, type_m, type_s = compute_power(prob_table, len(is_correct_extended))
print("Approx power = {:.3f}".format(power))
print("Approx Type-M error = {:.3f}".format(type_m))
print("Approx Type-S error = {:.3f}".format(type_s))

Model:  bert-large-uncased
Orig arrifmative accuracy:  1.0
Extended arrifmative accuracy:  0.8
----------------------------------------
Model:  roberta-base
Orig arrifmative accuracy:  0.9444444444444444
Extended arrifmative accuracy:  0.7389610389610389
----------------------------------------
Orig agreement:  0.9444444444444444
Extended agreement:  0.8636363636363636
----------------------------------------
Original dataset powers:
Probability table:
[[0.         0.        ]
 [0.05555556 0.94444444]]
acc1 = 1.000
acc2 = 0.944
Dataset size:  36
Approx power = 0.013
Approx Type-M error = 3.152
Approx Type-S error = 0.000
----------------------------------------
Extended dataset powers:
Probability table:
[[0.16233766 0.03766234]
 [0.0987013  0.7012987 ]]
acc1 = 0.800
acc2 = 0.739
Dataset size:  770
Approx power = 0.995
Approx Type-M error = 1.002
Approx Type-S error = 0.000
