In [2]:
import sys
import ast
import numpy as np


def parse_array(s):
    return np.array(ast.literal_eval(s))

def read_array():
    return parse_array(sys.stdin.readline())

def write_array(arr):
    print(repr(arr.tolist()))


def generate_w2v_sgns_samples(text, window_size, vocab_size, ns_rate):
    """
    text - list of integer numbers - ids of tokens in text
    window_size - odd integer - width of window
    vocab_size - positive integer - number of tokens in vocabulary
    ns_rate - positive integer - number of negative tokens to sample per one positive sample

    returns list of training samples (CenterWord, CtxWord, Label)
    """
    w_side = int((window_size - 1) / 2)
    train = []
    for i in range(w_side, len(text) - w_side+1):
        center = text[i]
        for cxtWord in text[i - w_side:i] + text[i + 1:i+w_side + 2]:
            train.append([center, cxtWord, 1])
            for neg in range(ns_rate):
                r = np.random.randint(vocab_size)
                train.append([center, r, 0])
    return train


text = read_array()
window_size = int(sys.stdin.readline().strip())
vocab_size = int(sys.stdin.readline().strip())
ns_rate = int(sys.stdin.readline().strip())

result = generate_w2v_sgns_samples(text, window_size, vocab_size, ns_rate)

write_array(np.array(result))

SyntaxError: unexpected EOF while parsing (<unknown>, line 0)

In [3]:
import sys
import ast
import numpy as np


def parse_array(s):
    return np.array(ast.literal_eval(s))

def read_array():
    return parse_array(sys.stdin.readline())

def write_array(arr):
    print(repr(arr.tolist()))


def update_w2v_weights(center_embeddings, context_embeddings, center_word, context_word, label, learning_rate):
    """
    center_embeddings - VocabSize x EmbSize
    context_embeddings - VocabSize x EmbSize
    center_word - int - identifier of center word
    context_word - int - identifier of context word
    label - 1 if context_word is real, 0 if it is negative
    learning_rate - float > 0 - size of gradient step
    """
    sigm = lambda x: 1 / (1 + np.exp(-x))
    center_word_emb = center_embeddings[center_word]
    context_word_emb = context_embeddings[context_word]
    prob = sigm(center_word_emb @ context_word_emb)
    err = prob - label
    w_grad = err * context_word_emb
    d_grad = err * center_word_emb
    
    center_embeddings[center_word] -= learning_rate * w_grad
    context_embeddings[context_word] -= learning_rate * d_grad
    return center_embeddings, context_embeddings
    

center_embeddings = read_array()
context_embeddings = read_array()
center_word = int(sys.stdin.readline().strip())
context_word = int(sys.stdin.readline().strip())
label = int(sys.stdin.readline().strip())
learning_rate = float(sys.stdin.readline().strip())

update_w2v_weights(center_embeddings, context_embeddings,
                   center_word, context_word, label, learning_rate)

write_array(center_embeddings)
write_array(context_embeddings)

SyntaxError: unexpected EOF while parsing (<unknown>, line 0)

In [None]:
import sys
import ast
import numpy as np


def read_list():
    return ast.literal_eval(sys.stdin.readline())

def parse_array(s):
    return np.array(ast.literal_eval(s))

def read_array():
    return parse_array(sys.stdin.readline())

def write_array(arr):
    print(repr(arr.tolist()))


def generate_ft_sgns_samples(text, window_size, vocab_size, ns_rate, token2subwords):
    """
    text - list of integer numbers - ids of tokens in text
    window_size - odd integer - width of window
    vocab_size - positive integer - number of tokens in vocabulary
    ns_rate - positive integer - number of negative tokens to sample per one positive sample
    token2subwords - list of lists of int - i-th sublist contains list of identifiers of n-grams for token #i (list of subword units)

    returns list of training samples (CenterSubwords, CtxWord, Label)
    """
    w_side = (window_size - 1) // 2
    train = []
    text = text.tolist()
    for i in range(len(text)):
        center = text[i]

        left_idx = max(0, i - w_side)
        left = text[left_idx:i]
        right_idx = min(len(text), i+w_side+1)
        right = text[i+1:right_idx]
        
        ctx = left + right
        #print(len(ctx))
        for cxtWord in ctx:
            v = [center] + token2subwords[center]
            train.append((v, cxtWord, 1))
            for neg in range(ns_rate):
                r = np.random.randint(vocab_size)
                train.append((v, r, 0))
    return train


text = read_array()
window_size = int(sys.stdin.readline().strip())
vocab_size = int(sys.stdin.readline().strip())
ns_rate = int(sys.stdin.readline().strip())
token2subwords = read_list()

