# Initial

In [None]:
from ner_influence.ebm import load_datamodule
from ner_influence.ebm.experiment_utils import dosages 
data = load_datamodule(transformer="google/bigbird-roberta-base")
data._batch_size = 3

In [None]:
conll_key = lambda x: x.id.rsplit("_", 1)[0]
conll_order = lambda x: int(x.id.rsplit("_", 1)[1])

for split in ["train", "validation", "test"]:
    docs = data.combine_to_docs(data[split], key=conll_key, order=conll_order)
    data[f"{split}_docs"] = data.apply_transform([doc for doc in docs.values()] , transform=lambda x:x, retokenize=True)

In [None]:
model_path = "outputs/ebm_docs/simple_trainer/crf:False;seed:2021"

In [None]:
from ner_influence.modelling.trainer import evaluate_ner_model
evaluate_ner_model(data, model_path, "validation_docs", metrics="token")

In [None]:
from ner_influence.modelling.scaffolding import NERTransformerScaffolding

scaffolding = NERTransformerScaffolding(
    data,
    model_path,
    save_outputs=True
)

In [None]:
sum(p.numel() for p in scaffolding.model.parameters() if p.requires_grad) / (1000000)

In [None]:
test_outputs = list(scaffolding.generate_outputs("test_docs", with_feature_vectors=True))
train_outputs = list(scaffolding.generate_outputs("train_docs", with_feature_vectors=True))

In [None]:
from tqdm import tqdm
w = 10

info = {}
int_index = data._label_list.index("INT")
for ex in train_outputs:
    token = ex["tokens"]
    gold_labels = [1 if x == int_index else 0 for x in ex["gold_labels"]]
    pred_labels = [1 if x == int_index else 0 for x in ex["predicted_labels"]]
    for start, end in dosages(ex["tokens"]):
        if any(gold_labels[i] == 1 for i in range(max(0, start - w), min(len(gold_labels), end +  w))):
            is_gold_int = sum(gold_labels[start:end]) > 0
            is_pred_int = sum(pred_labels[start:end]) > 0
            info[(is_gold_int, is_pred_int)] = info.get((is_gold_int, is_pred_int), 0) + 1
            # print(is_gold_int, is_pred_int, " ".join(token[start:end]))

print(info)  

In [None]:
def first_dosage_mispredictions(w=10):
    for ex in test_outputs:
        tokens = ex["tokens"]
        gold_labels = [1 if x == int_index else 0 for x in ex["gold_labels"]]
        pred_labels = [1 if x == int_index else 0 for x in ex["predicted_labels"]]
        for start, end in dosages(tokens):
            w_start, w_end = max(0, start - w), min(len(tokens), end + w)
            if sum(gold_labels[w_start:w_end]) > 0: # there is true intervention nearby so this is likely to be dosage
                is_gold_int = sum(gold_labels[start:end]) > 0
                is_pred_int = sum(pred_labels[start:end]) > 0
                if is_gold_int == False and is_pred_int == True:
                    yield ex, start , end

N = len(list(first_dosage_mispredictions()))
N

In [None]:
def has_dose(sent, w=10) -> list[bool]:
    ints = []
    ints_idx = []
    tokens = sent.tokens
    gold_labels = [1 if x == "INT" else 0 for x in sent.labels]
    for start, end in dosages(tokens):
        w_start, w_end = max(0, start - w), min(len(tokens), end + w)
        if sum(gold_labels[w_start:w_end]) > 0:
            ints.append(sum(gold_labels[start:end]) > 0)
            ints_idx.append((start, end))
            
    return ints, ints_idx

train_dict = {x.id: x for x in data["train_docs"]}

# Instance Attribution

In [None]:
from ner_influence.instance_influence_indexing import InstanceIndexer
indexer = InstanceIndexer(scaffolding, normalize=True)
indexer.create_index("train_docs")
indexer.generate_influence_vectors("test_docs", label_set="gold")

In [None]:
from tqdm import tqdm

neighbors = indexer.batched_search((s["id"] for s, _, _ in first_dosage_mispredictions()), k=3, batch_size=50)

has_supp_dose, has_opp_dose = 0, 0
shows_inconsistency = 0
for sent in tqdm(first_dosage_mispredictions()):
    supps, opps = next(neighbors)
    top_supp, top_opp = supps[0][0], opps[0][0]
    top_supp, top_opp = train_dict[top_supp], train_dict[top_opp]
    supp_tokens, supp_tokens_idx = has_dose(top_supp)
    opp_tokens, opp_tokens_idx = has_dose(top_opp)

    if 0 in supp_tokens:
        has_supp_dose += 1
    if 1 in opp_tokens:
        has_opp_dose += 1
    if 0 in supp_tokens and 1 in opp_tokens:
        shows_inconsistency += 1

print(has_supp_dose, has_opp_dose, shows_inconsistency)
print(has_supp_dose / N, has_opp_dose / N, shows_inconsistency / N)

# Entity Attribution

In [None]:
from ner_influence.np_entity_influence_indexing import NumpyEntityIndexer
indexer = NumpyEntityIndexer(scaffolding, normalize=True)
indexer.create_index("train_docs")
indexer.generate_influence_vectors("test_docs", label_set="gold")

In [None]:
neighbors = indexer.batched_search(((s["id"], end - 1) for s, _, end in first_dosage_mispredictions()), k=3, batch_size=50)

