# Imports and Constants

In [None]:
import pandas as pd
import json
import random
import torch
import numpy as np
from itertools import chain

from scipy import spatial

from plotly.tools import FigureFactory as ff
import plotly.graph_objects as go

from sklearn.dummy import DummyClassifier

from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel, XLMRobertaModel, AutoModelForMaskedLM
from transformers import AutoTokenizer, AutoModel

import torch.nn.functional as F


from tqdm.notebook import tqdm as tq
from torch.utils.data import Dataset, DataLoader

In [None]:
MASCULINE = "Masc"
FEMININE = "Fem"


PLURAL = "Plur"
SINGULAR = "Sing"

# Plots setup

In [None]:
muted_greys = ['rgb(255,255,255)', 'rgb(240,240,240)', 'rgb(225,225,225)', 'rgb(210,210,210)', 'rgb(195,195,195)', 'rgb(180,180,180)', 'rgb(165,165,165)', 'rgb(150,150,150)', 'rgb(135,135,135)']


In [None]:
plot_layout_dict = {
'autosize': False,
'plot_bgcolor': 'rgba(255,255,255,0.5)',
'width': 672,
'height': 672,
'font_family': "DejaVu Serif",
'font_size': 30, 
'legend_font_size': 28,
'legend': dict(
    orientation="h",
    yanchor="top",
    xanchor="right",
),
'xaxis': dict(showgrid= False),
'yaxis': dict(showgrid= True),
}

In [None]:
masc_color = '#E1BE6A'
fem_color = '#40B0A6'
dark_blue_color = '#332288'

# Initialize tasks' parameters (gender/number)

In [None]:
def define_task_parameters(task):
    before = after = extra = label_1 = label_2 = label_1_short = label_2_short = ""
    sample_rate = 1.
    prefix = "../data/processed/spa/"
    ds_names = {}
    ds_names['train_ds_1'] = prefix + "train-ancora-{}.csv".format(task)
    ds_names['dev_ds_1'] = prefix + "dev-ancora-{}.csv".format(task)
    ds_names['test_ds_1'] = prefix + "test-ancora-{}.csv".format(task)

    ds_names['train_ds_2'] = prefix + "train-gsd-{}.csv".format(task)
    ds_names['dev_ds_2'] = prefix + "dev-gsd-{}.csv".format(task)
    ds_names['test_ds_2'] = prefix + "test-gsd-{}.csv".format(task)

    if task == "gender":
        before = "before gender"
        after = "after gender"
        extra = "number"
        label_1 = MASCULINE
        label_2 = FEMININE
        label_1_short = "M"
        label_2_short = "F"


    elif task == "number":
        before = "before number"
        after = "after number"
        extra = "gender"
        label_1 = SINGULAR
        label_2 = PLURAL
        label_1_short = "S"
        label_2_short = "P"
        sample_rate = .5


    elif task == "test":
        before = "before gender"
        after = "after gender"
        extra = "number"
        label_1 = SINGULAR
        label_2 = PLURAL
        label_1_short = "M"
        label_2_short = "F"


    return before, after, extra, label_1, label_2, label_1_short, label_2_short, sample_rate, ds_names

In [None]:
task = "gender"

In [None]:
before, after, extra, label_1, label_2, label_1_short, label_2_short, sample_rate, ds_names = define_task_parameters(task)

setting up the device

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# Loading The Dataset

Let's look at the paired dataset

In [None]:
df = pd.read_csv(ds_names['train_ds_1'])
df = df.append(pd.read_csv(ds_names['dev_ds_1']))
df = df.append(pd.read_csv(ds_names['test_ds_1']))
df = df.append(pd.read_csv(ds_names['train_ds_2']))
df = df.append(pd.read_csv(ds_names['dev_ds_2']))
df = df.append(pd.read_csv(ds_names['test_ds_2']))

df['adj idx'] = df['adj idx'].fillna(-1)
df['adj idx'] = df['adj idx'].astype(int)

df['det idx'] = df['det idx'].fillna(-1)
df['det idx'] = df['det idx'].astype(int)


df_target = df.sample(frac=sample_rate, replace=False, random_state=1)
df_target.head()

also check how does the templated dataset look like

In [None]:
templates_df = pd.read_csv("../data/manual/paired-templates.csv")
templates_df.head()

