In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import re

In [3]:
class GaussianKernelEmbedding(nn.Module):
    def __init__(self, vocab_size, embedding_dim, sigma=1.0):
        super().__init__()
        self.embedding_weights = nn.Parameter(torch.randn(vocab_size, embedding_dim, dtype = torch.float64))
        self.sigma = sigma

    def forward(self, context, center):
        context_vecs = self.embedding_weights[context] # batch_size * (winlen - 1) * embedding
        center_vec = self.embedding_weights[center] # batch_size * embedding
        diff = context_vecs - center_vec.unsqueeze(1)  # batch_size * (winlen - 1) * embedding
        dist_sq = torch.sum(diff ** 2, dim=2)  # batch_size * (winlen - 1)
        weights = torch.exp(-dist_sq / (2 * self.sigma ** 2))  # batch_size * (winlen - 1)
        weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-8)  # batch_size * (winlen - 1)
        weighted_context = (weights.unsqueeze(2) * context_vecs).sum(dim=1)  # batch_size * embedding

        return weighted_context
    
    def getEmbedding(self, id):
        return self.embedding_weights[id]

In [None]:
def preprocessing(text):
    _ = re.findall(r"[A-Za-z]+", text)
    words = []
    for word in _:
        words.append(word.lower())
    word2id = {w : i for i, w in enumerate(set(words))}
    id2word = {i : w for _, (w, i) in enumerate(word2id.items())}
    return words, word2id, id2word

def generateData(words, word2id, winlen): # winlen must be odd
    vocab_size = len(word2id)
    word_size = len(words)
    batch_size = word_size - winlen + 1
    context_train = np.zeros((batch_size, winlen - 1))
    center_train = np.zeros((batch_size))
    for _ in range(winlen // 2, word_size - winlen // 2):
        fr = _ - winlen // 2
        center_train[fr] = word2id[words[_]]
        for __ in range(_ - winlen // 2, _):
            context_train[fr][__ - (_ - winlen // 2)] = word2id[words[__]]
        for __ in range(_ + 1, _ + winlen // 2 + 1):
            context_train[fr][__ - (_ - winlen // 2) - 1] = word2id[words[__]]
    return torch.tensor(context_train).int(), torch.tensor(center_train).int(), vocab_size, word_size
        
with open("wiki.train.tokens", 'r') as f:
    text = f.read()

words, word2id, id2word = preprocessing(text)
context_train, center_train, vocab_size, word_size = generateData(words, word2id, 5)
print(context_train.shape, center_train.shape, vocab_size)

torch.Size([1694556, 6]) torch.Size([1694556]) 27228


In [None]:
def train(model, optimizer, context_train, center_train):
    criterion = nn.CosineEmbeddingLoss()
    flag = torch.ones(context_train.shape[0])
    num_epoches = 1000

    best_loss = float('inf')
    for epoch in range(num_epoches):
        model.train()
        optimizer.zero_grad()
        output = model(context_train, center_train)
        target = model.embedding_weights[center_train].detach()
        loss = criterion(output, target, flag)
        loss.backward()
        optimizer.step()
        if epoch % 5 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item()}")

        if loss.item() < best_loss:
            best_loss = loss.item()
        elif epoch > 100 and loss.item() > best_loss * 1.05: 
            print(f"Early stopping at epoch {epoch}")
            break
    
    return model, optimizer

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GaussianKernelEmbedding(vocab_size, 32)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [7]:
current_model = torch.load("model-32.pth")
model.load_state_dict(current_model["state_dict"])
optimizer.load_state_dict(current_model["optimizer"])

In [None]:
model, optimizer = train(model, optimizer, context_train, center_train)

In [8]:
word2embed = {}
for (word, id) in word2id.items():
    embedding = model.getEmbedding(id).detach()
    word2embed[word] = embedding

with open("output.txt", 'w') as f:
    for (word, embed) in word2embed.items():
        f.write(word)
        f.write(str(list(embed.numpy())))
        f.write('\n')

In [9]:
def getClose(target):
    sims = []

    cos = nn.CosineSimilarity(dim=0)
    for (word, embed) in word2embed.items():
        sim = cos(embed, target).item()
        sims.append((word, sim))

    res = sorted(sims, key=lambda x : x[1], reverse = True)
    return res[:5]

In [10]:
import random

correctCount = 0
totalCount = 0
with open("questions-words.txt", 'r') as f:
    qs = f.read()

qs_s = qs.split('\n')
#random.shuffle(qs_s)

for q in qs_s[:100]:
    words = q.split()
    try:
        ans = getClose(word2embed[words[0].lower()] + word2embed[words[1].lower()] - word2embed[words[2].lower()])[0][0]
        if ans == words[3].lower():
            correctCount += 1
        else:
            print(words[0].lower(), words[1].lower(), words[2].lower(), ans)
        totalCount += 1
    except KeyError:
        pass

print(correctCount, totalCount)

athens greece baghdad cassette
athens greece bangkok smart
athens greece beijing coatings
athens greece berlin greece
athens greece bern ambitious
athens greece cairo isolated
athens greece canberra humanitarian
athens greece hanoi greece
athens greece havana doom
athens greece helsinki beno
athens greece london smart
athens greece madrid pbs
athens greece moscow pentwyn
athens greece oslo incoherent
athens greece ottawa isolated
athens greece paris greece
athens greece rome greece
athens greece stockholm greece
athens greece tehran incoherent
athens greece tokyo phosphorus
baghdad iraq bangkok exercises
baghdad iraq beijing stunts
baghdad iraq berlin narragansett
baghdad iraq bern stunts
baghdad iraq cairo reforming
baghdad iraq canberra baghdad
baghdad iraq hanoi dollar
baghdad iraq havana baghdad
baghdad iraq helsinki possessed
baghdad iraq london swears
baghdad iraq madrid baghdad
baghdad iraq moscow stunts
baghdad iraq oslo stunts
baghdad iraq ottawa narragansett
baghdad iraq pari

In [8]:
state = {
    'state_dict': model.state_dict(),  # model parameters
    'optimizer': optimizer.state_dict(),  # optimizer state
}
torch.save(state, 'model-64.pth')