## Installs

In [None]:
!pip install captum

In [None]:
#!pip install datasets

## Imports

In [None]:
import pandas as pd

In [None]:
from transformers import AutoTokenizer, BertForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import torch
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz

In [None]:
import torch
import torch.nn as nn

In [None]:
import re

In [None]:
from datasets import load_dataset, DatasetDict, Dataset

In [None]:
import os 
directory = "gradients"
parent_dir = "../results"
path = os.path.join(parent_dir, directory) 

## Import Model

In [None]:
import sys
sys.path.append('../model/code-bert/')
from temporal_relation_classification import TemporalRelationClassification
from temporal_relation_classification_config import TemporalRelationClassificationConfig

In [None]:
model_path = "../saved_models/bert-base-uncased-saved-model"
model = TemporalRelationClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [None]:
model.resize_token_embeddings(len(tokenizer))

## Initialise Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model.to(device)

In [None]:
model.eval()

## Import Dataset

In [None]:
original = pd.read_csv('../data/annotated/partitions.csv')
counterfactuals = pd.read_csv('../data/annotated/counterfactuals.csv')

## Compute Gradients Original

In [None]:
import re
def annotate_text(row, column):
    context = row['context']
    eventA = re.escape(row['eventA'])
    eventB = re.escape(row['eventB'])

    # Make sure the longer event is replaced first if they overlap
    if len(eventA) > len(eventB):
        context = re.sub(eventA, f"[a1]{row['eventA']}[/a1]", context)
        context = re.sub(eventB, f"[a2]{row['eventB']}[/a2]", context)
    else:
        context = re.sub(eventB, f"[a2]{row['eventB']}[/a2]", context)
        context = re.sub(eventA, f"[a1]{row['eventB']}[/a1]", context)

    return context

In [None]:
original['annotated_context'] = original.apply(lambda row: annotate_text(row, 'context'), axis=1)

In [None]:
from datasets import Dataset, DatasetDict
dataset = DatasetDict({
    "original": Dataset.from_pandas(original),
})

In [None]:
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
tokenizer_class = str(type(tokenizer)).strip("><'").split('.')[-1]

In [None]:
def preprocess_function(examples):
  max_length_value = 508
  return tokenizer(examples["annotated_context"], truncation=True, max_length=max_length_value)

tokenized_datasets = dataset.map(preprocess_function, batched=True)

In [None]:
def predict(inputs, token_type_ids):
    output = model(inputs, token_type_ids=token_type_ids)
    logits = output[0]
    max_logits = logits.max(dim=1).values
    return max_logits

In [None]:
def create_baseline(input_ids, attention_mask):
    baseline_attention_mask = torch.ones_like(attention_mask).to(device)
    baseline_input_ids = torch.zeros_like(input_ids).to(device)

    special_token_ids = {30522, 30523, 30524, 30525}
    special_tokens_mask = torch.isin(input_ids, torch.tensor(list(special_token_ids), device=device))
    baseline_input_ids[special_tokens_mask] = input_ids[special_tokens_mask]
    return baseline_input_ids, baseline_attention_mask #baseline_token_type_ids

In [None]:
lig = LayerIntegratedGradients(predict, model.bert.embeddings)

In [None]:
def gradient_sensitivity(model, input_ids, attention_mask):
    model.eval()
    baseline_input_ids, baseline_attention_mask = create_baseline(input_ids, attention_mask)

    attributions, delta = lig.attribute(inputs=(input_ids, attention_mask),
                                         baselines=(baseline_input_ids, baseline_attention_mask),
                                         n_steps=10,  
                                         return_convergence_delta=True)

    attributions = attributions.sum(dim=-1).squeeze(0)  
    attributions = attributions / torch.norm(attributions) 

    return attributions  

In [None]:
def tokenisation(text, tokenizer, device='cuda'):
    special_tokens_dict = {'additional_special_tokens': ['[a1]', '[/a1]', '[a2]', '[/a2]']}
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    tokenizer.add_special_tokens(special_tokens_dict)

    encoded_input = tokenizer(text, max_length=508, truncation=True, padding='max_length', return_tensors='pt')
    input_ids = encoded_input['input_ids'].to(device)
    attention_mask = encoded_input.get('attention_mask', None).to(device) if encoded_input.get('attention_mask', None) is not None else None
    token_type_ids = encoded_input.get('token_type_ids', None).to(device) if encoded_input.get('token_type_ids', None) is not None else None
    return input_ids, attention_mask, token_type_ids

In [None]:
gradients_details = []
for index, row in original.iterrows():
    print(index)
    text = row['annotated_context']
    input_ids, attention_mask, token_type_ids = tokenisation(text, tokenizer, device)
    attributions = gradient_sensitivity(model, input_ids, attention_mask)
    attributions_list = attributions.detach().cpu().numpy().tolist()
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    word_attributions = list(zip(tokens, attributions_list))  # Ensure this pairs tokens with their attributions correctly

    gradients_details.append({
        "index": index,
        "text": text,
        "word_attributions": word_attributions
    })


In [None]:
gradients_df = pd.DataFrame(gradients_details)
print(gradients_df)

In [None]:
original['index'] = original.index
gradients_df['index'] =  gradients_df.index
original = original.reset_index(drop=True)
gradients_df['label'] = df_filtered['label']

In [None]:
gradients_df.to_csv('../results/gradients/gradients-og-bert-base.csv', index=False)

## Compute Gradients Counterfactuals