result = generate_ft_sgns_samples(text, window_size, vocab_size, ns_rate, token2subwords)

print(repr(result))

In [None]:
import sys
import ast
import numpy as np


def parse_array(s):
    return np.array(ast.literal_eval(s))

def read_array():
    return parse_array(sys.stdin.readline())

def write_array(arr):
    print(repr(arr.tolist()))


def update_ft_weights(center_embeddings, context_embeddings, center_subwords, context_word, label, learning_rate):
    """
    center_embeddings - VocabSize x EmbSize
    context_embeddings - VocabSize x EmbSize
    center_subwords - list of ints - list of identifiers of n-grams contained in center word
    context_word - int - identifier of context word
    label - 1 if context_word is real, 0 if it is negative
    learning_rate - float > 0 - size of gradient step
    """
    sigm = lambda x: 1 / (1 + np.exp(-x))

    center_word_emb = np.zeros((1, center_embeddings.shape[1]))
    n = len(center_subwords)
    for w in center_subwords:
        center_word_emb += center_embeddings[w]
    center_word_emb /= n

    context_word_emb = context_embeddings[context_word]
    prob = sigm(center_word_emb @ context_word_emb)
    err = prob - label
    w_grad = err * context_word_emb / n
    
    for w in center_subwords:
        center_embeddings[w] -= learning_rate * w_grad
    d_grad = err * center_word_emb

    context_embeddings[context_word] -= learning_rate * d_grad.squeeze()
    return center_embeddings, context_embeddings

center_embeddings = read_array()
context_embeddings = read_array()
center_subwords = read_array()
context_word = int(sys.stdin.readline().strip())
label = int(sys.stdin.readline().strip())
learning_rate = float(sys.stdin.readline().strip())

update_ft_weights(center_embeddings, context_embeddings,
                  center_subwords, context_word, label, learning_rate)

write_array(center_embeddings)
write_array(context_embeddings)

In [11]:
import sys
import ast
import numpy as np
import scipy.sparse


def read_array():
    return ast.literal_eval(sys.stdin.readline())

def write_array(arr):
    print(repr(arr.tolist()))


def generate_coocurrence_matrix(texts, vocab_size):
    """
    texts - list of lists of ints - i-th sublist contains identifiers of tokens in i-th document
    vocab_size - int - size of vocabulary
    returns scipy.sparse.dok_matrix
    """
    dok = scipy.sparse.dok_matrix((vocab_size, vocab_size))
    n = len(texts)
    for i in range(vocab_size):
        for j in range(vocab_size):
            if i == j:
                continue
            for text in texts:
                if i in text and j in text:
                    dok[i, j] += 1
    return dok

text = read_array()
vocab_size = int(sys.stdin.readline().strip())

result = generate_coocurrence_matrix(text, vocab_size)

write_array(result.toarray())

SyntaxError: unexpected EOF while parsing (<unknown>, line 0)

In [None]:
import sys
import ast
import numpy as np


def parse_array(s):
    return np.array(ast.literal_eval(s))

def read_array():
    return parse_array(sys.stdin.readline())

def write_array(arr):
    print(repr(arr.tolist()))


def update_glove_weights(x, w, d, alpha, max_x, learning_rate):
    """
    x - square integer matrix VocabSize x VocabSize - coocurrence matrix
    w - VocabSize x EmbSize - first word vectors
    d - VocabSize x EmbSize - second word vectors
    alpha - float - power in weight smoothing function f
    max_x - int - maximum coocurrence count in weight smoothing function f
    learning_rate - positive float - size of gradient step
    """
    f = lambda x: np.where(x <= max_x, (x / max_x) ** alpha, 1.0)
    fx = f(x)
    logx = np.log1p(x)
    wd = w @ d.T
    
    err = fx * (logx - wd)
    dw = (-2 * err) @ d
    dd = (-2 * err).T @ w
    w[:] = w - learning_rate * dw
    d[:] = d - learning_rate * dd

x = read_array()
w = read_array()
d = read_array()
alpha = float(sys.stdin.readline().strip())
max_x = int(sys.stdin.readline().strip())
learning_rate = float(sys.stdin.readline().strip())

update_glove_weights(x, w, d, alpha, max_x, learning_rate)

write_array(w)
write_array(d)

In [None]:
import sys
import ast
import numpy as np


def parse_array(s):
    return np.array(ast.literal_eval(s))

def read_array():
    return parse_array(sys.stdin.readline())

def write_array(arr):
    print(repr(arr.tolist()))


