## Installs

In [None]:
!pip install captum

In [None]:
#!pip install datasets

## Imports

In [None]:
import pandas as pd

In [None]:
import torch
import torch.nn as nn
from torch.nn.functional import softmax

In [None]:
from captum.attr import Occlusion
from captum.attr import visualization as viz

In [None]:
from transformers import AutoTokenizer, BertForSequenceClassification
from transformers import BertTokenizer, BertModel

In [None]:
import re

In [None]:
import itertools

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import os 
directory = "occlusion"
parent_dir = "results"
path = os.path.join(parent_dir, directory) 

## Import Model

In [None]:
import sys
sys.path.append('model/code-bert/')
from temporal_relation_classification import TemporalRelationClassification
from temporal_relation_classification_config import TemporalRelationClassificationConfig

In [None]:
model_path = "saved_models/bert-base-uncased-saved-model"
model = TemporalRelationClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [None]:
model.resize_token_embeddings(len(tokenizer))

## Initialise Model

In [None]:
model.to(device)

In [None]:
model.eval()

## Import Dataset

In [None]:
original = pd.read_csv('data/annotated/partitions.csv')
counterfactual = pd.read_csv('data/annotated/counterfactuals.csv')

## Process Original

In [None]:
label_mapping = {
    'BEFORE': 0,
    'AFTER': 1,
    'EQUAL': 2,
    'VAGUE': 3
}

In [None]:
def tokenize(text):
    return tokenizer(text, max_length=508, truncation=True, return_tensors='pt')

def annotate_text(row):
    context = row['context']
    eventA = re.escape(row['eventA'])
    eventB = re.escape(row['eventB'])

    # Make sure the longer event is replaced first if they overlap
    if len(eventA) > len(eventB):
        context = re.sub(eventA, f"[a1]{row['eventA']}[/a1]", context)
        context = re.sub(eventB, f"[a2]{row['eventB']}[/a2]", context)
    else:
        context = re.sub(eventB, f"[a2]{row['eventB']}[/a2]", context)
        context = re.sub(eventA, f"[a1]{row['eventA']}[/a1]", context)

    return context

In [None]:
original['label'] = original['label_temp'].map(label_mapping)

In [None]:
original['annotated_context'] = original.apply(annotate_text, axis=1)

In [None]:
def tokenize_and_extract_ids(text):
    max_length_value = 508
    inputs = tokenizer(text, truncation=True, max_length=max_length_value, return_tensors="pt")
    return inputs['input_ids'][0].tolist()

required_ids = {30522, 30523, 30524, 30525}
def check_required_ids(input_ids, required_ids):
    input_set = set(input_ids)
    return required_ids.issubset(input_set)

In [None]:
# Apply the function to each row in the DataFrame and create a new column
original['input_ids'] = original['annotated_context'].apply(tokenize_and_extract_ids)

In [None]:
def string_to_list(s):
    return [token.strip() for token in s.strip('{}').split(',')]

original['pos_partition'] = original['pos_partition'].apply(string_to_list)

In [None]:
def calculate_neg_partition(row):
    context_tokens = set(row['context'].split())
    pos_tokens = set(row['pos_partition'])
    neg_tokens = context_tokens - pos_tokens
    return neg_tokens

original['neg_partition'] = original.apply(calculate_neg_partition, axis=1)
original['neg_partition'] = original['neg_partition'].apply(list)

## Process Counterfactuals

In [None]:
label_mapping = {
    'BEFORE': 0,
    'AFTER': 1,
    'EQUAL': 2,
    'VAGUE': 3
}

In [None]:
def tokenize(text):
    return tokenizer(text, max_length=508, truncation=True, return_tensors='pt')

def annotate_text(row):
    context = row['counterfactual']
    eventA = re.escape(row['eventA'])
    eventB = re.escape(row['eventB'])

    # Make sure the longer event is replaced first if they overlap
    if len(eventA) > len(eventB):
        context = re.sub(eventA, f"[a1]{row['eventA']}[/a1]", context)
        context = re.sub(eventB, f"[a2]{row['eventB']}[/a2]", context)
    else:
        context = re.sub(eventB, f"[a2]{row['eventB']}[/a2]", context)
        context = re.sub(eventA, f"[a1]{row['eventA']}[/a1]", context)

    return context

In [None]:
counterfactual['label'] = counterfactual['new_temp'].map(label_mapping)

In [None]:
counterfactual['annotated_context'] = counterfactual.apply(annotate_text, axis=1)

In [None]:
def tokenize_and_extract_ids(text):
    max_length_value = 508
    inputs = tokenizer(text, truncation=True, max_length=max_length_value, return_tensors="pt")
    return inputs['input_ids'][0].tolist()

counterfactual['input_ids'] = counterfactual['annotated_context'].apply(tokenize_and_extract_ids)

In [None]:
def string_to_list(s):
    return [token.strip() for token in s.strip('{}').split(',')]

counterfactual['pos_partition'] = counterfactual['pos_partition'].apply(string_to_list)

## Occlusion Function

