In [16]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import pickle
import json

import numpy as np
import nltk
from nltk.translate.bleu_score import SmoothingFunction
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
torch.set_printoptions(precision=20)

In [18]:
print(device)

cuda


In [19]:
hidden_size = 16

In [20]:
SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

In [48]:
def default_load_files_for_model():
    with open('simple_lang.pkl', 'rb') as f:
        seq_lang = pickle.load(f)
    with open('comment_lang.pkl', 'rb') as f:
        comment_lang = pickle.load(f)
    with open('train_pairs.pkl', 'rb') as f:
        train_pairs = pickle.load(f)  

    return seq_lang, comment_lang, train_pairs

In [49]:
def load_data(pkl_file_name):
    seq_lang, comment_lang, train_pairs = default_load_files_for_model()

    with open(pkl_file_name, 'rb') as f:
        test_pairs = pickle.load(f)

    return seq_lang, comment_lang, train_pairs, test_pairs

In [50]:
# Load either bug or nobug for a single run
# seq_lang, comment_lang, train_pairs, test_pairs = load_data("test_pairs_bug.pkl")

In [59]:
# Load either bug or nobug for a single run
seq_lang, comment_lang, train_pairs, test_pairs = load_data("test_pairs_nobug.pkl")

In [51]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        #self.gru = nn.GRU(hidden_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        #output, hidden = self.gru(output, hidden)
        output, hidden = self.lstm(output, hidden)
        return output, hidden

    def initHidden(self):
        return (torch.zeros(1, 1, self.hidden_size, device=device), torch.zeros(1, 1, self.hidden_size, device=device))

In [52]:
def indexesFromSentence(lang, sentence):
    indexes=[]
    for word in sentence.split(' '):
        if word in lang.word2index:
            indexes.append(lang.word2index[word])
    return indexes

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)

def tensorsFromPair(code, comment):
    input_tensor = tensorFromSentence(seq_lang, code)
    target_tensor = tensorFromSentence(comment_lang, comment)
    return (input_tensor, target_tensor)

In [53]:
# randomly dummy set learning_rate
learning_rate = 0.007

In [54]:
encoder_ast = EncoderRNN(seq_lang.n_words, hidden_size).to(device)
encoder_comment = EncoderRNN(comment_lang.n_words, hidden_size).to(device)

ast_optimizer = optim.SGD(encoder_ast.parameters(), lr=learning_rate)
comment_optimizer = optim.SGD(encoder_comment.parameters(), lr=learning_rate)

In [55]:
def load_checkpoint():
    # checkpoint = torch.load(folderpath+"checkpoint_LSTM_"+str(hidden_size)+".pth")
    checkpoint = torch.load("checkpoint_LSTM_16.pth")
    
    encoder_ast = EncoderRNN(seq_lang.n_words, hidden_size).to(device)
    encoder_ast.load_state_dict(checkpoint['encoder_ast_state_dict'])
    ast_optimizer = checkpoint['ast_optimizer']
    
    encoder_comment = EncoderRNN(comment_lang.n_words, hidden_size).to(device)
    encoder_comment.load_state_dict(checkpoint['encoder_comment_state_dict'])
    comment_optimizer = checkpoint['comment_optimizer']
    
    return encoder_ast,ast_optimizer,encoder_comment,comment_optimizer

In [56]:
encoder_ast,ast_optimizer,encoder_comment,comment_optimizer = load_checkpoint()

In [57]:
def result(code, comment):
    with torch.no_grad():
        ast_hidden = encoder_ast.initHidden()
        comment_hidden = encoder_comment.initHidden()

        ast_tensor, comment_tensor = tensorsFromPair(code, comment)

        ast_length = ast_tensor.size(0)
        comment_length = comment_tensor.size(0)

        for ei in range(ast_length):
            ast_output, ast_hidden = encoder_ast(ast_tensor[ei], ast_hidden)

        for ei in range(comment_length):
            comment_output, comment_hidden = encoder_comment(comment_tensor[ei], comment_hidden)

        distance = F.pairwise_distance(ast_hidden[0], comment_hidden[0]).sum()
        similarity = torch.exp(-distance)

        return similarity.item()

In [60]:
for data in test_pairs:
    print(result(data[0], data[1]))

0.20636115968227386
0.35180291533470154
0.3796302378177643
0.7856696248054504
0.2378607988357544
0.8512691259384155
0.6956254243850708
0.8506895899772644
0.046975892037153244
0.8261533379554749
0.8459100127220154
0.0393519252538681
0.752209484577179
0.24426069855690002
0.3816138505935669
0.219221830368042
0.2946721017360687
0.023530835285782814
0.3536733388900757
0.24426069855690002
0.47069844603538513
0.19840958714485168
0.2946721017360687
0.8554349541664124
0.8123858571052551
0.46946626901626587
0.14699143171310425
0.7895798087120056
0.2331356406211853
0.447043240070343
0.01882416568696499
0.7525233626365662
0.2331356406211853
0.7807722687721252
0.5560671091079712
0.794459342956543
0.7137780785560608
0.5310855507850647
0.8512382507324219
0.018832921981811523
0.7558750510215759
0.2944118082523346
0.5786951184272766
0.9142813086509705
0.5506317019462585
0.5887378454208374
0.9558722376823425
0.47184041142463684
0.5524036884307861
0.4193935990333557
0.13008731603622437
0.568858802318573
