In [1]:
import os, sys
sys.path.insert(0, '..')

In [2]:
import gensim
import gc
from os.path import join as j
import json
from tqdm import tqdm, trange
from models import glove, custom_trained_model, word2vec
from utils.dataset import PandasDataset
from datasets.nyt import Nyt
from utils.weat import WEAT
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.neighbors import KNeighborsClassifier
import seaborn as sns, numpy as np, pandas as pd, random
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import warnings
warnings.filterwarnings('ignore')


gc.enable()

In [3]:
def get_bar_plot(y, pred):
    u = np.unique(y)
    classes = len(u)
    match = [0] * classes * 2
    for idx, true in enumerate(y):
        if true == pred[idx]:
            match[true] += 1
        else:
            match[true + classes] += 1
    return pd.DataFrame({'x': list(u) * 2, 'y': match, 'hue': [True] * classes + [False] * classes})

In [4]:
def plot_(x, y, colors, z=None, title="year", three=False, scale=True):
    # https://stackoverflow.com/a/60621783
    sns.set_style("whitegrid", {'axes.grid' : False})
    fig = plt.figure(figsize=(6,6))
    if three:
        ax = Axes3D(fig)
        g = ax.scatter(x, y, z, c=colors, marker='o', depthshade=False, cmap='Paired')
        if scale:
            ax.set_zlim(-1, 1)
            ax.set_xlim(-1, 1)
            ax.set_ylim(-1, 1)
        ax.set_zlabel('Z Label')
        ax.set_xlabel('X Label')
        ax.set_ylabel('Y Label')
        ax.set_title(title)
        legend = ax.legend(*g.legend_elements(), loc="lower center", borderaxespad=-10, ncol=4)
        ax.add_artist(legend)
    else:
        plt.scatter(x, y, c=colors)
        if scale:
            plt.xlim(-2, 2)
            plt.ylim(-2, 2)
        plt.title(title)
    plt.show()

In [5]:
def plot_model(model, model_title, man_words, woman_words, occupations=[]):
    
    vecs = model.transform(man_words + woman_words + occupations)
    colors = ['blue'] * len(man_words) + ['pink'] * len(woman_words) + ['green'] * len(occupations)
    
    # plot 3D graph using PCA
    three = PCA(random_state=0).fit_transform(vecs)[:, :3]
    plot_(three[:, 0], three[:, 1], z=three[:, 2], title=model_title + "_3D_PCA", three=True, scale=False, colors=colors)
    
    # plot 2 D plot using PCA
    two = PCA(random_state=0).fit_transform(vecs)[:, :2]
    plot_(two[:, 0], two[:, 1], title=model_title + "_2D_PCA", three=False, scale=False, colors=colors)
    
    
    # plot LDA
    y = np.concatenate([np.zeros(shape=len(man_words), dtype=int), np.ones(dtype=int, shape=len(woman_words)), 
                    np.full(dtype=int, shape=len(occupations), fill_value=2)]) 
    two = LinearDiscriminantAnalysis().fit_transform(vecs, y)
    
    knn = KNeighborsClassifier(n_neighbors=3)
    knn.fit(vecs, y)
    pred = knn.predict(vecs)
    df = get_bar_plot(y=y, pred=pred)
    if len(occupations):
        plot_(two[:, 0], two[:, 1], title=model_title + "_2D_LDA", three=False, scale=False, colors=colors)
    else:
        plot_(two[:, 0], y, title=model_title + "_1D_LDA", three=False, scale=False, colors=colors)
    sns.barplot(data=df, x='x', y='y', hue='hue', ).set_title(model_title + "_KNN classification")
    plt.show()
    return WEAT(model, words_json='../weat/weat.json').get_scores()

In [6]:
DIR = '/tmp/temp/'
YEARS = range(1921, 2021, 10)
EMBEDDINGS_DIR = '../trained_models/word2vec/embeddings_{}/'
MAN, WOMAN, OCCUPATIONS = 'Man words', 'Woman words', 'Occupations with Human Stereotype Scores'
words = json.load(open('../weat/GargWordList.json'))
words[WOMAN].remove('femen') # not sure what femen is !! 
CSV = "df_{st}_to_{end}.csv"
SCORES = np.zeros(shape=(len(YEARS), 7))
CUSTOM_MODEL_PATH = "../trained_models/GoogleNews-vectors-negative300.bin"

# WORD2VEC MODELS

In [7]:
for idx, y in enumerate(tqdm(YEARS)):
    dataset = DIR + CSV.format(st=y, end=y+9)
    saved_model_path = EMBEDDINGS_DIR.format(y)
    lines = Nyt(dataset).lines
    m = word2vec.Word2Vec(load=False).fit(lines).save(saved_model_path)
    

100%|████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [1:23:42<00:00, 502.21s/it]
