In [36]:
import os
from os import path
import json
import torch

from scipy.stats import spearmanr
import math
from collections import Counter

import glob

In [37]:
base_dir = "/share/data/speech/shtoshni/research/state-probes/models"
seed = 100

In [38]:
model_paths = glob.glob(path.join(base_dir, f"multitask_*seed_{seed}"))

# print(model_paths)

model_names = [path.split(model_path.rstrip("/"))[1] for model_path in model_paths]
model_names = sorted(model_names)
for model_name in model_names:
    print(f'"{model_name}", ')

"multitask_size_base_epochs_100_patience_10_seed_100", 
"multitask_size_base_epochs_100_patience_10_state_0.1_all_text_seed_100", 
"multitask_size_base_epochs_100_patience_10_state_0.1_random_text_seed_100", 
"multitask_size_base_epochs_100_patience_10_state_0.1_targeted_text_seed_100", 
"multitask_size_base_epochs_100_patience_10_state_0.2_all_text_seed_100", 
"multitask_size_base_epochs_100_patience_10_state_0.2_random_text_seed_100", 
"multitask_size_base_epochs_100_patience_10_state_0.2_targeted_text_seed_100", 
"multitask_size_base_epochs_100_patience_10_state_0.3_all_text_seed_100", 
"multitask_size_base_epochs_100_patience_10_state_0.3_random_text_seed_100", 
"multitask_size_base_epochs_100_patience_10_state_0.3_targeted_text_seed_100", 
"multitask_size_base_epochs_100_patience_10_state_0.4_all_text_seed_100", 
"multitask_size_base_epochs_100_patience_10_state_0.4_random_text_seed_100", 
"multitask_size_base_epochs_100_patience_10_state_0.4_targeted_text_seed_100", 
"multitask_s

In [39]:
# print(model_paths)

In [40]:
total_corr_list = []
perplexity_list = []
for model_name in model_names:
    model_path = path.join(base_dir, model_name)
    log_file = path.join(model_path, 'dev.jsonl')
    if path.exists(log_file):        
        print(model_name)
        model = torch.load(path.join(path.join(model_path, "best"), "model.pt"))
        best_val_loss = model["train_info"]["best_val_loss"]
        
        data = json.load(open(log_file))
        
        orig = [0, 0]
        chng = [0, 0]
        for key in data:
            instance = data[key]
            for output in instance["output"]:
                if output["same_as_init"]:
                    orig[1] += 1
                    if output["corr"]:
                        orig[0] += 1
                else:
                    chng[1] += 1
                    if output["corr"]:
                        chng[0] += 1
                        
        print(orig)
        print(chng)
        print(f"{best_val_loss:.3f}")
        
        
        total_corr = orig[0] + chng[0]
        perplexity = math.exp(best_val_loss)
        
        if 'oracle' not in model_name:
            total_corr_list.append(total_corr)
            perplexity_list.append(perplexity)

multitask_size_base_epochs_100_patience_10_state_0.1_all_text_seed_100
[4294, 4302]
[5, 2558]
1.083
multitask_size_base_epochs_100_patience_10_state_0.1_random_text_seed_100
[4049, 4302]
[1015, 2558]
1.070
multitask_size_base_epochs_100_patience_10_state_0.1_targeted_text_seed_100
[3880, 4302]
[115, 2558]
1.076
multitask_size_base_epochs_100_patience_10_state_0.2_all_text_seed_100
[4297, 4302]
[13, 2558]
1.085
multitask_size_base_epochs_100_patience_10_state_0.2_random_text_seed_100
[4095, 4302]
[1257, 2558]
1.074
multitask_size_base_epochs_100_patience_10_state_0.2_targeted_text_seed_100
[4030, 4302]
[837, 2558]
1.079
multitask_size_base_epochs_100_patience_10_state_0.3_all_text_seed_100
[4298, 4302]
[139, 2558]
1.084
multitask_size_base_epochs_100_patience_10_state_0.3_random_text_seed_100
[3724, 4302]
[1337, 2558]
1.078
multitask_size_base_epochs_100_patience_10_state_0.3_targeted_text_seed_100
[3864, 4302]
[1206, 2558]
1.066
multitask_size_base_epochs_100_patience_10_state_0.4_all_

In [41]:
print(total_corr_list)
print(perplexity_list)
spearmanr(total_corr_list, perplexity_list)