def get_nearest(embeddings, query_word_id, get_n):
    """
    embeddings - VocabSize x EmbSize - word embeddings
    query_word_id - integer - id of query word to find most similar to
    get_n - integer - number of most similar words to retrieve

    returns list of `get_n` tuples (word_id, similarity) sorted by descending order of similarity value
    """
    embbeddings /= np.linalg.norm(embeddings, ord=2, axis=0)
    word_emb = embeddings[query_word_id]
    distances = embeddings @ word_emb
    top_idx = np.argsort(distances)[::-1][:get_n]
    answer = []
    for idx in top_idx:
        answer.append((idx, distances[idx]))
    return answer


embeddings = read_array()
query_word_id = int(sys.stdin.readline().strip())
get_n = int(sys.stdin.readline().strip())

result = get_nearest(embeddings, query_word_id, get_n)

write_array(np.array(result))

In [5]:
a = [[0.4924379807931911, 0.4372909146117835, 0.3786715543730712, 0.08765497537657796, 0.6756358041845987, 0.21918880613708924, 0.340282751187035, 0.7564088132708989], [0.44017230951695874, 0.1521570621825502, 0.5269119216892, 0.13969272399772448, 0.8905714313092886, 0.6855990837043952, 0.7329403885550506, 0.6967890760190271], [0.08209802635148533, 0.1572554699218176, 0.32713724494890817, 0.3282258214949071, 0.18745515454533546, 0.6684444254856369, 0.3821374793787291, 0.14405318337384954], [0.4685806535328395, 0.6823014264406443, 0.45242992021735506, 0.13462511393109822, 0.37643757313276727, 0.33711083265770425, 0.953787573113227, 0.9375271756567445], [0.39126114202292495, 0.4240577502243138, 0.45596962948649294, 0.8405515982747295, -0.0035639217032158305, 0.08033917718260508, 0.7064913228196669, 0.6149593432967865], [0.47541554931084073, 0.6053107399554578, 0.7731362614524417, 0.43745637956185424, 0.06569482940157491, 0.3140376497070242, 0.17562601577455203, 0.06255425666405454], [0.4623567291387256, 0.6292685831821194, 0.35307170833144663, 0.48720150493272496, 0.27648124218553927, 0.40031870928193414, 0.02599896525632328, 0.6563730701085458], [0.7006508390687906, 0.9514452459838603, 0.2451080863945575, 0.5123482317902458, 0.0073772504296028, 0.04790725982848376, 0.7551912304218487, 0.05519748500067578], [0.7949394482871045, 0.6139101679060078, 0.9636668795699912, 0.20793187470932273, 0.5691843546570802, 0.64175464297284, 0.5316441135537809, 0.6480149378052954], [0.9602567071597025, 0.6144154154415676, 0.6354348353858741, 0.8176817682689422, 0.4365203872332607, 0.6018591595887026, 0.24588600970700025, 0.8417956150068155], [0.27828446555932296, 0.4807798024399753, 0.28161693377865327, 0.5774428203991377, 0.18228473976288317, 0.16857605985792, 0.4559900707914125, 0.9441427490521547], [0.7414879823479924, 0.8647281941270776, 0.35393371610868524, 0.2804504748200687, 0.4701667907804358, 0.8861670917859702, 0.6892073557636252, 0.1330433396753354], [0.20061898730429373, 0.13089938905380893, 0.2180695062782767, 0.16685666837275925, 0.7485275811921536, 0.0009222052073699638, 0.8764827495937608, 0.7118160874636742], [0.4102880272659505, 0.6911504017206197, 0.674527579437148, 0.9585605434693513, 0.9184313869364793, 0.13068786160259238, 0.7817381372985575, 0.11406856860587689], [0.9637996777388024, 0.9096498590978186, 0.09897429193559693, 0.6229947983032628, 0.6103840977039744, 0.07802796500873965, 0.27955682331764764, 0.7458814824074171], [0.7543834769837153, 0.6104516650024087, 0.350083552653517, 0.23679074594453386, 0.26123401278746905, 0.21552344000610668, 0.16486681789984003, 0.5376355458748053], [0.970347909471565, 0.531785083567912, 0.6761950893288675, 0.84233111602816, 0.7717486528345849, 0.19409526576427183, 0.6650316349624878, 0.9772140892554211], [0.2447519981151074, 0.11727034778098522, 0.4645834487138182, 0.0953191957887265, 0.18973178771818744, 0.7330847257058634, 0.1697697394333707, 0.08721362050017656], [0.5991518757961101, 0.15871166110849755, 0.5804090119102789, 0.27786186874031715, 0.08935690910963434, 0.30450146770449615, 0.6783458480620492, 0.17629782226896706], [0.4724575749096397, 0.7272444260638546, 0.45727256576998476, 0.13904986666789265, 0.33443609088716, 0.7534586544539846, 0.6994487600617424, 0.06700176145670023], [0.48069977927688057, 0.28453027574930045, 0.6026090887574915, 0.5395475284054398, 0.43738833425061685, 0.44290239322523606, 0.9159159974831983, 0.5758008552741829]]

