In [6]:
import torch
import torch.nn as nn

We will begin by making the model, which will be used later within the TF API. In this small example, we will be experimenting with different methods for random initialization of the embeddings, exploring how they may impact learning speed. So, we will use "init_func" as a placeholder for a generic initialization function.

In [7]:
V = 5000
E = 50

class MFEmbedder(nn.Module):
    def __init__(self, vsize, embdim):
        super(MFEmbedder, self).__init__()
        self.B = torch.nn.Parameter(torch.FloatTensor(vsize, embdim))
        self.A = torch.nn.Parameter(torch.FloatTensor(vsize, embdim))
        
        nn.init.normal_(self.B)
        nn.init.normal_(self.A)
   
    def forward(self, x=None):     
        return torch.matmul(self.B, self.A.t())
    

In [8]:
model = MFEmbedder(V, E)

Now we need to make the special loss function that we learned to derive from last time. With this made, we will now be ready to learn SGNS vectors!

In [9]:
class SGNSLoss():
    def __init__(self, Nij, Ni, Nj, N, K):
        self.Nij = Nij
        self.Ni = Ni
        self.Nj = Nj
        self.N = N
        self.K = K
        
    def __call__(self, M_hat):
        pos_samples = (self.Nij * nn.LogSigmoid()(M_hat)).sum()
        
        # This should be a 1 by 1 matrix.
        neg_samples = (self.K / self.N) * torch.matmul(
            torch.matmul(Nj.view(1, -1), nn.LogSigmoid()(-M_hat)),
            self.Ni.view(-1, 1)
        ).sum()
        
        return -(pos_samples + neg_samples) / self.N


Now we should declare some of the variables that we want to work with. We are going to need to load one of my pre-made files that performs cooccurrence statistic extraction. Because this is a small example, we will just be using a 5000 word vocabulary so that we don't overload Google's cloud GPUs.

In [10]:
## Boring data downloading stuff.
import requests
import io
import zipfile
import pandas as pd

def download_extract_zip(file_name):
    with zipfile.ZipFile(file_name, 'r') as thezip:
        for zipinfo in thezip.infolist():
            with thezip.open(zipinfo) as thefile:
                return pd.read_csv(
                    thefile, 
                    sep=' ', 
                    header=None, 
                    names=['term', 'context', 'Nij']
                )

df = download_extract_zip('cooc.zip') 

In [11]:
df.head()

Unnamed: 0,term,context,Nij
0,chief,peace,2072.8
1,produces,venezuela,109.4
2,resources,?,411.2
3,israelis,later,196.2
4,however,list,1395.2


In [14]:
# Now we are loading in the data from the dataframe (quick and kind of dirty)
vocab = {"<unk>": 0}
invvocab = ["<unk>"]

# count for words, not super efficient
from collections import defaultdict
counts = defaultdict(float)
# for w in sorted(df['term'].unique()): ## alphabetically sorted
#   counts[w] = 1
for tup in df.itertuples(): ## frequency sorted
    counts[tup.term] += tup.Nij

for w in sorted(counts.keys(), key=lambda x: -counts[x]):
    invvocab.append(w)
    vocab[w] = len(vocab)

V = len(vocab)

# Given the vocabulary mapping above, we now fill in the Nij matrix.
# NOTE - with large vocabs, this should be done using a *sparse* matrix!!!
import numpy as np
Nij_np = np.zeros((V, V))
for tup in df.itertuples():
    i = vocab[tup.term]
    j = vocab[tup.context]
    Nij_np[i, j] = tup.Nij

In [15]:
Nij_np[0, 1], Nij_np[1, 0]

(0.0, 0.0)

In [16]:
Nij = torch.tensor(Nij_np, dtype=torch.float32)
# marginalize to get the unigram counts
N = Nij.sum()

# col sums
Ni = Nij.sum(axis=0)

# row sum
Nj = Nij.sum(axis=1)

In [17]:
(Nij * nn.LogSigmoid()(model())).sum()

tensor(-7.5978e+10, grad_fn=<SumBackward0>)

In [18]:
torch.flatten(torch.matmul(
    torch.matmul(
        Nj.view(1, -1),
        nn.LogSigmoid()(model())
    ),
    Ni.view(-1, 1)
)).sum()

tensor(-2.0059e+21, grad_fn=<SumBackward0>)

