In [72]:
import pathlib
from sklearn import neighbors
from sklearn import metrics
import string
import numpy as np
from tqdm.auto import tqdm

import spacy
nlp = spacy.load('en_core_web_lg')

from scipy import spatial

In [None]:
docs = []
labels = []

directory = pathlib.Path(r'..\data\genderData')
label_names = ['Man', 'Woman']

for label in label_names:
    for file in directory.joinpath(label).iterdir():
        labels.append(label)
        docs.append(file.read_text(encoding='utf-8'))

In [53]:
doc = docs[0]
print(labels[0])
doc

business


'Ad sales boost Time Warner profit\n\nQuarterly profits at US media giant TimeWarner jumped 76% to $1.13bn (£600m) for the three months to December, from $639m year-earlier.\n\nThe firm, which is now one of the biggest investors in Google, benefited from sales of high-speed internet connections and higher advert sales. TimeWarner said fourth quarter sales rose 2% to $11.1bn from $10.9bn. Its profits were buoyed by one-off gains which offset a profit dip at Warner Bros, and less users for AOL.\n\nTime Warner said on Friday that it now owns 8% of search-engine Google. But its own internet business, AOL, had has mixed fortunes. It lost 464,000 subscribers in the fourth quarter profits were lower than in the preceding three quarters. However, the company said AOL\'s underlying profit before exceptional items rose 8% on the back of stronger internet advertising revenues. It hopes to increase subscribers by offering the online service free to TimeWarner internet customers and will try to sig

In [54]:
def clean_text(text):
    text = text.lower()
    text = text.translate(str.maketrans('', '', string.punctuation))
    text = text.replace('\n', ' ')
    text = ' '.join(text.split())  # remove multiple whitespaces
    return text

In [55]:
doc = clean_text(doc)
doc

'ad sales boost time warner profit quarterly profits at us media giant timewarner jumped 76 to 113bn £600m for the three months to december from 639m yearearlier the firm which is now one of the biggest investors in google benefited from sales of highspeed internet connections and higher advert sales timewarner said fourth quarter sales rose 2 to 111bn from 109bn its profits were buoyed by oneoff gains which offset a profit dip at warner bros and less users for aol time warner said on friday that it now owns 8 of searchengine google but its own internet business aol had has mixed fortunes it lost 464000 subscribers in the fourth quarter profits were lower than in the preceding three quarters however the company said aols underlying profit before exceptional items rose 8 on the back of stronger internet advertising revenues it hopes to increase subscribers by offering the online service free to timewarner internet customers and will try to sign up aols existing customers for highspeed b

In [56]:
def embed(tokens, nlp):
    """Return the centroid of the embeddings for the given tokens.

    Out-of-vocabulary tokens are cast aside. Stop words are also
    discarded. An array of 0s is returned if none of the tokens
    are valid.

    """

    lexemes = (nlp.vocab[token] for token in tokens)

    vectors = np.asarray([
        lexeme.vector
        for lexeme in lexemes
        if lexeme.has_vector
        and not lexeme.is_stop
        and len(lexeme.text) > 1
    ])

    if len(vectors) > 0:
        centroid = vectors.mean(axis=0)
    else:
        width = nlp.meta['vectors']['width']  # typically 300
        centroid = np.zeros(width)

    return centroid

In [57]:
tokens = doc.split(' ')
centroid = embed(tokens, nlp)
print(centroid.shape)
print(centroid[:10])

(300,)
[-0.7090112  -0.47271743 -1.482876    1.0192236   2.6123836  -0.17018092
  0.8778576   3.4173107  -0.86535937 -1.4310907 ]


In [58]:
label_vectors = np.asarray([
    embed(label.split(' '), nlp)
    for label in label_names
])
label_vectors.shape

(5, 300)

In [62]:
neigh = neighbors.NearestNeighbors(n_neighbors=1)
neigh.fit(label_vectors)
closest_label = neigh.kneighbors([centroid], return_distance=False)[0, 0]
label_names[closest_label]

'entertainment'

In [59]:
def predict(doc, nlp, neigh):
    doc = clean_text(doc)
    tokens = doc.split(' ')[:50]
    centroid = embed(tokens, nlp)
    closest_label = neigh.kneighbors([centroid], return_distance=False)[0][0]
    return closest_label

preds = [label_names[predict(doc, nlp, neigh)] for doc in docs]

In [63]:
neigh = neighbors.NearestNeighbors(
    n_neighbors=1,
    metric=spatial.distance.cosine
)

neigh.fit(label_vectors)

preds = [label_names[predict(doc, nlp, neigh)] for doc in docs]

report = metrics.classification_report(
    y_true=labels,
    y_pred=preds,
    labels=label_names
)

print(report)

               precision    recall  f1-score   support

     business       0.58      0.90      0.70       510
entertainment       0.48      0.91      0.63       386
     politics       0.81      0.45      0.57       417
        sport       0.94      0.64      0.76       511
   technology       0.86      0.26      0.40       401

     accuracy                           0.64      2225
    macro avg       0.73      0.63      0.61      2225
 weighted avg       0.74      0.64      0.62      2225