In [6]:
b = [[0.4924379807931911, 0.4372909146117835, 0.3786715543730712, 0.08765497537657796, 0.6756358041845987, 0.21918880613708924, 0.340282751187035, 0.7564088132708989], [0.44017230951695874, 0.1521570621825502, 0.5269119216892, 0.13969272399772448, 0.8905714313092886, 0.6855990837043952, 0.7329403885550506, 0.6967890760190271], [0.08209802635148533, 0.1572554699218176, 0.32713724494890817, 0.3282258214949071, 0.18745515454533546, 0.6684444254856369, 0.3821374793787291, 0.14405318337384954], [0.4685806535328395, 0.6823014264406443, 0.45242992021735506, 0.13462511393109822, 0.37643757313276727, 0.33711083265770425, 0.953787573113227, 0.9375271756567445], [0.3747200281624562, 0.2957023888332825, 0.4268682014731547, 0.7845482742542396, -0.20545128867001977, -0.10572156027398755, 0.556371523841753, 0.5072352412236689], [0.458874435450372, 0.4769553785644265, 0.7440348334391035, 0.3814530555413644, -0.13619253756522903, 0.12797691225043156, 0.025506216796638226, -0.045169845409062964], [0.4623567291387256, 0.6292685831821194, 0.35307170833144663, 0.48720150493272496, 0.27648124218553927, 0.40031870928193414, 0.02599896525632328, 0.6563730701085458], [0.7006508390687906, 0.9514452459838603, 0.2451080863945575, 0.5123482317902458, 0.0073772504296028, 0.04790725982848376, 0.7551912304218487, 0.05519748500067578], [0.7949394482871045, 0.6139101679060078, 0.9636668795699912, 0.20793187470932273, 0.5691843546570802, 0.64175464297284, 0.5316441135537809, 0.6480149378052954], [0.9602567071597025, 0.6144154154415676, 0.6354348353858741, 0.8176817682689422, 0.4365203872332607, 0.6018591595887026, 0.24588600970700025, 0.8417956150068155], [0.27828446555932296, 0.4807798024399753, 0.28161693377865327, 0.5774428203991377, 0.18228473976288317, 0.16857605985792, 0.4559900707914125, 0.9441427490521547], [0.7249468684875237, 0.7363728327360463, 0.324832288095347, 0.22444715079957883, 0.26827942381363185, 0.7001063543293775, 0.5390875567857114, 0.02531923760221791], [0.20061898730429373, 0.13089938905380893, 0.2180695062782767, 0.16685666837275925, 0.7485275811921536, 0.0009222052073699638, 0.8764827495937608, 0.7118160874636742], [0.4102880272659505, 0.6911504017206197, 0.674527579437148, 0.9585605434693513, 0.9184313869364793, 0.13068786160259238, 0.7817381372985575, 0.11406856860587689], [0.9637996777388024, 0.9096498590978186, 0.09897429193559693, 0.6229947983032628, 0.6103840977039744, 0.07802796500873965, 0.27955682331764764, 0.7458814824074171], [0.7543834769837153, 0.6104516650024087, 0.350083552653517, 0.23679074594453386, 0.26123401278746905, 0.21552344000610668, 0.16486681789984003, 0.5376355458748053], [0.970347909471565, 0.531785083567912, 0.6761950893288675, 0.84233111602816, 0.7717486528345849, 0.19409526576427183, 0.6650316349624878, 0.9772140892554211], [0.2447519981151074, 0.11727034778098522, 0.4645834487138182, 0.0953191957887265, 0.18973178771818744, 0.7330847257058634, 0.1697697394333707, 0.08721362050017656], [0.5991518757961101, 0.15871166110849755, 0.5804090119102789, 0.27786186874031715, 0.08935690910963434, 0.30450146770449615, 0.6783458480620492, 0.17629782226896706], [0.4724575749096397, 0.7272444260638546, 0.45727256576998476, 0.13904986666789265, 0.33443609088716, 0.7534586544539846, 0.6994487600617424, 0.06700176145670023], [0.48069977927688057, 0.28453027574930045, 0.6026090887574915, 0.5395475284054398, 0.43738833425061685, 0.44290239322523606, 0.9159159974831983, 0.5758008552741829]]

In [10]:
~np.isclose(a, b)

array([[False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [False, False, Fal