[4299, 5064, 3995, 4310, 5352, 4867, 4437, 5061, 5070, 4712, 5337, 5275, 4854, 5705, 4838, 4647, 5742, 5369, 5660, 5449, 4830, 5820, 5415, 5506, 5760, 5087, 5926, 5857, 5064]
[2.952305592428037, 2.9148687810569305, 2.9341117577149083, 2.9587728258159367, 2.9278556493913936, 2.9408890704681667, 2.9579420984611566, 2.937825058155326, 2.903395200784267, 2.9628267951937586, 2.9340183615045063, 2.928060342579809, 2.9600942208294607, 2.896477867115539, 2.928828532393884, 2.9629714061367856, 2.917062703279541, 2.91736969827477, 2.9115276445268004, 2.900587496407705, 2.9749525936088457, 2.8996187376632387, 2.910548202184257, 2.9595443601241707, 2.9664389738700767, 2.9368290969803215, 2.9619736935424426, 2.887121319832514, 2.931524005221447]


SpearmanrResult(correlation=-0.48084739864669007, pvalue=0.008279017247201098)

In [48]:
perplexity_list = []
mrr_list = []

model_names = [
    "size_base_epochs_100_patience_10_seed_100", 
    "size_base_epochs_100_patience_10_state_0.1_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.1_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.1_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.2_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.2_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.2_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.3_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.3_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.3_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.4_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.4_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.4_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.5_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.5_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.5_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.6_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.6_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.6_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.7_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.7_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.7_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.8_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.8_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.8_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.9_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.9_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.9_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_1.0_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_1.0_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_1.0_targeted_text_seed_100", 
]

model_names = ['multitask_' + model_name for model_name in model_names]

for model_name in model_names:
    model_path = path.join(base_dir, model_name)
    
    model = torch.load(path.join(path.join(model_path, "best"), "model.pt"))
    best_val_loss = model["train_info"]["best_val_loss"]
        
    
    log_file = path.join(model_path, 'dev.jsonl')
    cloze_mrr_file = path.join(model_path, 'cloze_mrr.txt')
    
#     if path.exists(log_file):        
    cloze_mrr = float(open(cloze_mrr_file).read())
    perplexity = math.exp(best_val_loss)

    perplexity_list.append(perplexity)
    mrr_list.append(cloze_mrr)
#         print(cloze_mrr)

print(mrr_list)
print(perplexity_list)


print(len(perplexity_list))
print(max(mrr_list))

[0.146, 0.134, 0.123, 0.132, 0.138, 0.135, 0.145, 0.135, 0.172, 0.165, 0.134, 0.166, 0.146, 0.128, 0.184, 0.134, 0.14, 0.158, 0.129, 0.136, 0.17, 0.138, 0.128, 0.191, 0.15, 0.176, 0.158, 0.123, 0.184, 0.188, 0.16]
[2.9540370418473385, 2.952305592428037, 2.9148687810569305, 2.9341117577149083, 2.9587728258159367, 2.9278556493913936, 2.9408890704681667, 2.9579420984611566, 2.937825058155326, 2.903395200784267, 2.9628267951937586, 2.9340183615045063, 2.928060342579809, 2.9600942208294607, 2.896477867115539, 2.928828532393884, 2.9629714061367856, 2.917062703279541, 2.91736969827477, 2.956082782074605, 2.9115276445268004, 2.900587496407705, 2.9749525936088457, 2.8996187376632387, 2.910548202184257, 2.9595443601241707, 2.9664389738700767, 2.9368290969803215, 2.9619736935424426, 2.887121319832514, 2.931524005221447]
31
0.191


In [49]:
print(spearmanr(perplexity_list, mrr_list))
print(min(perplexity_list))
print(max(mrr_list))


SpearmanrResult(correlation=-0.31708567322498643, pvalue=0.08219697379057851)
2.887121319832514
0.191


In [50]:
print(sorted(perplexity_list))

[2.887121319832514, 2.896477867115539, 2.8996187376632387, 2.900587496407705, 2.903395200784267, 2.910548202184257, 2.9115276445268004, 2.9148687810569305, 2.917062703279541, 2.91736969827477, 2.9278556493913936, 2.928060342579809, 2.928828532393884, 2.931524005221447, 2.9340183615045063, 2.9341117577149083, 2.9368290969803215, 2.937825058155326, 2.9408890704681667, 2.952305592428037, 2.9540370418473385, 2.956082782074605, 2.9579420984611566, 2.9587728258159367, 2.9595443601241707, 2.9600942208294607, 2.9619736935424426, 2.9628267951937586, 2.9629714061367856, 2.9664389738700767, 2.9749525936088457]


In [51]:
mrr_list = []
total_corr_list = []
perplexity_list = []