has_supp_dose, has_opp_dose = 0, 0
shows_inconsistency = 0

top_token_has_supp = 0
top_token_has_opp = 0
token_shows_inconsistency = 0

for sent, start, end in tqdm(first_dosage_mispredictions()):
    supps, opps = next(neighbors)
    top_supp, top_opp = supps[0][0], opps[0][0]
    top_supp_token, top_opp_token = supps[0][1], opps[0][1]

    top_supp_sent, top_opp_sent = train_dict[top_supp], train_dict[top_opp]
    
    supp_tokens, supp_tokens_idx = has_dose(top_supp_sent)
    opp_tokens, opp_tokens_idx = has_dose(top_opp_sent)

    top_supp_token_idx = [i for i, (s, e) in enumerate(supp_tokens_idx) if s <= top_supp_token and e > top_supp_token]
    top_opp_token_idx = [i for i, (s, e) in enumerate(opp_tokens_idx) if s <= top_opp_token and e > top_opp_token]

    # if len(top_supp_token_idx) == 0 or len(top_opp_token_idx) == 0:
    #     print(sent["tokens"], 
    #         top_supp_sent.tokens[top_supp_token],  len(top_supp_token_idx),
    #         top_opp_sent.tokens[top_opp_token], len(top_opp_token_idx))

    instance_supp_condition = 0 in supp_tokens
    instance_opp_condition = 1 in opp_tokens
    token_supp_condition = len(top_supp_token_idx) == 1 and supp_tokens[top_supp_token_idx[0]] == 0
    token_opp_condition = len(top_opp_token_idx) == 1 and opp_tokens[top_opp_token_idx[0]] == 1

    if token_supp_condition:
        top_token_has_supp += 1

    if token_opp_condition:
        top_token_has_opp += 1

    if instance_supp_condition:
        has_supp_dose += 1
        
    if instance_opp_condition:
        has_opp_dose += 1
        
    if instance_supp_condition and instance_opp_condition:
        shows_inconsistency += 1

    if token_supp_condition and token_opp_condition:
        token_shows_inconsistency += 1

print(has_supp_dose, has_opp_dose, shows_inconsistency, top_token_has_supp, top_token_has_opp, token_shows_inconsistency)
print(has_supp_dose / N, has_opp_dose / N, shows_inconsistency / N, top_token_has_supp / N, top_token_has_opp / N, token_shows_inconsistency / N)

## Nearest Neighbor Attribution

In [None]:
from ner_influence.nearest_neighbor_indexing import NNIndexer
indexer = NNIndexer(scaffolding, normalize=True)
indexer.create_index("train_docs")
indexer.generate_influence_vectors("test_docs")

In [None]:
neighbors = indexer.batched_search(((s["id"], end - 1) for s, _, end in first_dosage_mispredictions()), k=3, batch_size=50)

has_supp_dose, has_opp_dose = 0, 0
shows_inconsistency = 0

top_token_has_supp = 0
top_token_has_opp = 0
token_shows_inconsistency = 0

for sent, start, end in tqdm(first_dosage_mispredictions()):
    all_n = next(neighbors)
    top_n = [x[0] for x in all_n]

    top_supp = top_n[0][0]
    top_opp = max(top_n[1:], key=lambda x: x[2])[0]
    
    top_supp_token, top_opp_token = top_n[0][1], max(top_n[1:], key=lambda x: x[2])[1]

    top_supp_sent, top_opp_sent = train_dict[top_supp], train_dict[top_opp]
    
    supp_tokens, supp_tokens_idx = has_dose(top_supp_sent)
    opp_tokens, opp_tokens_idx = has_dose(top_opp_sent)

    top_supp_token_idx = [i for i, (s, e) in enumerate(supp_tokens_idx) if s <= top_supp_token and e > top_supp_token]
    top_opp_token_idx = [i for i, (s, e) in enumerate(opp_tokens_idx) if s <= top_opp_token and e > top_opp_token]

    # if len(top_supp_token_idx) == 0 or len(top_opp_token_idx) == 0:
    #     print(sent["tokens"][start:end], 
    #         top_supp_sent.tokens[top_supp_token],  len(top_supp_token_idx),
    #         top_opp_sent.tokens[top_opp_token], len(top_opp_token_idx))

    instance_supp_condition = 0 in supp_tokens
    instance_opp_condition = 1 in opp_tokens
    token_supp_condition = len(top_supp_token_idx) == 1 and supp_tokens[top_supp_token_idx[0]] == 0
    token_opp_condition = len(top_opp_token_idx) == 1 and opp_tokens[top_opp_token_idx[0]] == 1

    if token_supp_condition:
        top_token_has_supp += 1

    if token_opp_condition:
        top_token_has_opp += 1

    if instance_supp_condition:
        has_supp_dose += 1
        
    if instance_opp_condition:
        has_opp_dose += 1
        
    if instance_supp_condition and instance_opp_condition:
        shows_inconsistency += 1

    if token_supp_condition and token_opp_condition:
        token_shows_inconsistency += 1

print(has_supp_dose, has_opp_dose, shows_inconsistency, top_token_has_supp, top_token_has_opp, token_shows_inconsistency)
print(has_supp_dose / N, has_opp_dose / N, shows_inconsistency / N, top_token_has_supp / N, top_token_has_opp / N, token_shows_inconsistency / N)