In [None]:
import re
def annotate_text(row, column):
    context = row['counterfactual']
    eventA = re.escape(row['eventA'])
    eventB = re.escape(row['eventB'])

    # Make sure the longer event is replaced first if they overlap
    if len(eventA) > len(eventB):
        context = re.sub(eventA, f"[a1]{row['eventA']}[/a1]", context)
        context = re.sub(eventB, f"[a2]{row['eventB']}[/a2]", context)
    else:
        context = re.sub(eventB, f"[a2]{row['eventB']}[/a2]", context)
        context = re.sub(eventA, f"[a1]{row['eventB']}[/a1]", context)

    return context

In [None]:
counterfactuals['annotated_context'] = counterfactuals.apply(lambda row: annotate_text(row, 'counterfactual'), axis=1)

In [None]:
from datasets import Dataset, DatasetDict
dataset = DatasetDict({
    "counterfactual": Dataset.from_pandas(df_filtered_counterfactuals),
})

In [None]:
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
tokenizer_class = str(type(tokenizer)).strip("><'").split('.')[-1]

In [None]:
def preprocess_function(examples):
  max_length_value = 508
  return tokenizer(examples["annotated_context"], truncation=True, max_length=max_length_value)

tokenized_datasets = dataset.map(preprocess_function, batched=True)

In [None]:
def predict(inputs, token_type_ids):
    output = model(inputs, token_type_ids=token_type_ids)
    logits = output[0]
    max_logits = logits.max(dim=1).values
    return max_logits

In [None]:
def create_baseline(input_ids, attention_mask):
    baseline_attention_mask = torch.ones_like(attention_mask).to(device)
    #baseline_token_type_ids = torch.zeros_like(token_type_ids).to(device)
    baseline_input_ids = torch.zeros_like(input_ids).to(device)

    special_token_ids = {30522, 30523, 30524, 30525}
    special_tokens_mask = torch.isin(input_ids, torch.tensor(list(special_token_ids), device=device))
    baseline_input_ids[special_tokens_mask] = input_ids[special_tokens_mask]
    return baseline_input_ids, baseline_attention_mask #baseline_token_type_ids

In [None]:
lig = LayerIntegratedGradients(predict, model.bert.embeddings)

In [None]:
def gradient_sensitivity(model, input_ids, attention_mask):
    model.eval()
    baseline_input_ids, baseline_attention_mask = create_baseline(input_ids, attention_mask)

    attributions, delta = lig.attribute(inputs=(input_ids, attention_mask),
                                         baselines=(baseline_input_ids, baseline_attention_mask),
                                         n_steps=10,  # More steps for smoother integration
                                         return_convergence_delta=True)
    # Sum over the embedding dimension and normalize
    attributions = attributions.sum(dim=-1).squeeze(0)  # Ensure it reduces to the number of tokens
    attributions = attributions / torch.norm(attributions)  # Normalization

    return attributions  # Convert to list for easier processing later

In [None]:
def tokenisation(text, tokenizer, device='cuda'):
    special_tokens_dict = {'additional_special_tokens': ['[a1]', '[/a1]', '[a2]', '[/a2]']}
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    tokenizer.add_special_tokens(special_tokens_dict)

    encoded_input = tokenizer(text, max_length=508, truncation=True, padding='max_length', return_tensors='pt')
    input_ids = encoded_input['input_ids'].to(device)
    attention_mask = encoded_input.get('attention_mask', None).to(device) if encoded_input.get('attention_mask', None) is not None else None
    token_type_ids = encoded_input.get('token_type_ids', None).to(device) if encoded_input.get('token_type_ids', None) is not None else None
    return input_ids, attention_mask, token_type_ids

In [None]:
gradients_details_counterfactuals = []
for index, row in counterfactuals.iterrows():
    print(index)
    text = row['annotated_context']
    input_ids, attention_mask, token_type_ids = tokenisation(text, tokenizer, device)
    attributions = gradient_sensitivity(model, input_ids, attention_mask)
    attributions_list = attributions.detach().cpu().numpy().tolist()
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    word_attributions = list(zip(tokens, attributions_list))  # Ensure this pairs tokens with their attributions correctly

    gradients_details_counterfactuals.append({
        "index": index,
        "text": text,
        "word_attributions": word_attributions
    })


In [None]:
gradients_counterfactual_df = pd.DataFrame(gradients_details_counterfactuals)
print(gradients_counterfactual_df)

In [None]:
counterfactuals['index'] = counterfactuals.index
gradients_counterfactual_df['index'] =  gradients_counterfactual_df.index
counterfactuals = counterfactuals.reset_index(drop=True)

In [None]:
gradients_counterfactual_df['label'] = df_counter['label']

In [None]:
gradients_counterfactual_df.to_csv('../results/gradients/gradients-cf-bert-base.csv', index=False)

## Visualise Instance

In [None]:
from IPython.display import HTML, display
import pandas as pd

instance = gradients_df.loc[1, 'word_attributions']
tokens, raw_attributions = zip(*instance)
attributions = [ __builtins__.sum(attrs) for attrs in raw_attributions] #/len(attrs)

def visualize_attributions(tokens, attributions):
    attributions = [float(i) for i in attributions] 
    min_attribution = min(attributions)
    max_attribution = max(attributions)
    attributions = [(attr - min_attribution) / (max_attribution - min_attribution) for attr in attributions]

    html_string = "<p><b>Attributions:</b><br>"
    for token, attr in zip(tokens, attributions):
        color = f"{int(300 * (1 - attr))}" 
        html_string += f"<span style='background-color:rgb(255,{color},{color})'>{token}</span> "

    html_string += "</p>"
    display(HTML(html_string))

visualize_attributions(tokens, attributions)