model_names = [
    "size_base_epochs_100_patience_10_seed_100", 
    "size_base_epochs_100_patience_10_state_0.1_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.1_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.1_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.2_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.2_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.2_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.3_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.3_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.3_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.4_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.4_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.4_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.5_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.5_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.5_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.6_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.6_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.6_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.7_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.7_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.7_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.8_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.8_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.8_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.9_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.9_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_0.9_targeted_text_seed_100", 
    "size_base_epochs_100_patience_10_state_1.0_all_text_seed_100", 
    "size_base_epochs_100_patience_10_state_1.0_random_text_seed_100", 
    "size_base_epochs_100_patience_10_state_1.0_targeted_text_seed_100", 
]

model_names = ['multitask_' + model_name for model_name in model_names]

for model_name in model_names:
    model_path = path.join(base_dir, model_name)
#     if 'random' in model_name:
#         continue
    model = torch.load(path.join(path.join(model_path, "best"), "model.pt"))
    best_val_loss = model["train_info"]["best_val_loss"]
    log_file = path.join(model_path, 'dev.jsonl')
    if path.exists(log_file):        
        data = json.load(open(log_file))        
        orig = [0, 0]
        chng = [0, 0]
        for key in data:
            instance = data[key]
            for output in instance["output"]:
                if output["same_as_init"]:
                    orig[1] += 1
                    if output["corr"]:
                        orig[0] += 1
                else:
                    chng[1] += 1
                    if output["corr"]:
                        chng[0] += 1
                        
#         print(orig)
#         print(chng)
#         print(f"{best_val_loss:.3f}")
        
        
        total_corr = orig[0] + chng[0]
        total_corr_list.append(total_corr)
    
        cloze_mrr_file = path.join(model_path, 'cloze_mrr.txt')

    #     if path.exists(log_file):        
        cloze_mrr = float(open(cloze_mrr_file).read())
        perplexity = math.exp(best_val_loss)

        perplexity_list.append(best_val_loss)
        mrr_list.append(cloze_mrr)
#         print(cloze_mrr)

print(total_corr_list)
print(mrr_list)
print(sorted(total_corr_list))

[4299, 5064, 3995, 4310, 5352, 4867, 4437, 5061, 5070, 4712, 5337, 5275, 4854, 5705, 4838, 4647, 5742, 5369, 5660, 5449, 4830, 5820, 5415, 5506, 5760, 5087, 5926, 5857, 5064]
[0.134, 0.123, 0.132, 0.138, 0.135, 0.145, 0.135, 0.172, 0.165, 0.134, 0.166, 0.146, 0.128, 0.184, 0.134, 0.14, 0.158, 0.129, 0.17, 0.138, 0.128, 0.191, 0.15, 0.176, 0.158, 0.123, 0.184, 0.188, 0.16]
[3995, 4299, 4310, 4437, 4647, 4712, 4830, 4838, 4854, 4867, 5061, 5064, 5064, 5070, 5087, 5275, 5337, 5352, 5369, 5415, 5449, 5506, 5660, 5705, 5742, 5760, 5820, 5857, 5926]


In [45]:
spearmanr(total_corr_list, mrr_list)

SpearmanrResult(correlation=0.653225214268353, pvalue=0.0001221771248509123)

## Egregious Errors

In [46]:
# # model_names = [
# # #     'size_base_epochs_100_patience_10_state_0.1_targeted_text_seed_60',
# #     'size_base_epochs_100_patience_10_state_0.5_targeted_text_seed_70',

# #     'size_base_epochs_100_patience_10_state_0.1_all_text_seed_70',
# #     'size_base_epochs_100_patience_10_state_0.25_all_text_seed_70',
# #     'size_base_epochs_100_patience_10_state_0.5_all_text_seed_70',
# #     'size_base_epochs_100_patience_10_state_0.75_all_text_seed_70',

# # #     'size_base_epochs_100_patience_10_state_0.1_random_text_seed_60',
# # #     'size_base_epochs_100_patience_10_state_0.25_random_text_seed_60',
# # #     'size_base_epochs_100_patience_10_state_0.5_random_text_seed_60',
# # #     'size_base_epochs_100_patience_10_state_0.75_random_text_seed_60',
# # ]


# for model_name in model_names:
#     model_path = path.join(base_dir, model_name)
#     log_file = path.join(model_path, 'dev.jsonl')
#     if path.exists(log_file):        
# #         print(model_name)
#         model = torch.load(path.join(path.join(model_path, "best"), "model.pt"))
#         best_val_loss = model["train_info"]["best_val_loss"]
        
#         data = json.load(open(log_file))
        
#         error_count_dict = Counter()
#         for key in data:
#             instance = data[key]
#             error_count = 0
#             for output in instance["output"]:
#                 if not output["corr"]:
#                     error_count += 1
            
# #             if error_count == 6:
# #                 print(instance)
#             error_count_dict[error_count] += 1
            
# #         print(error_count_dict)
# #         print()