In [19]:
assert(Nij.shape == torch.Size([V, V]))

In [20]:
torch.einsum('i,j->ij', Ni, Nj).shape

torch.Size([5000, 5000])

In [21]:
# marginalize to get the unigram counts
N = Nij.sum()

# col sums
Ni = Nij.sum(axis=0)

# row sum
Nj = Nij.sum(axis=1)

# Same thing.
PMI_np_check = np.array(torch.log((N * Nij) / (torch.einsum('i,j->ij', Ni, Nj))))
PMI_np = np.array(torch.log(
    (N * Nij) / (Ni.repeat(V).view(V, -1).t() * Nj.repeat(V).view(V, -1))
))

In [23]:
PMI_np.shape

(5000, 5000)

In [None]:
assert(torch.all(
    torch.eq(
        torch.einsum('i,j->ij', Ni, Nj),
        Ni.repeat(V).view(V, -1).t() * Nj.repeat(V).view(V, -1)
)).item())

In [None]:
PMI_np[:10, :10]

Let's do a sanity check and examine the count statistics with some cool matplotlib visualizations.

In [None]:
from matplotlib import pyplot as plt
# investigate some of the statistics.
flat_PMI_np = PMI_np.reshape(-1)
pmi_negativeinf = np.sum(flat_PMI_np == -np.inf)
total = np.prod(PMI_np.shape)
print('Percent -infinities in the PMI matrix: {:.4f}%'.format(
    100 * pmi_negativeinf / total))

# visualize the matrix!
_ = plt.figure()
_ = plt.imshow(PMI_np, cmap="coolwarm")
_ = plt.xlabel("term index")
# _ = plt.ylabel("context index")
_ = plt.colorbar()
# _ = plt.title("Visualization of the PMI matrix")

# turn the -infinities to something easier to work with
PMI_np[PMI_np == -np.inf] = -4
PMI_np[np.isnan(PMI_np)] = -4
hist_pmis = flat_PMI_np[flat_PMI_np != -4]
_ = plt.figure()
n, bins, patches = plt.hist(hist_pmis, 
                            bins=100,              
                            color="b",
                            alpha=1)
print("Statistics of PMIs: ")
print("{:0.4f} mean, {:0.4f} std".format(np.mean(hist_pmis), 
                                         np.std(hist_pmis)))

# Add shading to the histogram
cm = plt.cm.get_cmap("coolwarm")
bin_centers = 0.5 * (bins[:-1] + bins[1:])
col = bin_centers - min(bin_centers)
col /= max(col)
for c, p in zip(col, patches):
    plt.setp(p, 'facecolor', cm(c))
_ = plt.xlabel("PMI value")
_ = plt.ylabel("probability mass")
_ = plt.title("Histogram of PMIs (excluding -infinities)")

In [None]:
""" 
Note that there was a slight bug when extracting corpus statistics,
since the context window (w=5) is symmetric, PMI(i,j) should always equal
PMI(j,i); however, due to improper handling of context during the first 5 words
of a document, the statistics are ever-so-slightly distorted.
"""
# look at the biggest PMIs
sorted_ind = np.argsort(PMI_np, axis=None)

print('--- Word pairs with the notable PMIs ---\n')
for i in range(1, 18):
    ind = np.unravel_index(sorted_ind[-i], PMI_np.shape)
    term, context = invvocab[ind[0]], invvocab[ind[1]]
    print('{:8} {:14} (PMI = {:0.4f})'.format(term, context, float(PMI_np[ind])))

In [None]:
# Create the object representing the loss function.
criterion = SGNSLoss(
    Nij=Nij,
    Ni=Ni,
    Nj=Nj,
    N=N,
    k=1
)

# desired embedding dimensionality
E = 50

# Create the MF model!
model = MFEmbedder(V, E)

# Instantiate optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)


In [None]:
criterion(model()).shape

In [None]:
def train_step(model, optimizer, n=1000):
    "Training loop for torch model."
    losses = []
    optimizer.zero_grad()
    # This is B @ A.t().
    M_hat = model()
    loss = criterion(M_hat)
    loss.backward()
    optimizer.step()
    return loss.item()

