In [53]:
import numpy as np
import tensorflow as tf
import tqdm.notebook as tqdm
import os
import json
from gensim.models import Word2Vec, FastText

In [54]:
def read(path_with_everything: str):
    codes = []
    tags = []
    for contest_folder in os.listdir(path_with_everything):
        contest_path = os.path.join(path_with_everything, contest_folder)
        if not os.path.isdir(contest_path):
            continue
        meta_json = json.load(open(os.path.join(contest_path, 'meta.json'), 'r', encoding='utf-8'))
        submissions = meta_json['Submissions']
        for s in submissions.values():
            flattened_tree_path = os.path.join(contest_path, 'java', f'{s["Id"]}.java.ast.stm.flat')
            if not os.path.exists(flattened_tree_path):
                continue

            tags.append(s['Tags'])
            statements = open(flattened_tree_path, 'r', encoding='utf-8').read().strip('\n').split('\n\n')
            codes.append([])
            for stm in statements:
                codes[-1].append([])
                for i in stm.split('\n'):
                    try:
                        token, children = i.split('\t')
                        children = tuple(map(int, children.split())) if children else ()
                        codes[-1][-1].append((token, children))
                    except:
                        print(flattened_tree_path, stm, i, sep='\n\n')
    return codes, tags

In [55]:
codes, tags = read('utils/samples/Codeforces/Yoink-Data')

In [56]:
vocab = set()
for code in codes:
    for statement in code:
        for token, _ in statement:
            vocab.add(token)
vocab

{'isStart',
 'testHash',
 'resa',
 'sTree',
 'dpNext',
 '17L',
 'B_493',
 '.5',
 '"fir="',
 'D1005',
 'intIn',
 'A_493',
 'nextDouble',
 'mx_ind',
 '2833',
 'salida',
 'diez',
 'nastiaandagoodarray',
 'mostTimes',
 'currBest',
 'cdf720a',
 'stack_r',
 'tipolimite',
 'FenvikMax',
 '100003L',
 '"1 3 4"',
 'perf',
 'STD',
 'ADD_ARRAY_LIST',
 'cutCosts',
 'YaponskiiKrosvord',
 '1000000411',
 'Larr',
 'bit_pos',
 'TIME_END',
 'scanString',
 'hz',
 'firstapp',
 'alphabet',
 'CYetAnotherCardDeck',
 'upVote',
 'afteroff',
 'Perfectly_Imperfect_Array',
 'finalCuts',
 'Task_1520',
 '13456789',
 '44444444',
 'nb1',
 'pomXor',
 'nel',
 'idm',
 'baki',
 'prYN',
 '50000',
 'msub',
 'itemCount',
 'arri',
 'pts',
 'bestAdvance',
 'removeDuplicates',
 'pre422',
 'subString1',
 'LinkedList',
 'TaskE1',
 '" r:"',
 'RANGE',
 'xDec',
 '0.000000000005d',
 'rightsheeps',
 'toD',
 'readTree',
 '521',
 '109',
 'outputReader',
 '0.000000000000000009d',
 'swapArray',
 'r1',
 'Shirts',
 'A_TanyaAndStairways',
 'l

In [57]:
vocab = list(vocab)
token_to_id = {j:i for i, j in enumerate(vocab)}

In [58]:
dim = 192

In [59]:
w2v_model = Word2Vec.load('utils/java.w2v')
w2v_wv = w2v_model.wv
w2v_embeddings = np.array([w2v_wv[i] if i in w2v_wv else np.zeros((dim,)) for i in vocab])

In [60]:
ft_model = FastText.load('utils/java.ft')
ft_wv = ft_model.wv
ft_embeddings = np.array([ft_wv[i] if i in ft_wv else np.zeros((dim,)) for i in vocab])

In [61]:
print(w2v_embeddings[:10])
print(ft_embeddings[:10])

[[ 0.30367     0.10270511 -0.4131966  ... -0.0845805  -0.01419335
  -0.17824554]
 [ 0.05798344 -0.05268424  0.10425927 ...  0.09329583  0.12661745
  -0.04027711]
 [-0.12143497 -0.05300783  0.02445679 ...  0.06512076 -0.01052139
  -0.04133365]
 ...
 [ 0.07665545  0.00831272 -0.0460017  ...  0.08507466  0.03281435
   0.19442406]
 [ 0.02235229 -0.09181054 -0.01255795 ...  0.01226202  0.03656536
   0.10051543]
 [-0.06773315 -0.04321453  0.0452197  ...  0.05534829 -0.02168257
   0.03712261]]
[[ 0.2969489   0.16750643  1.1966213  ... -0.5208342  -0.38895652
  -1.9641588 ]
 [-0.2575132   0.23734853 -0.12299175 ...  0.34764686 -0.925342
  -1.0036569 ]
 [-0.7809004   0.5879315   0.43735573 ...  1.0293046  -2.060124
  -1.3344779 ]
 ...
 [ 0.01664163 -0.00615533 -0.00461163 ...  0.19570808 -0.02923601
   0.0089868 ]
 [-0.09287308 -0.29611906  0.0054104  ... -0.16896923 -0.06315529
  -0.05832468]
 [ 0.19434837  0.7019075   0.585249   ... -0.12767608 -0.05076283
  -0.6632441 ]]


In [62]:
print(sum(i in w2v_wv for i in vocab), len(vocab))
print(sum(i in ft_wv for i in vocab), len(vocab))

21777 21777
21777 21777


In [63]:
print(len(w2v_model.wv.key_to_index))
print(len(ft_model.wv.key_to_index))

21777
21777