In [None]:
def occlusion_sensitivity(model, input_ids, attention_mask, tokenizer):
    model.eval()

    with torch.no_grad():
        original_logits = model(input_ids=input_ids, attention_mask=attention_mask)[0]
        original_probs = torch.softmax(original_logits, dim=-1)

    attention = []
    special_token_ids = [30522, 30523, 30524, 30525]  # Adjust as needed
    mask_token_id = tokenizer.convert_tokens_to_ids('[MASK]')
    for index in range(1, input_ids.size(1)):  # Adjust if you need to skip different tokens
        if input_ids[0, index].item() in special_token_ids:
            continue
        occluded_input_ids = input_ids.clone()
        occluded_input_ids[0, index] = mask_token_id  # Ensure device match

        with torch.no_grad():
            occluded_logits = model(input_ids=occluded_input_ids, attention_mask=attention_mask)[0]
            occluded_probs = torch.softmax(occluded_logits, dim=-1)

        prob_change = torch.abs(original_probs - occluded_probs)
        attention.append(prob_change.cpu().numpy().tolist())  # Move back to CPU if necessary for further processing

    return attention

In [None]:
def tokenisation(text, tokenizer, device='cuda'):
    special_tokens_dict = {'additional_special_tokens': ['[a1]', '[/a1]', '[a2]', '[/a2]']}
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    tokenizer.add_special_tokens(special_tokens_dict)

    encoded_input = tokenizer(text, max_length=508, truncation=True, padding='max_length', return_tensors='pt')
    input_ids = encoded_input['input_ids'].to(device)
    attention_mask = encoded_input.get('attention_mask', None).to(device) if encoded_input.get('attention_mask', None) is not None else None
    token_type_ids = encoded_input.get('token_type_ids', None).to(device) if encoded_input.get('token_type_ids', None) is not None else None
    return input_ids, attention_mask, token_type_ids


## Calculate Occlusion Original

In [None]:
occlusion_details_original = []
for index, row in original_filtered.iterrows():
    print(index)
    text = row['annotated_context']
    input_ids, attention_mask, token_type_ids = tokenisation(text, tokenizer, device)
    attributions = occlusion_sensitivity(model, input_ids, attention_mask, tokenizer)
    flattened_attributions = list(itertools.chain.from_iterable(attributions))

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    word_attributions = list(zip(tokens, flattened_attributions))

    occlusion_details_original.append({
        "index": index,
        "text": text,
        "word_attributions": word_attributions
    })

In [None]:
occlusion_original_df = pd.DataFrame(occlusion_details_original)
print(occlusion_original_df)

In [None]:
original['index'] = original.index
occlusion_original_df['index'] =  occlusion_original_df.index
original = original.reset_index(drop=True)

In [None]:
occlusion_original_df['label'] = original['label']
print(occlusion_original_df)

In [None]:
occlusion_original_df.to_csv('results/occlusion/occlusion-og-bert-base.csv', index=False)

## Calculate Occlusion Counterfactuals

In [None]:
occlusion_details_counterfactuals = []
for index, row in counterfactuals.iterrows():
    print(index)
    text = row['annotated_context']
    input_ids, attention_mask, token_type_ids = tokenisation(text, tokenizer, device)
    attributions = occlusion_sensitivity(model, input_ids, attention_mask, tokenizer)
    flattened_attributions = list(itertools.chain.from_iterable(attributions))

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    word_attributions = list(zip(tokens, flattened_attributions))

    occlusion_details_original.append({
        "index": index,
        "text": text,
        "word_attributions": word_attributions
    })

In [None]:
occlusion_counter_df = pd.DataFrame(occlusion_details_original)
print(occlusion_counter_df)

In [None]:
counterfactuals['index'] = df_counter_fintered.index
occlusion_counter_df['index'] =  occlusion_counter_df.index
counterfactuals = counterfactuals.reset_index(drop=True)

In [None]:
occlusion_counter_df['label'] = counterfactuals['label']
print(occlusion_counter_df)

In [None]:
occlusion_counter_df.to_csv('results/occlusion/occlusion-cf-bert-base.csv', index=False)

In [None]:
import ast
occlusion_counter_df['word_attributions'] = occlusion_counter_df['word_attributions'].apply(lambda x: ast.literal_eval(x.replace("('[", "(\"[").replace("]',", "]\",").replace(")']", ")]")))

## Visualise Instance

In [None]:
from IPython.display import HTML, display
import pandas as pd

instance = occlusion_counter_df.loc[159, 'word_attributions']
tokens, raw_attributions = zip(*instance)
attributions = [sum(attrs)/len(attrs) for attrs in raw_attributions]

def visualize_attributions(tokens, attributions):
    attributions = [float(i) for i in attributions]  
    min_attribution = min(attributions)
    max_attribution = max(attributions)

    html_string = "<p><b>Attributions:</b><br>"
    for token, attr in zip(tokens, attributions):
        intensity = int(1500 * abs(attr) / max(abs(min_attribution), abs(max_attribution)))
        color = f"rgb(255,{1500-intensity},{1-intensity})" if attr < 0 else f"rgb({255-intensity},255,{255-intensity})"
        html_string += f"<span style='background-color:{color}; padding: 0 2px;'>{token}</span> "

    html_string += "</p>"
    display(HTML(html_string))

visualize_attributions(tokens, attributions)