In [None]:
class MorphoDataset(Dataset):
    def __init__(self, csv_files, root_dir, cached_df=False):
        if cached_df:
            self.morpho_df = csv_files
        else:
            for idx, csv_file in enumerate(csv_files):
                if idx == 0:
                    self.morpho_df = pd.read_csv(csv_file)
                else:
                    self.morpho_df = self.morpho_df.append(pd.read_csv(csv_file))
                self.root_dir = root_dir
                if sample_rate < 1.:
                    self.morpho_df = self.morpho_df.sample(frac=sample_rate, replace=False, random_state=1)

    def get_label_sent(self, row, idx):
        labels = [row[before], row[after]]
        sents = [row['before'], row['after']]
        deprels = [row['deprel'], row['deprel']]
        extras = [row[extra], row[extra]]
        focus_ids = [row['focus ID'], row['focus ID after']]
        indices = [idx, idx]

        return labels, sents, deprels, focus_ids, indices, extras


    def __len__(self):
        return len(self.morpho_df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        row = self.morpho_df.iloc[idx]

        labels, sents, deprels, focus_ids, indices, extras = self.get_label_sent(row, idx)
        sample = {'idx': indices, 'sents': sents, 'labels': labels, 'deprels': deprels, 'focus_ids': focus_ids, extra: extras}

        return sample

# Extract Contextual Representations

baseline to compare the representations to representations extracted from a random model

In [None]:
import transformers

In [None]:
def randomize_model(model):
    for module_ in model.named_modules():
        if isinstance(module_[1],(torch.nn.Linear, torch.nn.Embedding, transformers.modeling_utils.Conv1D)):
            module_[1].weight.data.normal_(mean=0.0, std=model.config.initializer_range)
        elif isinstance(module_[1], torch.nn.LayerNorm):
            module_[1].bias.data.zero_()
            module_[1].weight.data.fill_(1.0)
        if isinstance(module_[1], (torch.nn.Linear, transformers.modeling_utils.Conv1D)) and module_[1].bias is not None:
            module_[1].bias.data.zero_()
    return model

## mBERT

In [None]:
from transformers import BertModel

tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertModel.from_pretrained('bert-base-multilingual-cased', output_hidden_states=True).to(device)
random_model = BertModel(model.config).to(device)
embed_len = 768

## XLM-RoBERTa

In [None]:
from transformers import XLMRobertaModel
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
model = XLMRobertaModel.from_pretrained('xlm-roberta-base', output_hidden_states=True).to(device)
random_model = XLMRobertaModel(model.config).to(device)
embed_len = 768

In [None]:
embed_len = 768

## GPT-2

In [None]:
from transformers import AutoTokenizer, GPT2Model
  
tokenizer = AutoTokenizer.from_pretrained("datificate/gpt2-small-spanish", add_prefix_space=True)

model = GPT2Model.from_pretrained("datificate/gpt2-small-spanish").to(device)

In [None]:
# random_model = AutoModel.from_pretrained('datificate/gpt2-small-spanish').to(device)
# random_model = randomize_model(random_model)
random_model = GPT2Model(model.config).to(device)

In [None]:
def extract_features(tokenizer, model, gender_dataloader, batch_size, total_len, paired=True, token_level=True, cls=False, is_gpt2=False):
  if paired:
    iteration_step = batch_size * 2
    vec_shape = (total_len, 2)
    embeddings = np.zeros((total_len, 2, embed_len))
  else:
    iteration_step = batch_size
    vec_shape = (total_len)
    embeddings = np.zeros((total_len, embed_len))

  labels = np.empty(vec_shape, dtype = str)
  feats = {"deprels": [], "focus_ids": np.empty(vec_shape, dtype = int), 
           "idx": np.empty(vec_shape, dtype = int), extra: np.empty(vec_shape, dtype = str), "multi_token": np.full(vec_shape, False)}

  for i_batch, sample_batched in tq(enumerate(gender_dataloader), total=total_len/batch_size):
    if paired:
      batched_sents = list(chain.from_iterable(zip((sample_batched['sents'][0]), list(sample_batched['sents'][1]))))
      batched_sents = [tuple(s.split(" ")) for s in batched_sents]
      labels[i_batch*batch_size:(i_batch+1)*batch_size,] = np.hstack((np.asarray(sample_batched['labels'][0]).reshape(-1, 1),
                                                                      np.asarray(sample_batched['labels'][1]).reshape(-1, 1)))
    else:
      batched_sents = sample_batched['sents']
      batched_sents = [tuple(s.split(" ")) for s in batched_sents]
      labels[i_batch*batch_size:(i_batch+1)*batch_size,] = sample_batched['labels']


    for k, v in feats.items():
      if k == 'multi_token':
        continue
      if paired:
        if k == "deprels":
          feats[k].append(np.hstack((np.asarray(sample_batched[k][0]).reshape(-1, 1),
                                                                          np.asarray(sample_batched[k][1]).reshape(-1, 1))))
        feats[k][i_batch*batch_size:(i_batch+1)*batch_size] = np.hstack((np.asarray(sample_batched[k][0]).reshape(-1, 1),
                                                                          np.asarray(sample_batched[k][1]).reshape(-1, 1)))
      else:
        feats[k][i_batch*batch_size:(i_batch+1)*batch_size] = np.asarray(sample_batched[k])

    inputs = tokenizer(batched_sents, padding=True, is_split_into_words=True, return_tensors="pt").to(device)
    
    with torch.no_grad():
      outputs = model(**inputs)
    last_hidden_states = outputs.last_hidden_state.cpu().numpy()
#       last_hidden_states = outputs.hidden_states[-1].cpu().numpy()


    if cls:
      if is_gpt2:
        cls_idx = (torch.ne(inputs['input_ids'], tokenizer.pad_token_id).sum(-1) - 1).cpu().numpy()
        cls_idx = np.repeat(cls_idx, embed_len).reshape(-1, embed_len)
        cls_idx = np.expand_dims(cls_idx, axis=1)
        cls_embeddings = np.take_along_axis(last_hidden_states, cls_idx, axis=1)
        cls_embeddings = np.squeeze(cls_embeddings, axis=1)

      else:
        cls_embeddings = outputs.pooler_output.cpu().numpy()

    if token_level:
      zero_based_idx = 0
      for i in range(i_batch*batch_size, i_batch*batch_size + int(last_hidden_states.shape[0] / 2)):
        if cls:
          if paired:
            embeddings[i, 0] = cls_embeddings[2*zero_based_idx, :]
            embeddings[i, 1] = cls_embeddings[2*zero_based_idx+1, :]
          else:
            embeddings[i] = cls_embeddings[zero_based_idx, :]

        else:
          if paired:
            start_idx_0 = inputs.word_to_tokens(2*zero_based_idx,feats['focus_ids'][i, 0])[0]
            end_idx_0 = inputs.word_to_tokens(2*zero_based_idx,feats['focus_ids'][i, 0])[1]
            start_idx_1 = inputs.word_to_tokens(2*zero_based_idx+1,feats['focus_ids'][i, 1])[0]
            end_idx_1 = inputs.word_to_tokens(2*zero_based_idx+1,feats['focus_ids'][i, 1])[1]
            
            if (start_idx_0 + 1) != end_idx_0:
              feats['multi_token'][i, 0] = True
            if (start_idx_1 + 1) != end_idx_1:
              feats['multi_token'][i, 1] = True

            embeddings[i, 0] = np.mean(last_hidden_states[2*zero_based_idx, start_idx_0:end_idx_0, :], axis=0)
            embeddings[i, 1] = np.mean(last_hidden_states[2*zero_based_idx+1, start_idx_1:end_idx_1, :], axis=0)
          else:
            start_idx = inputs.word_to_tokens(zero_based_idx,feats['focus_ids'][i])[0]
            end_idx = inputs.word_to_tokens(zero_based_idx,feats['focus_ids'][i])[1]
            if (start_idx + 1) != end_idx:
              feats['multi_token'][i] = True
            embeddings[i] = np.mean(last_hidden_states[zero_based_idx, start_idx:end_idx], axis=0)
        zero_based_idx += 1
    
    else:
      embeddings[i_batch*iteration_step:(i_batch+1)*iteration_step] = mean_pooling(outputs, inputs['attention_mask']).cpu().numpy()

  return embeddings, labels, feats

In [None]:
def create_balance_df(csv_list):
    # make a balanced version of the dataset based on gender/number 
    # this was used as a sanity check that the imbalance in the data
    # doesn't affect the results of the estimators
    
    balanced_df = pd.read_csv(csv_list[0])
    for i in range(len(csv_list) - 1):
        balanced_df = balanced_df.append(pd.read_csv(csv_list[i+1]))
        
    fem_df = balanced_df[balanced_df[before] == label_2]
    masc_df = balanced_df[balanced_df[after] == label_1]
    masc_df = masc_df.sample(n = len(fem_df))
    balanced_df = fem_df.append(masc_df)
    
    return balanced_df.sample(frac=.5)

## for paired estimator
create a matrix of paired representations (before/after the intervention)

In [None]:
batch_size = 8
morpho_dataset = MorphoDataset([ds_names['train_ds_1'], ds_names['dev_ds_1']], '/', cached_df=False)
# Add ds_names['test_ds_1'], ds_names['train_ds_2'], ds_names['dev_ds_2'], ds_names['test_ds_2'] to the list for the whole data
morpho_dataloader = DataLoader(morpho_dataset, batch_size=batch_size)

total_len = len(morpho_dataset)
print(total_len)
embeddings, labels, feats = extract_features(tokenizer, model, morpho_dataloader, batch_size, total_len, cls=False, is_gpt2=False)
r_embeddings, r_labels, r_feats = extract_features(tokenizer, random_model, morpho_dataloader, batch_size, total_len, cls=False, is_gpt2=False)




In [None]:
batch_size = 8
morpho_dataset = MorphoDataset([ds_names['train_ds_2'], ds_names['dev_ds_2']], '/', cached_df=False)
morpho_dataloader = DataLoader(morpho_dataset, batch_size=batch_size)

total_len = len(morpho_dataset)
print(total_len)
gsd_embeddings, gsd_labels, gsd_feats = extract_features(tokenizer, model, morpho_dataloader, batch_size, total_len, cls=False, is_gpt2=False)



In [None]:
feats['deprels'] = np.asarray(feats['deprels'])
# dev_feats['deprels'] = np.asarray(dev_feats['deprels'])

## for template-based estimator

In [None]:
templates_df["deprel"] = [""]*len(templates_df)
templates_df["number"] = [""]*len(templates_df)
templates_df.to_csv("../data/manual/paired-templates.csv", index=False)

In [None]:
batch_size = 8
morpho_dataset = MorphoDataset(["../data/manual/paired-templates.csv"], '/')
morpho_dataloader = DataLoader(morpho_dataset, batch_size=batch_size)

total_len = len(morpho_dataset)
print(total_len)
t_embeddings, t_labels, t_feats = extract_features(tokenizer, model, morpho_dataloader, batch_size, total_len, cls=False, is_gpt2=False)

**balanced dataset**

In [None]:
batch_size = 8
balanced_dataset = create_balance_df([ds_names['train_ds_2'], ds_names['dev_ds_2']])
morpho_dataset = MorphoDataset(balanced_dataset, '/', cached_df=True)
morpho_dataloader = DataLoader(morpho_dataset, batch_size=batch_size)

total_len = len(morpho_dataset)
print(total_len)
b_embeddings, b_labels, b_feats = extract_features(tokenizer, model, morpho_dataloader, batch_size, total_len, cls=False, is_gpt2=False)

# ATE Calculation 

In [None]:
print(t_labels.shape)
print(t_embeddings.shape)
print(t_labels)

In [None]:
def calc_diff_mat(embeddings, labels):
    # given the matrix with embeddings before and after the intervention
    # calculates the difference matrix, i.e. if the size of the input is
    # (n, d, 2) the output size would be (n, d, 1)
    
    _, c = np.unique(labels, return_inverse=True)
    c = c.reshape(-1, 2, 1)
    c = 2*c - 1


    embeddings_gender = c * embeddings
    diff_mat = embeddings_gender[:, 0, :] + embeddings_gender[:, 1, :]

    return diff_mat

In [None]:
def filter_embeddings_roles(embeddings, lables, target):
    # can be used to single out sentence pairs with a focus noun of a specific role: subj, obj, etc.
    embed_new = embeddings[feats['deprels'] == target].reshape(-1, 2, embed_len)
    label_new = labels[feats['deprels'] == target].reshape(-1, 2)
    return embed_new, label_new

In [None]:
def calculate_causal_effect_paired(embeddings, labels):
    # paired estimator
    diff_mat = calc_diff_mat(embeddings, labels)
    return np.mean(diff_mat, axis = 0)
 
def calculate_causal_effect_naive(embeddings, labels):
    # naive estimator
    return np.mean(embeddings[labels == label_1_short], axis=0) - np.mean(embeddings[labels == label_2_short], axis=0)


In [None]:
diff_mat = calc_diff_mat(embeddings, labels)
diff_mat_random = calc_diff_mat(r_embeddings, r_labels)

**effect size**

In [None]:
def calc_effect_size(embeddings, labels):
    mean_1 = np.mean(embeddings[labels == label_1_short], axis=0)
    mean_2 = np.mean(embeddings[labels == label_2_short], axis=0)

    var_1 = np.var(embeddings[labels == label_1_short], axis=0, ddof=1)
    var_2 = np.var(embeddings[labels == label_2_short], axis=0, ddof=1)

    s = np.sqrt(0.5*var_1 + 0.5*var_2)
    cod = (mean_1 - mean_2) / np.sqrt(np.dot(s, s))
    print(np.sqrt(np.dot(cod, cod)))

In [None]:
print(calc_effect_size(embeddings, labels))
print(calc_effect_size(r_embeddings, r_labels))

## Plot ATE estimators similarity

In [None]:
ate_vectors = np.zeros((5, embed_len))
ate_vectors[0, :] = calculate_causal_effect_paired(embeddings, labels)
ate_vectors[1, :] = calculate_causal_effect_paired(gsd_embeddings, gsd_labels)

ate_vectors[2, :] = calculate_causal_effect_naive(embeddings[:, 0, :], labels[:, 0])
ate_vectors[3, :] = calculate_causal_effect_naive(gsd_embeddings[:, 0, :], gsd_labels[:, 0])

# ate_vectors[4, :] = calculate_causal_effect_paired(b_embeddings, b_labels)
ate_vectors[4, :] = calculate_causal_effect_paired(t_embeddings, t_labels)

In [None]:
import plotly.express as px
print(px.colors.sequential.Greys)

In [None]:
def visualize_cos_matrix(column_names, ate_vectors, plot_name):
    n = len(column_names)
    cross_cosine_matrix = np.zeros((n, n))
    for i in range(n): 
        for j in range(n):
            cross_cosine_matrix[i, j] = 1 - spatial.distance.cosine(ate_vectors[i], ate_vectors[j])

    z_text = np.around(cross_cosine_matrix, decimals=2)
    fig = px.imshow(z_text, color_continuous_scale=muted_greys, x=column_names, y=column_names, text_auto=True, aspect='auto')
    fig.layout.coloraxis.showscale = False

    # fig = ff.create_annotated_heatmap(z=z_text, colorscale=muted_greys,x=column_names, y=column_names)
    fig.update_layout(plot_layout_dict)
    fig.update_layout(
    width=550,
    height=550,
    )
    # 'width': 800,
    fig.show()
    fig.write_image(plot_name)
    

In [None]:
import plotly.express as px

column_names = ["P.AnCora",
                "P.GSD", 
                "B.AnCora", 
                "B.GSD",
#                 "P.Balanced",
                "T",
                ]

visualize_cos_matrix(column_names, ate_vectors, "cos-ate-{}.pdf".format(task))




## Calculate ATE on Adjectives and Determiners 

In [None]:
def calc_other_ates(df, num_rows, idx_col_name, idx_col_after=None):
    det_embeddings = np.zeros((num_rows, 2, embed_len))
    det_labels = np.zeros((num_rows, 2), dtype='<U1')
    
    if idx_col_after is None:
        idx_col_after = idx_col_name

    count = 0
    
    for idx, row in tq(df.iterrows(), total=len(df)):
        idx_before = row[idx_col_name]
        idx_after = adjust_index(row["focus ID"], row["focus ID after"], idx_before)
        before_row = row["before"].split()
        after_row = row["after"].split()
    
        
        inputs = tokenizer([row["before"], row["after"]], return_tensors="pt", padding=True).to(device)
        if inputs.word_to_tokens(1, row[idx_col_name]) is None:
            print(row)
            print(count)
        end_idx = inputs.word_to_tokens(0, row[idx_col_name])[1]
        end_idx_after = inputs.word_to_tokens(1, row[idx_col_after])[1]
        

        with torch.no_grad():
            outputs = model(**inputs)
        
        last_hidden_state = outputs.hidden_states[-1].cpu().numpy()
        det_embeddings[count, 0, :] = last_hidden_state[0][end_idx-1]
        det_embeddings[count, 1, :] = last_hidden_state[1][end_idx_after-1]
        det_labels[count, 0] = row[before]
        det_labels[count, 1] = row[after]
        count += 1

    print(count)
    return det_embeddings, det_labels

In [None]:
def adjust_index(focus_idx, focus_idx_after, adj_idx):
    if focus_idx != focus_idx_after:
        return adj_idx + focus_idx_after - focus_idx
    return adj_idx

df_det = df_target[df_target["det idx"] != -1]
df_adj = df_target[df_target["adj idx"] != -1]
det_embeddings, det_labels = calc_other_ates(df_det, len(df_det), "det idx")
adj_embeddings, adj_labels = calc_other_ates(df_adj, len(df_adj), "adj idx")
embeddings, labels = calc_other_ates(df_target, len(df_target), "focus ID", "focus ID after")

In [None]:
ate_vectors = torch.zeros((6, embed_len))

ate_vectors[0, :] = torch.tensor(calculate_causal_effect_paired(embeddings, labels)).float().to(device)
ate_vectors[3, :] = torch.tensor(calculate_causal_effect_naive(embeddings[:, 0, :], labels[:, 0])).float().to(device)

ate_vectors[2, :] = torch.tensor(calculate_causal_effect_paired(det_embeddings, det_labels)).float().to(device)
ate_vectors[5, :] = torch.tensor(calculate_causal_effect_naive(det_embeddings[:, 0, :], det_labels[:, 0])).float().to(device)

ate_vectors[1, :] = torch.tensor(calculate_causal_effect_paired(adj_embeddings, adj_labels)).float().to(device)
ate_vectors[4, :] = torch.tensor(calculate_causal_effect_naive(adj_embeddings[:, 0, :], adj_labels[:, 0])).float().to(device)


column_names = ["P.Focus",
                "P.Adj", 
                "P.Det",
                "N.Focus", 
                "N.Adj",
                "N.Det"
                ]

plot_name = "cos-roles-{}.pdf".format(task)


In [None]:
visualize_cos_matrix(column_names, ate_vectors, plot_name)

# PCA Analysis

In [None]:
np.mean(embeddings, axis=1).shape

In [None]:
from sklearn.decomposition import PCA

def get_pca(embeddings):
    mean_d = np.mean(embeddings, axis=1)
    mean_d_extended = np.repeat(mean_d[:, np.newaxis, :], 2, axis=1)
    new_diff_mat = embeddings - mean_d_extended
    new_diff_mat = new_diff_mat.reshape(-1, new_diff_mat.shape[-1]) 
    pca = PCA(n_components=50)
    pca.fit(0.7071*new_diff_mat)

    return pca, pca.singular_values_

In [None]:
import plotly.graph_objects as go

def plot_pca_variance(embeddings, r_embeddings, exp_ratio, n_iter, output_file):
    pca, pca_s_values = get_pca(embeddings)
    r_pca, r_pca_s_values = get_pca(r_embeddings)

    fig = go.Figure(data=[
        go.Bar(x=list(range(1, 11)), y=r_pca.explained_variance_ratio_, marker_color="grey", opacity=0.7, name=r'Random Model'),
        go.Bar(x=list(range(1, 11)), y=pca.explained_variance_ratio_, marker_color=masc_color, opacity=0.7, name=r'$\Large\Delta_{\text{AnCora}}$'),
        go.Bar(x=list(range(1, 11)), y=exp_ratio / n_iter, marker_color=fem_color, name='Random Matrix'),
    ])
    fig.update_layout(xaxis_title="Axis number", yaxis_title="Explained Variance", barmode='overlay')
    fig.update_layout(plot_layout_dict)
    fig.show()
    fig.write_image(output_file)
    return pca_s_values

In [None]:
exp_ratio = np.zeros(embed_len)
n_iter = 100

In [None]:
import plotly.graph_objects as go
exp_ratio = np.zeros(embed_len)
n_iter = 100
for i in tq(range(n_iter)):
    pca = PCA()
    pca.fit(np.random.random(diff_mat.shape))
    exp_ratio += pca.explained_variance_ratio_

In [None]:
pca_s_values = plot_pca_variance(embeddings, r_embeddings, exp_ratio, n_iter, "pca-var-{}-mbert.pdf".format(task))

## Projection Test

In [None]:
pca, _ = get_pca(embeddings)
ate = calculate_causal_effect_paired(embeddings, labels)
1 - spatial.distance.cosine(pca.components_[0], ate)

In [None]:
def calc_projections(embeddings, pca, labels, feats, sample_size):
  dev_ate = calculate_causal_effect_paired(embeddings, labels)
  shuffle_index = np.random.randint(0, high=embeddings.shape[0], size=sample_size)
  
  pca_proj_x = np.dot(embeddings[shuffle_index,:], pca.components_[0])
  pca_proj_y = np.dot(embeddings[shuffle_index,:], pca.components_[1])
  ate_proj_x = np.dot(dev_ate, pca.components_[0])
  ate_proj_y = np.dot(dev_ate, pca.components_[1])

  new_feats = dict()
  for k, v in feats.items():
    new_feats[k] = feats[k][shuffle_index]

  limited_labels = labels[shuffle_index]
  return pca_proj_x, pca_proj_y, ate_proj_x, ate_proj_y, limited_labels, new_feats

In [None]:
def get_row_info(idx, i, j, label, dev_df, pca_proj_x, pca_proj_y):
  row = dev_df.iloc[idx]
  
  gender = row[before]
  row_id = row["ID"]
  text_before = row['before'].split(" ")
  text_after = row['after'].split(" ")
  focus_id_before = int(row['focus ID'])
  focus_id_after = int(row['focus ID after'])
  if label == gender[0]:
    focus_text = text_before[focus_id_before]
    complete_sent = row['before']
    is_counterfactual = False
  else:
    focus_text = text_after[focus_id_after]
    complete_sent = row['after']
    is_counterfactual = True

  return row_id, focus_text, complete_sent, is_counterfactual

In [None]:
def visualize_swap_gender_vecs(pca_proj_x, pca_proj_y, ate_proj_x, ate_proj_y, limited_labels, hover_text_masc, hover_text_fem, title, output_file):
  fig = go.Figure()
  _, label_mask = np.unique(limited_labels, return_inverse=True)
  label_mask = label_mask.reshape(-1, 2)
  label_mask = label_mask.astype(bool)

  fig.add_trace(go.Scatter(
      x=pca_proj_x[label_mask],
      y= [0] * len(pca_proj_x[label_mask]), #pca_proj_y[label_mask]
      mode='markers',
      name=label_1,
      hovertemplate ='<b>%{text}</b>',
      text=hover_text_masc,
      marker_size=15,
      marker_color=masc_color
  ))
  fig.add_trace(go.Scatter(
      x=pca_proj_x[np.logical_not(label_mask)],
      y=[0]*len(pca_proj_x[np.logical_not(label_mask)]), #pca_proj_y[np.logical_not(label_mask)]
      mode='markers',
      name=label_2,
      hovertemplate ='<b>%{text}</b>',
      text=hover_text_fem,
      marker_size=15,
      marker_symbol='x',
      marker_color=fem_color
  ))

  fig.update_xaxes(showgrid=False)
  fig.update_yaxes(showgrid=False, 
                 zeroline=True, zerolinecolor='black', zerolinewidth=3,
                 showticklabels=False)
  fig.update_layout(plot_layout_dict)
  fig.update_layout(height=200)
  fig.update_layout(width=800)

#   fig.update_laxyout(title=title)
  fig.show()
  fig.write_image(output_file)

In [None]:
def get_hover_texts(dev_indices, dev_df, dev_labels, pca_proj_x, pca_proj_y):
  hover_text_masc = []
  hover_text_fem = []

  for i in range(dev_labels.shape[0]):
    for j in range(2):
      row_id, focus_text, _, _ = get_row_info(dev_indices[i, j], i, j, dev_labels[i, j], dev_df, pca_proj_x, pca_proj_y)
      if dev_labels[i, j] == label_1_short:
        hover_text_masc.append("ID: {}, Noun: {}, {} {}".format(row_id, focus_text, task, dev_labels[i, j]))
      else:
        hover_text_fem.append("ID: {}, Noun: {}, {} {}".format(row_id, focus_text, task, dev_labels[i, j]))
  return hover_text_masc, hover_text_fem

In [None]:
sample_size = 20
df = pd.read_csv(ds_names['train_ds_1'])
df = df.append(pd.read_csv(ds_names['dev_ds_1']))
df = df.append(pd.read_csv(ds_names['test_ds_1']))
df = df.append(pd.read_csv(ds_names['train_ds_2']))
df = df.append(pd.read_csv(ds_names['dev_ds_2']))
df = df.append(pd.read_csv(ds_names['test_ds_2']))


output_file = "proj-{}-mbert.pdf".format(task)

pca, _ = get_pca(embeddings)
pca_proj_x, pca_proj_y, ate_proj_x, ate_proj_y, limited_labels, limited_feats = calc_projections(embeddings, pca, labels, feats, sample_size)
hover_text_masc, hover_text_fem = get_hover_texts(limited_feats['idx'], df, limited_labels, pca_proj_x, pca_proj_y)
visualize_swap_gender_vecs(pca_proj_x, pca_proj_y, ate_proj_x, ate_proj_y, limited_labels, hover_text_masc, hover_text_fem, "", output_file)

# Train Gender/Number Probes

In [None]:
df_list = []
baselines = []


orig_all = []
orig_all_focus = []
aug_all = []
aug_all_focus = []

In [None]:
from sklearn.utils import shuffle
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
from sklearn.model_selection import cross_validate
from sklearn.metrics import make_scorer
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

def train_models(embeddings, labels):
  original_training_size = embeddings.shape[0]
  x, y = shuffle(embeddings[:, 0, :], labels[:, 0], random_state=0)
  x_counter, y_counter = shuffle(embeddings.reshape(-1, embed_len), labels.reshape(-1,), random_state=0)
  x_counter = x_counter[:original_training_size, :]
  y_counter = y_counter[:original_training_size]

  clf_counter = LogisticRegression(max_iter=1000, class_weight="balanced", random_state=0).fit(x_counter, y_counter)
  clf = LogisticRegression(max_iter=1000, class_weight="balanced", random_state=0).fit(x, y)
  # clf_projected = LogisticRegression(max_iter=100, class_weight="balanced", random_state=0).fit(get_projected(embeddings, labels, x), y)

  
  clf_svm = make_pipeline(StandardScaler(), SVC(gamma='auto'))
  clf_svm.fit(x, y)

  clf_svm_counter = make_pipeline(StandardScaler(), SVC(gamma='auto'))
  clf_svm_counter.fit(x_counter, y_counter)

  clf_dummy = DummyClassifier(strategy='most_frequent', random_state=0)
  clf_dummy.fit(x, y)

  return x, y, x_counter, y_counter, clf, clf_counter, clf_dummy, clf_svm, clf_svm_counter

Performance of the naive classifier

In [None]:
def log_train_test_report(df_list, baselines, x, y, x_counter, y_counter, clf, clf_counter, clf_dummy, clf_svm, clf_svm_counter, embeddings, labels, dev_embeddings, dev_labels, model_name, number, rep):

  y_dev_pred = clf_dummy.predict(dev_embeddings.reshape(-1, embed_len))
  orig=[]
  aug=[]
  baselines.append({"Name": "Majority", "Model": model_name, "Number": number, "Representation": rep, "Accuracy": accuracy_score(dev_labels.reshape(-1,), y_dev_pred)})


  # print("**Test Performance**")
  y_dev_pred = clf.predict(dev_embeddings[:, 0, :])
  acc= accuracy_score(dev_labels[:, 0], y_dev_pred)
  df_list.append({'Probe' : "LogRegProbe", "Dataset": "Original Dataset", "Model": model_name, "Number": number, "Representation": rep, "Accuracy": acc})
  orig.append(acc)


  # print("**Test Performance on CD**")
  y_dev_paired_pred = clf.predict(dev_embeddings.reshape(-1, embed_len))
  acc =  accuracy_score(dev_labels.reshape(-1,), y_dev_paired_pred)
  df_list.append({'Probe' : "LogRegProbe", "Dataset": "Augmented Dataset", "Model": model_name, "Number": number, "Representation": rep, "Accuracy": acc})
  aug.append(acc)

  # print("**Test Performance of CM on CD**")
  y_dev_pred = clf_counter.predict(dev_embeddings.reshape(-1, embed_len))
  baselines.append({"Name": "Causal-LogReg", "Model": model_name, "Number": number, "Representation": rep, "Accuracy": accuracy_score(dev_labels.reshape(-1,), y_dev_pred)})


  # print("**Test Performance SVM**")
  y_dev_pred = clf_svm.predict(dev_embeddings[:, 0, :])
  acc = accuracy_score(dev_labels[:, 0], y_dev_pred)
  df_list.append({'Probe' : "SVMProbe", "Dataset": "Original Dataset", "Model": model_name, "Number": number, "Representation": rep, "Accuracy": acc})
  orig.append(acc)


  # print("**Test Performance SVM on CD**")
  y_dev_paired_pred = clf_svm.predict(dev_embeddings.reshape(-1, embed_len))
  acc = accuracy_score(dev_labels.reshape(-1,), y_dev_paired_pred)
  df_list.append({'Probe' : "SVMProbe", "Dataset": "Augmented Dataset", "Model": model_name, "Number": number, "Representation": rep, "Accuracy": acc})
  aug.append(acc)

  # print("**Test Performance of SVM CM on CD**")
  y_dev_pred = clf_svm_counter.predict(dev_embeddings.reshape(-1, embed_len))
  baselines.append({"Name": "Causal-SVM", "Model": model_name, "Number": number, "Representation": rep, "Accuracy": accuracy_score(dev_labels.reshape(-1,), y_dev_pred)})


  return orig, aug

In [None]:
x, y, x_counter, y_counter, clf, clf_counter, clf_dummy, clf_svm, clf_svm_counter = train_models(embeddings, labels)
o_a, a_a = log_train_test_report(df_list, baselines, x, y, x_counter, y_counter, clf, clf_counter, clf_dummy, clf_svm, clf_svm_counter, embeddings, labels, dev_embeddings, dev_labels, 'GPT-2', 'All', 'CLS')
print(df_list)
print(baselines)

In [None]:
df_df = pd.DataFrame(data=df_list)
df_df

In [None]:
df_base = pd.DataFrame(data=baselines)
df_base

In [None]:
import pickle
with open('df-list-{}.pickle'.format(task), 'wb') as handle:
    pickle.dump(df_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('baselines-{}.pickle'.format(task), 'wb') as handle:
    pickle.dump(baselines, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
with open('df-list-{}.pickle'.format(task), 'rb') as handle:
    df_list = pickle.load(handle)

with open('baselines-{}.pickle'.format(task), 'rb') as handle:
    baselines = pickle.load(handle)

In [None]:
orig_mix = list(df_df[df_df["Dataset"] == "Original Dataset"]["Accuracy"].values)
aug_mix = list(df_df[df_df["Dataset"] == "Augmented Dataset"]["Accuracy"].values)

In [None]:
causal = list(df_base[df_base["Name"] != "Majority"]["Accuracy"].values)

## Plot probes accuracy

In [None]:
def plot_probe_acc(orig_focus, aug_focus, orig, aug, output_file):

  import plotly.graph_objects as go

  fig = go.Figure()

  fig.add_trace(go.Scatter(
    y = [['mBERT', 'mBERT', 'mBERT', 'mBERT', 'XLM-RoBERTa', 'XLM-RoBERTa', 'XLM-RoBERTa', 'XLM-RoBERTa', 'GPT-2', 'GPT-2', 'GPT-2', 'GPT-2'],
        ["LogRegProbe-Focus", "SVMProbe-Focus", "LogRegProbe-CLS", "SVMProbe-CLS", "LogRegProbe-Focus", "SVMProbe-Focus", "LogRegProbe-CLS", "SVMProbe-CLS", "LogRegProbe-Focus", "SVMProbe-Focus", "LogRegProbe-CLS", "SVMProbe-CLS"]],
    x = orig_mix,
    marker_size=15,
    mode='markers',
    marker_color=fem_color,
    name = "Original Dataset",
  ))

  fig.add_trace(go.Scatter(
  y = [['mBERT', 'mBERT', 'mBERT', 'mBERT', 'XLM-RoBERTa', 'XLM-RoBERTa', 'XLM-RoBERTa', 'XLM-RoBERTa', 'GPT-2', 'GPT-2', 'GPT-2', 'GPT-2'],
        ["LogRegProbe-Focus", "SVMProbe-Focus", "LogRegProbe-CLS", "SVMProbe-CLS", "LogRegProbe-Focus", "SVMProbe-Focus", "LogRegProbe-CLS", "SVMProbe-CLS", "LogRegProbe-Focus", "SVMProbe-Focus", "LogRegProbe-CLS", "SVMProbe-CLS"]],
    x = aug_mix,
    marker_size=15,
    mode='markers',
    marker_color=masc_color,
    name = "Augmented Dataset",
  ))
    
  fig.add_trace(go.Scatter(
  y = [['mBERT', 'mBERT', 'mBERT', 'mBERT', 'XLM-RoBERTa', 'XLM-RoBERTa', 'XLM-RoBERTa', 'XLM-RoBERTa', 'GPT-2', 'GPT-2', 'GPT-2', 'GPT-2'],
        ["LogRegProbe-Focus", "SVMProbe-Focus", "LogRegProbe-CLS", "SVMProbe-CLS", "LogRegProbe-Focus", "SVMProbe-Focus", "LogRegProbe-CLS", "SVMProbe-CLS", "LogRegProbe-Focus", "SVMProbe-Focus", "LogRegProbe-CLS", "SVMProbe-CLS"]],
    x = causal,
    marker_size=15,
    mode='markers',
    marker_color="grey",
    name = "Causal",
  ))

  for i in range(len(orig_mix)):
    fig.add_shape(type="line",
        y0=i, x0=orig_mix[i],
        y1=i, x1=aug_mix[i],
        layer="below",
        line=dict(
            # dash="dot",
            color="silver",
            width=10,),
    )
  
  for i in range(len(causal)):
    fig.add_shape(type="line",
        y0=i, x0=causal[i],
        y1=i, x1=orig_mix[i],
        layer="below",
        line=dict(
            # dash="dot",
            color="gray",
            width=10,),
    )


  fig.update_xaxes(title_text = "Accuracy")


  fig.update_layout(plot_layout_dict)
  fig.update_layout({"width": 1700})
  fig.update_layout({"height": 1000})
  fig.write_image(output_file)
  fig.show()


In [None]:
plot_probe_acc(orig_all_focus, aug_all_focus, orig_all, aug_all, "probes-{}.pdf".format(task))

# Causal Effect on Token Predictions [6.2 and 6.3 in the paper]

Initialize the model

In [None]:
model = AutoModelForMaskedLM.from_pretrained("xlm-roberta-base", output_hidden_states=True).to(device)
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")

In [None]:
list(model.state_dict().keys())

In [None]:
wt1 = model.state_dict()['lm_head.dense.weight']
bt1 = model.state_dict()['lm_head.dense.bias']

wt2 = model.state_dict()['lm_head.layer_norm.weight']
bt2 = model.state_dict()['lm_head.layer_norm.bias']

B = model.state_dict()['lm_head.bias']

w = model.state_dict()['lm_head.decoder.weight']
b = model.state_dict()['lm_head.decoder.bias']


# Uncomment if GPT-2
# wt1 = model.state_dict()['cls.predictions.transform.dense.weight']
# bt1 = model.state_dict()['cls.predictions.transform.dense.bias']

# wt2 = model.state_dict()['cls.predictions.transform.LayerNorm.weight']
# bt2 = model.state_dict()['cls.predictions.transform.LayerNorm.bias']

# B = model.state_dict()['cls.predictions.bias']

# w = model.state_dict()['cls.predictions.decoder.weight']
# b = model.state_dict()['cls.predictions.decoder.bias']

In [None]:
model = model
model.config

In [None]:
def KL(a, b):
    a = np.asarray(a, dtype=np.float)
    b = np.asarray(b, dtype=np.float)

    return np.sum(np.where(a != 0, a * np.log(a / b), 0))

def JS(a, b):
    return 0.5*KL(a, b) + 0.5*KL(b, a)

In [None]:
def intervene_with_ate(hidden_states, ate, start_idx, gender):
    if gender == label_1:
        h = hidden_states[-1][1, start_idx, :] - ate
    else:
        h = hidden_states[-1][1, start_idx, :] + ate
        
#     Uncomment if GPT-2    
#     h_transformed = torch.matmul(h, wt1.T) + bt1 
#     h_transformed = F.gelu(h_transformed)
#     h_transformed = F.layer_norm(h_transformed, (h_transformed.shape[1],), wt2, bt2, eps=1e-5)
#     h_transformed = torch.matmul(h_transformed, w.T) + b
    
    h_transformed = torch.matmul(wt1, h) + bt1 
    h_transformed = F.gelu(h_transformed)
    h_transformed = F.layer_norm(h_transformed, h_transformed.shape, wt2, bt2, eps=1e-05)
    h_transformed = torch.matmul(w, h_transformed) + b
    
    return h_transformed

In [None]:
def get_token_preds(h_transformed, h_transformed_naive):
    arg_max_manual = h_transformed.argmax()
    arg_max_naive = h_transformed_naive.argmax()
    
    after_ate_pred = tokenizer.convert_ids_to_tokens([arg_max_manual])
    after_ate_naive_pred = tokenizer.convert_ids_to_tokens([arg_max_naive])
    
    return after_ate_pred, after_ate_naive_pred

In [None]:
def compute_kls(probs_before, probs_after, ate_prob, naive_ate_prob):
    before_masked_w_after_masked = JS(probs_before.cpu().numpy(), probs_after.cpu().numpy())
    
    ate_w_after_masked = JS(probs_after.cpu().numpy(), ate_prob.cpu().numpy())
    naive_w_after_masked = JS(probs_after.cpu().numpy(), naive_ate_prob.cpu().numpy())
    
    return before_masked_w_after_masked, naive_w_after_masked, ate_w_after_masked

***Define the list of stereotypical adjectives***

In [None]:
MASC_ADJS = [("▁inteligente", "▁inteligente"), ("▁profesional", "▁profesional"), 
             ("▁independiente", "▁independiente"), ("▁racional", "▁racional"),
             ("▁rico", "▁rica"), ("▁rápido", "▁rápida"),
             ("▁brutal", "▁brutal"), ("▁duro", "▁dura"),
             ("▁fuerte", "▁fuerte"), ("▁serio", "▁seria")]

FEM_ADJS = [("▁bonito", "▁bonita"), ("▁amable","▁amable"),
            ("▁sensible","▁sensible"), ("▁hermoso", "▁hermosa"), 
            ("▁delicado", "▁delicada"), ("▁molest", "▁molesta"),
             ("▁protegido", "▁protegida"), ("▁sexy", "▁sexy"),
             ("▁emocional", "▁emocional"), ("▁alegre", "▁alegre")]

NEUT_ADJS = [("▁bueno", "▁buena"), ("▁malo", "▁mala"),
            ("▁triste", "▁triste"), ("▁tranquilo", "▁tranquila"),
            ("▁divertido", "▁divertida"), ("▁joven", "▁joven"),
            ("▁excelente", "▁excelente"), ("▁fantástico", "▁fantástica"),
            ("▁horrible", "▁horrible"), ("▁nuevo", "▁nueva")]

ALL_ADJS = MASC_ADJS + FEM_ADJS + NEUT_ADJS

In [None]:
ALL_ADJS_IDX = []
for m_adj, f_adj in ALL_ADJS:
    m_id = tokenizer.convert_tokens_to_ids([m_adj])
    f_id = tokenizer.convert_tokens_to_ids([f_adj])
    ALL_ADJS_IDX.append((m_id, f_id))

In [None]:
def calc_adj_bias(prob_before, prob_after, gender, before_adj, after_adj, adj):
    for i, (m_adj, f_adj) in enumerate(ALL_ADJS_IDX):
        if gender == MASCULINE:
            diff = ((prob_before[m_adj] + prob_before[f_adj]) / (prob_after[f_adj] + prob_after[m_adj]))[0]
        else:
            diff = ((prob_after[m_adj] + prob_after[f_adj]) / (prob_before[f_adj] + prob_before[m_adj]))[0]
        
        if np.log(diff) > 4 or np.log(diff) < -4:
            print("diff for {}, {} is {}, orig adj {}".format(ALL_ADJS[i][0], ALL_ADJS[i][1], diff, adj))
            print("prob masc-before {} fem-before {} masc-after {} fem-after {}".format(prob_before[m_adj], prob_before[f_adj], prob_after[m_adj], prob_after[f_adj]))
            print(before_adj)
            print(after_adj)
            print("="*100)
        delta_probs[i].append(diff)
#     print(delta_probs)

Loop through the whole dataset of paired sentences, calculate the masked predictions for adjectives and estimate the vector after intervention with ATE vectors.

prerequisite: `ate_vectors` should have paired estimators for adjectives, determiners, nouns.

In [None]:
import math 


predss = []
klss = []
prob_dists = []
more_than_one_sub_word = 0
idx_column_name = "adj idx"
delta_probs = [[] for i in range(30)]


for idx, row in tq(df_target.iterrows(), total=len(df_target)):
    if row[idx_column_name] == -1:
        continue 
    
    before_row = row["before"].split(" ")
    after_row = row["after"].split(" ")
    
    idx_before = row[idx_column_name]
#     idx_after = row["focus ID after"]
    idx_after = adjust_index(row["focus ID"], row["focus ID after"], idx_before)
    
    masked_before = " ".join([s if i != idx_before else "<mask>" for i, s in enumerate(before_row)])
    masked_after = " ".join([s if i != idx_after else "<mask>" for i, s in enumerate(after_row)])
    
    inputs = tokenizer([row["before"], masked_before, row["after"], masked_after], return_tensors="pt", padding=True).to(device)
    start_idx = inputs.word_to_tokens(1, idx_before)[0]
    start_idx_after = inputs.word_to_tokens(3, idx_after)[0]
    
#     assert(end_idx - start_idx == 1)
    
    with torch.no_grad():
        outputs = model(**inputs)

    
    h_transformed = intervene_with_ate(outputs.hidden_states, torch.tensor(ate_vectors[1, :]).float().to(device), start_idx, row[before])
    h_transformed_naive = intervene_with_ate(outputs.hidden_states, torch.tensor(ate_vectors[4, :]).float().to(device), start_idx, row[before])

    
#     preds = outputs['logits'][:, start_idx, :].argmax(dim=1)
    preds_before = outputs['logits'][1, start_idx, :].argmax()
    preds_after = outputs['logits'][3, start_idx, :].argmax()
    
    pred_tokens = [tokenizer.convert_ids_to_tokens([preds_before])[0][1:]]
    pred_tokens.append(tokenizer.convert_ids_to_tokens([preds_after])[0][1:])
    pred_tokens.append(tokenizer.convert_ids_to_tokens([h_transformed_naive.argmax()])[0][1:])
    pred_tokens.append(tokenizer.convert_ids_to_tokens([h_transformed.argmax()])[0][1:])
    pred_tokens.append(row["after"].split(" ")[idx_after])
    

#     probs = F.softmax(outputs.logits[:, start_idx, :], dim=-1)
    probs_before = F.softmax(outputs.logits[1, start_idx, :], dim=-1)
    probs_after = F.softmax(outputs.logits[3, start_idx_after, :], dim=-1)

    ate_prob = F.softmax(h_transformed, dim=-1)
    naive_ate_prob = F.softmax(h_transformed_naive, dim=-1)
    
    
    adj = row["before"].split()[row["adj idx"]]
    calc_adj_bias(probs_before.cpu().numpy(), probs_after.cpu().numpy() , row[before], masked_before, masked_after, adj)
    
    
    kls = compute_kls(probs_before, probs_after, ate_prob, naive_ate_prob) 
    print(kls)
    klss.append(kls)
    print(tuple(pred_tokens))
    predss.append(tuple(pred_tokens))

In [None]:
fig = go.Figure()
adj_bias = {}
for i in range(30):
    print(ALL_ADJS[i])
    val = np.mean(np.log(delta_probs[i]))
    print(val)
    if ALL_ADJS[i][0] == ALL_ADJS[i][1]:
        adj_str = ALL_ADJS[i][0][1:]
    else:
        adj_str = ALL_ADJS[i][0][1:] + "<br>" + ALL_ADJS[i][1][1:]
    adj_bias[adj_str] = val
    print("===")
    fig.add_trace(go.Histogram(x=np.log(delta_probs[i]), histnorm='probability', name=ALL_ADJS[i][0]))
fig.update_layout(barmode='overlay')
fig.update_traces(opacity=0.75)

fig.show()

In [None]:
sorted_adjs = {k: v for k, v in sorted(adj_bias.items(), key=lambda item: item[1])}

In [None]:
adj_texts = []
vals = list(sorted_adjs.values())
keys = list(sorted_adjs.keys())
for i, k in enumerate(keys):
    if vals[i] > -0.5 and vals[i] < 0.1:
        adj_texts.append('')
    else:
        adj_texts.append(k)

In [None]:
fig = go.Figure()


fig.add_trace(go.Scatter(
  x=vals,
  y= [0] * len(vals),
  mode='markers',
  textposition="top center",
  name=label_1,
  hovertemplate ='<b>%{text}</b>',
  text=adj_texts,
  textfont=dict(
        size=12,
    ),
  marker_size=15,
  marker_color=vals
))

fig.update_xaxes(showgrid=False)
fig.update_yaxes(showgrid=False, 
             zeroline=True, zerolinecolor='black', zerolinewidth=3,
             showticklabels=False)
fig.update_layout(plot_layout_dict)
fig.update_layout(height=200)
fig.update_layout(width=1200)

#   fig.update_laxyout(title=title)
fig.show()
fig.write_image("adj-bias.pdf")

## KLs

Let's visualize the KL divergence between the counterfactual representation (with ATE) and the gold counterfactual (with data)

In [None]:
before_after = [x[0] for x in klss]
after_after_naive = [x[1] for x in klss]
after_after_ate = [x[2] for x in klss]

In [None]:
print(np.mean(before_after))
print(np.mean(after_after_naive))
print(np.mean(after_after_ate))

In [None]:
print(np.std(before_after))
print(np.std(after_after_naive))
print(np.std(after_after_ate))