In [None]:
# Importing Libraries
from collections import defaultdict
from pathlib import Path
import numpy as np
import pandas as pd
from collections import namedtuple
import matplotlib.pyplot as plt 
from adversarial_debiasing import AdversarialDebiasing
from load_data import load_data, transform_data, Datapoint

from load_vectors import load_pretrained_vectors, load_vectors
import config
import utility_functions

In [None]:
# For autoreloading changes made in other python scripts
%load_ext autoreload
%autoreload 2

In [None]:
# Loading the word vectors dictionary
word_vectors = load_pretrained_vectors(config.wiki_embedding_data_path, config.save_dir, config.wiki_save_file, \
                                       config.use_glove)

In [None]:
# Testing the word vectors dictionary
temp = word_vectors[['athens', 'greece']]
print(temp.shape)

In [None]:
# Load the google analogies training dataset:
analogy_dataset = load_data()
analogy_dataset[0:100:10]

In [None]:
# Transform the data such that it includes the embeddings
transformed_analogy_dataset = transform_data(word_vectors, analogy_dataset)

In [None]:
# Testing the transformed analogy dataset
print(transformed_analogy_dataset[0].analogy_embeddings.shape)
print(transformed_analogy_dataset[0].gt_embedding.shape)
print(transformed_analogy_dataset[0].protected_embedding.shape)

In [None]:
# Now we fit a dataset.
embedding_dim = 100
analogy_dataset = [
    Datapoint(
    analogy_embeddings=np.random.normal(0, 1, size=(3 * embedding_dim)), 
    gt_embedding=np.random.normal(0, 1, size=(embedding_dim)),
    protected_embedding=np.random.uniform(0, 1, size=(100))) for n in range(0, 1000)
]



model = AdversarialDebiasing()
model.fit(dataset=analogy_dataset)


In [None]:
losses = model.losses

y = losses['adversary']
x = np.arange(len(y))
plt.plot(x, y, 'go--', label='Adversary loss')
plt.ylabel('Loss')
plt.title('Adversary Loss')
plt.legend()
plt.show()

y = losses['predictor']
x = np.arange(len(y))
plt.plot(x, y, 'go--', label='Predictor loss')
plt.ylabel('Loss')
plt.title('Predictor Loss')
plt.legend()
plt.show()
