In [2]:
# Importing Libraries
from collections import defaultdict
from pathlib import Path
import numpy as np
import pandas as pd
from collections import namedtuple

from adversarial_debiasing import AdversarialDebiasing
from load_data import load_data, transform_data, Datapoint

from load_vectors import load_pretrained_vectors, load_vectors
import config
import utility_functions

In [3]:
# For autoreloading changes made in other python scripts
%load_ext autoreload
%autoreload 2

In [4]:
# Loading the word vectors dictionary
word_vectors = load_pretrained_vectors(config.wiki_embedding_data_path, config.save_dir, config.wiki_save_file, \
                                       config.use_glove)

Loading from saved file.


In [5]:
# Testing the word vectors dictionary
temp = word_vectors['john', 'mary', 'hello']
print(temp.shape)
temp[0]

(3, 100)


array([-0.0104,  0.2052, -0.0433,  0.3336, -0.4085, -0.1087,  0.1706,
        0.2941, -0.1128, -0.1487,  0.0042,  0.1413,  0.0191, -0.63  ,
       -0.0495, -0.0478,  0.2947, -0.0873, -0.0122,  0.2309,  0.2741,
       -0.3123,  0.0905, -0.0017, -0.4458, -0.0592,  0.5358, -0.1962,
        0.2812,  0.4794,  0.1337,  0.1756, -0.1852, -0.0167,  0.2623,
        0.0054, -0.0672,  0.3179,  0.27  , -0.0625, -0.2528, -0.221 ,
       -0.3646,  0.3363, -0.4909,  0.2077, -0.0641,  0.5532, -0.2084,
        0.4772,  0.3867,  0.0567, -0.1839, -0.0369, -0.3371, -0.5641,
       -0.1561, -0.3864, -0.0726, -0.3733,  0.2073, -0.0353, -0.0133,
        0.386 ,  0.9638, -0.049 , -0.1288, -0.1577,  0.0498, -0.1378,
        0.3356,  0.2502,  0.2932,  0.4169, -1.0372,  0.1025,  0.0938,
       -0.4552,  0.1343, -0.1765,  0.0039, -0.51  , -0.3901,  0.0278,
       -0.0169,  0.6304,  0.2066, -0.0479, -0.2074,  0.2342, -0.2059,
       -0.2411, -0.0708, -0.1348,  0.2447, -0.1961, -0.3007,  0.2797,
       -0.3301,  0.1

In [6]:
# Load the google analogies training dataset:
analogy_dataset = load_data()
analogy_dataset

[Raw_Datapoint(x1='Athens', x2='Greece', x3='Baghdad', y='Iraq', task='capital-common-countries'),
 Raw_Datapoint(x1='Athens', x2='Greece', x3='Bangkok', y='Thailand', task='capital-common-countries'),
 Raw_Datapoint(x1='Athens', x2='Greece', x3='Beijing', y='China', task='capital-common-countries'),
 Raw_Datapoint(x1='Athens', x2='Greece', x3='Berlin', y='Germany', task='capital-common-countries'),
 Raw_Datapoint(x1='Athens', x2='Greece', x3='Bern', y='Switzerland', task='capital-common-countries'),
 Raw_Datapoint(x1='Athens', x2='Greece', x3='Cairo', y='Egypt', task='capital-common-countries'),
 Raw_Datapoint(x1='Athens', x2='Greece', x3='Canberra', y='Australia', task='capital-common-countries'),
 Raw_Datapoint(x1='Athens', x2='Greece', x3='Hanoi', y='Vietnam', task='capital-common-countries'),
 Raw_Datapoint(x1='Athens', x2='Greece', x3='Havana', y='Cuba', task='capital-common-countries'),
 Raw_Datapoint(x1='Athens', x2='Greece', x3='Helsinki', y='Finland', task='capital-common-cou

In [7]:
x = utility_functions.obtain_gender_pairs(word_vectors)
print(len(x[0][0]))
for y in x:
    print(y)

100
[[-0.02879999950528145, 0.7487999796867371, 0.38269999623298645, 0.03310000151395798, 0.1006999984383583, -0.20149999856948853, 0.49810001254081726, 0.09239999949932098, -0.1817999929189682, -0.21199999749660492, -0.20559999346733093, -0.13600000739097595, 0.1371999979019165, 0.08269999921321869, 0.07970000058412552, -0.2653999924659729, -0.00570000009611249, 0.038100000470876694, -0.05380000174045563, 0.2849000096321106, 0.006899999920278788, 0.21610000729560852, -0.21789999306201935, 0.09929999709129333, -0.1777999997138977, 0.06930000334978104, 0.04190000146627426, -0.33550000190734863, -0.28600001335144043, 0.271699994802475, -0.03500000014901161, -0.2037999927997589, -0.28949999809265137, 0.20409999787807465, 0.1526000052690506, -0.22930000722408295, -0.06289999932050705, 0.23659999668598175, 0.3723999857902527, -0.2816999852657318, -0.02879999950528145, -0.09950000047683716, -0.12020000070333481, 0.4316999912261963, -0.13050000369548798, -0.04230000078678131, -0.3483000099658

[[-0.04960000142455101, 0.5554999709129333, 0.2547999918460846, -0.16040000319480896, -0.10010000318288803, -0.19920000433921814, 0.2387000024318695, 0.11110000312328339, 0.14270000159740448, -0.243599995970726, 0.1898999959230423, -0.2093999981880188, -0.05480000004172325, -0.011900000274181366, 0.08449999988079071, -0.009100000374019146, -0.01759999990463257, -0.16609999537467957, -0.10930000245571136, 0.07769999653100967, 0.094200000166893, 0.13619999587535858, -0.40470001101493835, 0.19429999589920044, -0.4860999882221222, -0.11620000004768372, -0.04529999941587448, -0.30070000886917114, 0.19349999725818634, 0.19930000603199005, -0.011500000022351742, -0.2214999943971634, -0.195700004696846, -0.07970000058412552, -0.0034000000450760126, -0.08630000054836273, -0.36809998750686646, 0.03689999878406525, 0.18129999935626984, -0.44279998540878296, -0.0778999999165535, -0.21480000019073486, -0.4505000114440918, -0.1559000015258789, -0.0778999999165535, 0.34459999203681946, -0.24750000238

[[-0.23770000040531158, 0.2685999870300293, -0.09619999676942825, 0.27070000767707825, -0.2240999937057495, -0.24889999628067017, 0.10649999976158142, 0.041200000792741776, -0.5349000096321106, -0.1445000022649765, -0.08699999749660492, -0.18770000338554382, 0.19850000739097595, -0.16429999470710754, 0.10209999978542328, -0.17829999327659607, -0.0551999993622303, 0.021900000050663948, -0.21799999475479126, 0.15690000355243683, -0.28349998593330383, -0.32989999651908875, -0.06780000030994415, 0.3504999876022339, -0.32409998774528503, -0.0008999999845400453, -0.1234000027179718, -0.3452000021934509, -0.4523000121116638, 0.7448999881744385, 0.1469999998807907, -0.1257999986410141, -0.10729999840259552, 0.4018999934196472, 0.1120000034570694, 0.022299999371170998, -0.3720000088214874, 0.20260000228881836, 0.031599998474121094, 0.029100000858306885, -0.24060000479221344, 0.13680000603199005, -0.017500000074505806, 0.10199999809265137, 0.08340000361204147, 0.5012000203132629, -0.397300004959

[[-0.22439999878406525, 0.22050000727176666, -0.03790000081062317, 0.032999999821186066, 0.2087000012397766, -0.3749000132083893, -0.014999999664723873, 0.1388999968767166, -0.3224000036716461, -0.3005000054836273, 0.1882999986410141, -0.19040000438690186, 0.2752000093460083, -0.19280000030994415, 0.23199999332427979, -0.12280000001192093, -0.5169000029563904, -0.07760000228881836, -0.10379999876022339, 0.07109999656677246, -0.044599998742341995, -0.07289999723434448, 0.026200000196695328, 0.28769999742507935, -0.4146000146865845, 0.00279999990016222, -0.09319999814033508, -0.47540000081062317, -0.011900000274181366, 0.3732999861240387, 0.1687999963760376, -0.07090000063180923, -0.027799999341368675, 0.10170000046491623, 0.41940000653266907, -0.14399999380111694, -0.6105999946594238, 0.4611999988555908, -0.23520000278949738, -0.20829999446868896, -0.3490000069141388, -0.17749999463558197, -0.15639999508857727, 0.17020000517368317, -0.15629999339580536, -0.07240000367164612, -0.45640000

[[0.11169999837875366, -0.04340000078082085, 0.5677000284194946, 0.039400000125169754, -0.25110000371932983, 0.13950000703334808, 0.16920000314712524, 0.3163999915122986, -0.4999000132083893, 0.2996000051498413, -0.013799999840557575, -0.22789999842643738, 0.06930000334978104, -0.3197000026702881, -0.2540999948978424, -0.27489998936653137, -0.3864000141620636, -0.1923000067472458, 0.32030001282691956, 0.33719998598098755, -0.019899999722838402, 0.14669999480247498, -0.13740000128746033, -0.12710000574588776, -0.5350000262260437, -0.12290000170469284, 0.17159999907016754, -0.6608999967575073, 0.13439999520778656, 0.5932000279426575, -0.053599998354911804, -0.07530000060796738, -0.16060000658035278, 0.13079999387264252, 0.3921999931335449, 0.20100000500679016, -0.5278000235557556, 0.3431999981403351, 0.029899999499320984, -0.2587999999523163, -0.21870000660419464, -0.18729999661445618, -0.26840001344680786, 0.1639000028371811, 0.2409999966621399, 0.6403999924659729, -0.6424999833106995, 

[[0.1371999979019165, 0.25699999928474426, 0.47909998893737793, -0.07440000027418137, -0.10180000215768814, -0.07079999893903732, 0.0982000008225441, 0.4018999934196472, -0.4269999861717224, 0.17960000038146973, 0.08900000154972076, -0.1251000016927719, 0.13950000703334808, -0.13490000367164612, -0.024299999698996544, -0.1996999979019165, -0.29510000348091125, -0.10459999740123749, 0.2524000108242035, 0.35030001401901245, 0.1168999969959259, 0.18140000104904175, -0.15729999542236328, 0.06449999660253525, -0.4530999958515167, -0.10689999908208847, 0.07349999994039536, -0.6837999820709229, -0.08470000326633453, 0.5175999999046326, -0.10769999772310257, 0.010700000450015068, -0.17139999568462372, 0.1972000002861023, 0.2727999985218048, 0.22519999742507935, -0.23340000212192535, 0.35019999742507935, 0.12439999729394913, -0.37599998712539673, -0.07519999891519547, -0.027499999850988388, -0.4262000024318695, 0.020500000566244125, 0.1662999987602234, 0.5019000172615051, -0.5787000060081482, 0

[[0.14710000157356262, 0.46639999747276306, -0.16519999504089355, 0.34860000014305115, -0.16619999706745148, -0.14399999380111694, 0.0658000037074089, 0.2648000121116638, -0.17890000343322754, -0.21279999613761902, -0.13699999451637268, -0.04600000008940697, 0.2069000005722046, -0.3752000033855438, -0.12039999663829803, -0.32330000400543213, 0.09549999982118607, -0.4124999940395355, -0.28859999775886536, 0.34119999408721924, 0.016499999910593033, -0.17470000684261322, 0.03869999945163727, 0.5206999778747559, -0.26260000467300415, 0.22910000383853912, 0.27559998631477356, 0.1995999962091446, -0.26109999418258667, 0.5297999978065491, 0.5029000043869019, 0.11490000039339066, -0.03720000013709068, -0.014299999922513962, 0.053700000047683716, 0.12960000336170197, -0.24779999256134033, 0.35690000653266907, 0.013899999670684338, -0.2808000147342682, -0.2094999998807907, -0.011099999770522118, -0.1868000030517578, 0.38960000872612, 0.10080000013113022, -0.06030000001192093, 0.08889999985694885

[[-0.35010001063346863, 0.5210999846458435, 0.06069999933242798, -0.06419999897480011, -0.4433000087738037, 0.15449999272823334, -0.37049999833106995, 0.23960000276565552, -0.40369999408721924, -0.21699999272823334, -0.11110000312328339, -0.3427000045776367, 0.30469998717308044, -0.1298999935388565, 0.06239999830722809, -0.20069999992847443, 0.06790000200271606, -0.12530000507831573, -0.31150001287460327, 0.16769999265670776, -0.1712000072002411, -0.0142000000923872, -0.23409999907016754, 0.45179998874664307, 0.00930000003427267, 0.20239999890327454, -0.35510000586509705, -0.14309999346733093, -0.504800021648407, 0.7692000269889832, 0.08150000125169754, 0.03720000013709068, -0.3837999999523163, 0.7763000130653381, -0.1565999984741211, 0.034699998795986176, -0.38999998569488525, 0.2964000105857849, 0.37070000171661377, 0.11569999903440475, -0.17870000004768372, 0.12200000137090683, -0.47540000081062317, -0.23669999837875366, 0.11680000275373459, -0.10130000114440918, -0.8190000057220459

[[-0.5432999730110168, 0.3637000024318695, 0.18729999661445618, -0.9387000203132629, -0.18799999356269836, 0.40619999170303345, 0.03359999880194664, 0.06599999964237213, -0.08049999922513962, -0.19089999794960022, -0.32690000534057617, -0.210099995136261, 0.42399999499320984, -0.3280999958515167, -0.36169999837875366, -0.321399986743927, 0.3880000114440918, -0.4575999975204468, 0.5741999745368958, 0.009999999776482582, 0.11330000311136246, -0.07840000092983246, -0.2021999955177307, 0.39629998803138733, -0.4108000099658966, -0.8657000064849854, -0.09149999916553497, 0.26409998536109924, 0.039000000804662704, 0.5325000286102295, 0.742900013923645, 0.5013999938964844, -0.2542000114917755, 0.3871999979019165, 0.12409999966621399, 0.19050000607967377, -0.040300000458955765, -0.3813000023365021, -0.03959999978542328, -0.20270000398159027, -0.05999999865889549, 0.5722000002861023, -0.31029999256134033, -0.15459999442100525, -0.34779998660087585, -0.026599999517202377, 0.013899999670684338, -0

[[-0.010400000028312206, 0.20520000159740448, -0.043299999088048935, 0.3336000144481659, -0.40849998593330383, -0.10869999974966049, 0.17059999704360962, 0.29409998655319214, -0.1128000020980835, -0.14869999885559082, 0.00419999985024333, 0.1412999927997589, 0.019099999219179153, -0.6299999952316284, -0.0494999997317791, -0.04780000075697899, 0.2946999967098236, -0.08730000257492065, -0.012199999764561653, 0.23090000450611115, 0.27410000562667847, -0.3122999966144562, 0.09049999713897705, -0.0017000000225380063, -0.4458000063896179, -0.05920000001788139, 0.5357999801635742, -0.19619999825954437, 0.28119999170303345, 0.47940000891685486, 0.13369999825954437, 0.17560000717639923, -0.18520000576972961, -0.016699999570846558, 0.2623000144958496, 0.005400000140070915, -0.06719999760389328, 0.31790000200271606, 0.27000001072883606, -0.0625, -0.25279998779296875, -0.22100000083446503, -0.3646000027656555, 0.33629998564720154, -0.4909000098705292, 0.2076999992132187, -0.0640999972820282, 0.553

In [11]:
y = utility_functions.obtain_gender_subspace(x)
print(y)

[[ 0.06708962  0.00850507  0.0044131  -0.14337213 -0.01198367 -0.08900737
   0.09184579  0.10018828  0.05593755  0.04945152 -0.11566466 -0.03031982
   0.10042564  0.13032164 -0.016117   -0.08038898 -0.14472522  0.14728229
   0.02381136  0.26320039  0.20735399 -0.03331795 -0.1025313   0.14477854
   0.09007349 -0.16344649  0.00850054  0.10354929 -0.04904761 -0.05710212
   0.02936866  0.18065918  0.05961868  0.03630323  0.0970019   0.09347218
  -0.00974479 -0.23595012 -0.00717322 -0.0838382  -0.02938197  0.02917068
  -0.13478935 -0.11199915 -0.01242035 -0.0355581  -0.09277333  0.05909432
   0.06699538 -0.13369367 -0.15025215  0.02223417  0.08594777 -0.08936922
   0.10508251  0.25848451 -0.0108665   0.04493326  0.01115858  0.09997015
   0.07796242 -0.04719058  0.05199704  0.07475872 -0.02360925 -0.11951299
  -0.04243528 -0.14035094 -0.00608627 -0.00966634  0.05627689 -0.02892361
  -0.15278741 -0.1024916   0.1072672   0.02598628  0.08407235  0.05773501
  -0.02055145  0.14775123 -0.06889477 

In [7]:
# Transform the data such that it includes the embeddings
analogy_dataset = transform_data(word_vectors, analogy_dataset)
analogy_dataset

TypeError: unsupported operand type(s) for +: 'int' and 'list'

In [None]:
# Now we fit a dataset.

embedding_dim = 100
analogy_dataset = [
    Datapoint(
    analogy_embeddings=np.random.normal(0, 1, size=(3 * embedding_dim, 1)), 
    gt_embedding=np.random.normal(0, 1, size=(embedding_dim, 1)),
    protected_embedding=np.random.uniform(0, 1, size=(1))) for n in range(0, 1000)
]



model = AdversarialDebiasing()
model.fit(dataset=analogy_dataset)
