In [11]:
import os
import json
import numpy as np

In [34]:
def load_embeddings(filename: str):
    with open(os.path.join(filename, 'vocab', 'word_to_idx.json'), 'r', encoding='utf-8') as f:
        word_to_idx = json.load(f)
    with open(os.path.join(filename, 'vocab', 'idx_to_word.json'), 'r', encoding='utf-8') as f:
        data = json.load(f)
    idx_to_word = {int(k): v for k, v in data.items()}

    embeddings = np.load(os.path.join(filename, 'model', 'word_embeddings.npy'))
    return embeddings, {"w2i": word_to_idx, "i2w": idx_to_word}

In [35]:
embeddings, vocab = load_embeddings('../experiment/wikitext-2/CBOW/08-07-2023_23-58-38')

In [36]:
norms = (embeddings**2).sum(axis=1, keepdims=True) ** (1 / 2)
embeddings_norm = embeddings / norms

In [54]:
def vector_op(a, b, c, d):
    
    if a not in vocab["w2i"] or b not in vocab["w2i"] or c not in vocab["w2i"] or d not in vocab["w2i"]:
        return False

    emb1 = embeddings_norm[vocab["w2i"][a]]
    emb2 = embeddings_norm[vocab["w2i"][b]]
    emb3 = embeddings_norm[vocab["w2i"][c]]

    emb4 = emb2 - emb1 + emb3
    emb4_norm = (emb4**2).sum() ** (1 / 2)
    emb4 = emb4 / emb4_norm

    emb4 = np.reshape(emb4, (len(emb4), 1))
    dists = np.matmul(embeddings_norm, emb4).flatten()

    top5 = np.argsort(-dists)[:5]

    ret = vocab["w2i"][d] in top5

    # if ret:
    #     print(f"a - b ")
    #     for word_id in top5:
    #         print("{}: {:.3f}".format(vocab["i2w"][word_id.item()], dists[word_id]))    

    # top5 = np.argsort(-dists)[:10]
    # for word_id in top5:
    #     print("{}: {:.3f}".format(vocab["i2w"][word_id.item()], dists[word_id]))
    return ret

In [44]:
print(embeddings.shape)
print(embeddings.dtype)

print(embeddings_norm.shape)
print(embeddings_norm.dtype)

(33277, 300)
float32
(33277, 300)
float32


In [45]:
with open("../dataset/word-test.v1.txt", "r", encoding="utf-8") as f:
    word_test = f.readlines()

In [55]:
count = 0
total = 0
for line in word_test:
    line = line.strip()

    if len(line) == 0 or line.startswith("//"):
        continue
    elif line.startswith(": "):
        print(count, total)
        print(line)
        count = 0
        total = 0
    else:
        line = line.split()
        if vector_op(line[0], line[1], line[2], line[3]):
            count += 1
        total += 1


0 0
: capital-common-countries
0 506
: capital-world
0 4524
: currency
0 866
: city-in-state
0 2467
: family
35 506
: gram1-adjective-to-adverb
0 992
: gram2-opposite
0 812
: gram3-comparative
14 1332
: gram4-superlative
6 1122
: gram5-present-participle
0 1056
: gram6-nationality-adjective
2 1599
: gram7-past-tense
73 1560
: gram8-plural
11 1332
: gram9-plural-verbs