In [None]:
# Grabs the most similar words to the current term. We scale the vectors and then do an inner product.
def most_similar(term):
    i = vocab[term]
    emb = model.B[i]
    embs = model.B.t() / torch.linalg.norm(model.B, axis=1)
    cossims = torch.matmul(emb.view(1, -1), embs)
    wordsims = torch.argsort(cossims, descending=True).view(-1)
    return [invvocab[idx] for idx in wordsims[1:6]] # first is the word itself

In [None]:
# Now run a training loop! Note that these embeddings will be trained in less
# than 15 minutes - much much much faster than it would take for the original 
# implementation. But, the speed of this is dependent on the vocabulary size.
import time

start = time.process_time() 
estart = time.process_time() 
results = []

n_iters = 500 + 1
print_every = 50

for i in range(n_iters): 
    loss = train_step(model, optimizer)
    results.append(loss)
    if i % print_every == 0:
        print('\nstep {:4} - loss: {} ({:0.4f} seconds)'.format(
            i, results[-1], time.process_time()  - estart)
             )
        print('\t similar to \"money\": '+" ".join(most_similar("money")))
        print('\t similar to \"peace\": '+" ".join(most_similar("peace")))

        estart = time.process_time() 
        
print('\nTotal time: {:0.4f} seconds.'.format(time.process_time()  - start))

In [None]:
# Let's look at the loss over time.
from matplotlib import pyplot as plt
x = np.array(list(range(len(results))))
y = np.array(results)

# use the underscore to avoid printing to colab
_ = plt.figure()
_ = plt.plot(x, y, '--r', label="loss")
_ = plt.xlabel("Iteration")
_ = plt.ylabel("Loss value")
# _ = plt.title("MF-SGNS loss over time")
_ = plt.legend()

Training is finished, and it looks like the model is learning to produce vectors with desirable semantic qualities! The loss is descreasing as well. As a final sanity check, let's do a manual inspection of a few more words, just to double check. 

In [None]:
# manual qualitative inspection
for w in ["drive", "america", "east", "soviet", "belgium", "brussels", "1914"]:
    print("{:10}: {}".format(w, " ".join(most_similar(w))))

As we see above, all of the outcomes are quite reasonable! E.g., 1914 (the year World War 1 started) is most similar to the next year of WW1, but is also highly similar to 1939 (the year WW2 started) -- very cool! We also observe that "brussels" is similar to other capitals, while "belgium" is similar to other countries - exactly what we would expect.

In [None]:
Mhat = model(None)
_min, _max = -10, 10

# Compare the original PMI matrix to our model's reproduction.
f, axes = plt.subplots(1, 2, sharey=True, figsize=(12, 16))
im1 = axes[0].imshow(PMI_np, cmap="coolwarm", vmin = _min, vmax = _max)
axes[0].set_title(r"Original, PMI")

im2 = axes[1].imshow(Mhat, cmap="coolwarm", vmin = _min, vmax = _max)
axes[1].set_title(r"Mhat (V @ W)")
f.colorbar(im2, ax=axes, orientation='horizontal', anchor=(0,2))

In [None]:
# Compare the original PMI matrix to our model's reproduction.
_min, _max = 0, 1
f, axes = plt.subplots(1, 2, sharey=True, figsize=(12, 16))
axes[0].imshow(tf.sigmoid(PMI_np), cmap="coolwarm", vmin = _min, vmax = _max)
axes[0].set_title(r"Original, $\sigma($PMI$)$")

Mhat = model(None)
im2 = axes[1].imshow(tf.sigmoid(Mhat), cmap="coolwarm", vmin = _min, vmax = _max)
axes[1].set_title(r"Mhat $\sigma($V @ W$)$")
f.colorbar(im2, ax=axes, orientation='horizontal', anchor=(0,2))

In [None]:
f, axes = plt.subplots(1, 2, sharey=False, figsize=(12, 16))

hist_dots = tf.reshape(model(None), (-1,))
hist_dots = tf.sigmoid(hist_dots)
n, bins, patches = axes[0].hist(hist_dots, 
                                bins=100,              
                                color="b",
                                alpha=1,
                                key="Dot products")
axes[0].legend()

n, bins, patches = axes[1].hist(tf.reshape(tf.sigmoid(PMI_np), (-1,)), 
                                bins=100,              
                                color="b",
                                alpha=1,
                                key="PMIs")
axes[1].legend()
