In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
class SkipGram(nn.Module):
    def __init__(self, vocab_size, window_size, embedding_size):
        super(SkipGram, self).__init__()

        self.embeddings = nn.Embedding(vocab_size, embedding_size)
        self.linear = nn.Linear(embedding_size, vocab_size)

    def forward(self, target):
        target_embedding = self.embeddings(target)
        res = self.linear(target_embedding)
        return res

In [3]:
window_size = 3
doc = [
    "<i am henry>",
    "<i like college>",
    "<do henry like college>",
    "<i am do i like college>",
    "<i do like henry>",
    "<do i like henry>",
]
raw_text = " ".join(doc)
tokens = raw_text.split(" ")

In [None]:
def get_new_tokens(tok):
    char_tokens = []
    for token in tok:
        if len(token) < window_size:
            char_tokens.append(token)
        for i in range(0, len(token) - window_size + 1):
            char_tokens.append(token[i : i + window_size])
    return char_tokens

In [19]:
char_tokens = get_new_tokens(tokens)
print(char_tokens)
vocab = set(char_tokens)
vocab_size = len(vocab)

['<i', 'am', 'hen', 'enr', 'nry', 'ry>', '<i', 'lik', 'ike', 'col', 'oll', 'lle', 'leg', 'ege', 'ge>', '<do', 'hen', 'enr', 'nry', 'lik', 'ike', 'col', 'oll', 'lle', 'leg', 'ege', 'ge>', '<i', 'am', 'do', 'i', 'lik', 'ike', 'col', 'oll', 'lle', 'leg', 'ege', 'ge>', '<i', 'do', 'lik', 'ike', 'hen', 'enr', 'nry', 'ry>', '<do', 'i', 'lik', 'ike', 'hen', 'enr', 'nry', 'ry>']


In [6]:
data = []
word_index = {word: i for i, word in enumerate(vocab)}

for i in range(window_size, len(char_tokens) - window_size):
    context = word_index[char_tokens[i]]
    for j in range(-window_size, window_size + 1):
        if j == 0:
            continue
        data.append((context, word_index[char_tokens[i + j]]))
print(data)

[(14, 1), (14, 3), (14, 15), (14, 5), (14, 2), (14, 1), (5, 3), (5, 15), (5, 14), (5, 2), (5, 1), (5, 6), (2, 15), (2, 14), (2, 5), (2, 1), (2, 6), (2, 4), (1, 14), (1, 5), (1, 2), (1, 6), (1, 4), (1, 0), (6, 5), (6, 2), (6, 1), (6, 4), (6, 0), (6, 7), (4, 2), (4, 1), (4, 6), (4, 0), (4, 7), (4, 8), (0, 1), (0, 6), (0, 4), (0, 7), (0, 8), (0, 16), (7, 6), (7, 4), (7, 0), (7, 8), (7, 16), (7, 11), (8, 4), (8, 0), (8, 7), (8, 16), (8, 11), (8, 13), (16, 0), (16, 7), (16, 8), (16, 11), (16, 13), (16, 9), (11, 7), (11, 8), (11, 16), (11, 13), (11, 9), (11, 15), (13, 8), (13, 16), (13, 11), (13, 9), (13, 15), (13, 14), (9, 16), (9, 11), (9, 13), (9, 15), (9, 14), (9, 5), (15, 11), (15, 13), (15, 9), (15, 14), (15, 5), (15, 6), (14, 13), (14, 9), (14, 15), (14, 5), (14, 6), (14, 4), (5, 9), (5, 15), (5, 14), (5, 6), (5, 4), (5, 0), (6, 15), (6, 14), (6, 5), (6, 4), (6, 0), (6, 7), (4, 14), (4, 5), (4, 6), (4, 0), (4, 7), (4, 8), (0, 5), (0, 6), (0, 4), (0, 7), (0, 8), (0, 16), (7, 6), (7, 4)

In [7]:
embed_size = 10
learning_rate = 0.01
epochs = 1000

model = SkipGram(vocab_size, window_size, embed_size)
lossfn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

In [8]:
for epoch in range(epochs):
    total_loss = 0
    for context, target in data:
        optimizer.zero_grad()
        output = model(torch.tensor([context]))
        loss = lossfn(output, torch.tensor([target]))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if epoch % 50 == 0:
        print(epoch, total_loss / len(data))

0 2.8977245055899328
50 2.3386322324778757
100 2.287954709156841
150 2.2673744076774236
200 2.256828162540384
250 2.2503388248333316
300 2.2462501558316808
350 2.243549324217297
400 2.241594126435364
450 2.2400413674562154
500 2.238715876932858
550 2.237530699798039
600 2.2364458295763754
650 2.235446527296183
700 2.234530120885291
750 2.2336968061875324
800 2.232944846964207
850 2.232268684575347
900 2.2316595356480606
950 2.231107618938498


In [29]:
word_to_lookup = "henbenry"
lookup = get_new_tokens([word_to_lookup])
res = []
for lu in lookup:
    if lu in word_index.keys():
        wi = word_index[lu]
        embedding = model.embeddings(torch.tensor([wi]))
        res.append(embedding.detach().numpy()[0])
        print(f"Embedding for '{lu}': {embedding.detach().numpy()}")

print(f"Embedding for {word_to_lookup}: {sum(res) / len(res)}")

Embedding for 'hen': [[-0.79532444 -0.96615595  0.6535394  -1.0637437  -0.02516011  1.9169487
  -0.06009363  0.67181677 -0.22680724 -0.5015251 ]]
Embedding for 'enr': [[-0.43325138 -0.20698804  0.16642267 -0.58563775  1.3359994   0.58535695
  -1.5445085  -1.1151849  -0.49175018  0.8841796 ]]
Embedding for 'nry': [[ 0.42004493  1.4199466   0.78154904 -0.18266961 -0.2868884   0.35812554
  -0.64863205  0.87368417 -0.8444467   1.1747092 ]]
Embedding for henbenry: [-0.2695103   0.08226752  0.533837   -0.6106837   0.34131697  0.95347697
 -0.75107807  0.14343868 -0.5210014   0.5191212 ]
