In [None]:
!pwd
# !pip install -e .
# !npm install

In [None]:
import sys
sys.path.append('/work/paras/representjs/representjs')

import torch
from pathlib import Path
from models.code_mlm import CodeMLM
import sentencepiece as spm

sp = spm.SentencePieceProcessor()
sp.Load("data/codesearchnet_javascript/csnjs_8k_9995p_unigram_url.model")
pad_id = sp.PieceToId("[PAD]")
mask_id = sp.PieceToId("[MASK]")

ckpt_file = Path("data/runs/22006_roberta_no_weight_decay_1590101821/ckpt_pretrain_ep0003_step0050000.pth").resolve()

In [None]:
model = CodeMLM(sp.GetPieceSize(), pad_id=pad_id).cuda()
model.load_state_dict(torch.load(ckpt_file)['model_state_dict'])
model.eval()

In [None]:
string = """const x = function (z) {
    var x = 1;
    for (var i = 0; i < 10; i++) {
        x += i;
    }
    return z + x;
}"""

mask_pos = 14
seq = [sp.PieceToId("<s>")] + sp.EncodeAsIds(string) + [sp.PieceToId("</s>")]
masked_seq = seq[:]
masked_seq[mask_pos] = mask_id

print(sp.DecodeIds(seq))
print(sp.DecodeIds(masked_seq))

with torch.no_grad():
    logits = model(torch.LongTensor(masked_seq).cuda().unsqueeze(0))
print("\n")

topk_vals, topk_idx = logits[0, mask_pos].topk(10, largest=True)
probs = logits[0, mask_pos].softmax(dim=0)
for val, idx in zip(topk_vals, topk_idx):
    val, idx = val.item(), idx.item()
    word_piece = sp.IdToPiece(idx)
    print(f"{word_piece}\t{probs[idx].item():.3f}")