In [38]:
# Copyright (c) Microsoft Corporation. 
# Licensed under the MIT license.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable
import copy
from torch import Tensor
from typing import Optional, Any, Union, Callable
from torch.nn.modules.activation import MultiheadAttention
from torch.nn.modules.linear import Linear
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.normalization import LayerNorm
from torch.nn.modules.container import ModuleList



# Borrowed from https://github.com/facebookresearch/CodeGen/blob/main/codegen_sources/model/src/model/transformer.py
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    if padding_idx is not None:
        nn.init.constant_(m.weight[padding_idx], 0)
    return m


class Seq2Seq(nn.Module):
    """
        Build Seqence-to-Sequence.
        
        Parameters:

        * `encoder`- encoder of seq2seq model. e.g. roberta
        * `decoder`- decoder of seq2seq model. e.g. transformer
        * `config`- configuration of encoder model. 
        * `beam_size`- beam size for beam search. 
        * `max_length`- max length of target for beam search. 
        * `sos_id`- start of symbol ids in target for beam search.
        * `eos_id`- end of symbol ids in target for beam search. 
    """
    def __init__(self, encoder,decoder,config,beam_size=None,max_length=None,sos_id=None,eos_id=None,mask_id=None,pad_id=None,n_langs=1, \
                sd_config='16absasym'):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder=decoder
        self.config=config
        self.register_buffer("bias", torch.tril(torch.ones(2048, 2048)))
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.sd_asym = False
        if sd_config.endswith('div'):
            self.sd_dim = config.hidden_size//int(sd_config[:-3])
        elif sd_config.endswith('abs'):
            self.sd_dim = int(sd_config[:-3])
        elif sd_config.endswith('absasym'):
            self.sd_dim = int(sd_config[:-7])
            self.sd_asym = True
            self.sd_weight = nn.Linear(self.sd_dim,self.sd_dim,bias=False)
            self.sd_b1 = nn.Linear(self.sd_dim,1,bias=False)
            self.sd_b2 = nn.Linear(self.sd_dim,1,bias=False)
        self.lsm = nn.LogSoftmax(dim=-1)
        self.tie_weights()
        
        self.beam_size=beam_size
        self.max_length=max_length
        self.sos_id=sos_id # CLS
        self.eos_id=eos_id # SEP
        self.mask_id=mask_id
        self.pad_id=pad_id
        
        self.lang_embeddings = Embedding(n_langs, config.hidden_size)
        self.eps=1e-10
        
    def set_mode(self, mode):
        self.mode = mode
        
    def _tie_or_clone_weights(self, first_module, second_module):
        """ Tie or clone module weights depending of weither we are using TorchScript or not
        """
        if self.config.torchscript:
            first_module.weight = nn.Parameter(second_module.weight.clone())
        else:
            first_module.weight = second_module.weight
                  
    def tie_weights(self):
        """ Make sure we are sharing the input and output embeddings.
            Export to TorchScript can't handle parameter sharing so we are cloning them instead.
        """
        self._tie_or_clone_weights(self.lm_head, self.encoder.embeddings.word_embeddings)      
        
    def forward(self, inputs):
        if self.mode=='translation':
            return self.forward_translation(inputs)
        elif self.mode=='translation_structdec':
            return self.forward_translation_structdec(inputs)
        elif self.mode=='structdec':
            return self.forward_structdec(inputs)
        elif self.mode=='mlm':
            return self.forward_mlm(inputs)
        elif self.mode=='ep':
            return self.forward_ep(inputs)
        elif self.mode=='na':
            return self.forward_na(inputs)
        elif self.mode=='gen':
            return self.forward_gen(inputs)
        else:
            raise Exception('Unknown value for mode.')


    def forward_translation_structdec(self, inputs):   
        source_ids,source_mask,position_idx,attn_mask,target_ids,target_mask,source_lang,target_lang, dfg_links = inputs
        #embedding
        nodes_mask=position_idx.eq(0)
        token_mask=position_idx.ge(2) # CLS to SEP       
        inputs_embeddings=self.encoder.embeddings.word_embeddings(source_ids) + self.lang_embeddings(source_lang)[:,None,:] # b, L1, d
        nodes_to_token_mask=nodes_mask[:,:,None]&token_mask[:,None,:]&attn_mask
        nodes_to_token_mask=nodes_to_token_mask/(nodes_to_token_mask.sum(-1)+1e-10)[:,:,None]
        avg_embeddings=torch.einsum("abc,acd->abd",nodes_to_token_mask,inputs_embeddings)
        inputs_embeddings=inputs_embeddings*(~nodes_mask)[:,:,None]+avg_embeddings*nodes_mask[:,:,None]  

        outputs = self.encoder(inputs_embeds=inputs_embeddings,attention_mask=attn_mask,position_ids=position_idx)
        encoder_output = outputs[0].permute([1,0,2]).contiguous() # L1,b,d
        
        if target_ids is not None:  
            attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]]) 
            # diag, lower triangle are zeros. rest -1000. L2,L2
            tgt_embeddings = self.encoder.embeddings(target_ids).permute([1,0,2]).contiguous() + self.lang_embeddings(target_lang)[None,:,:]
            # L2, b, d
            out = self.decoder(tgt_embeddings,encoder_output,tgt_mask=attn_mask,\
                                             memory_key_padding_mask=(1-source_mask).bool()).permute([1,0,2]).contiguous() # b, L2, d

            hidden_states = torch.tanh(self.dense(out))
            lm_logits = self.lm_head(hidden_states)
            # Shift so that tokens < n predict n
            active_loss = target_mask[..., 1:].ne(0).view(-1) == 1
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = target_ids[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[active_loss],
                                shift_labels.view(-1)[active_loss])

    #         out = self.sd_head(out)
            out = out[:,:-1,:self.sd_dim]
            dfg_links = dfg_links[:,1:,1:].long() # b, L-1, L-1
            if self.sd_asym:
                pred_attn = torch.bmm(out, self.sd_weight(out).permute(0,2,1)) # b, L-1, L-1
                pred_attn = torch.sigmoid( pred_attn + self.sd_b1(out) + self.sd_b2(out).permute(0,2,1) )
            else:
                pred_attn = torch.sigmoid(torch.bmm(out, out.permute(0,2,1))) # b, L-1, L-1
            tp = ((pred_attn>=0.5)*(dfg_links==1)).int().sum()
            tn = ((pred_attn<0.5)*(dfg_links==0)).int().sum()

#             sd_mat = torch.clone(pred_attn)
            pred_attn = pred_attn.reshape(-1, 1) #bLL
            pred_attn = torch.log(torch.hstack((1-pred_attn+self.eps, pred_attn+self.eps))) # 2, n

            total = (dfg_links!=-1).long().sum()
            num_pos = (dfg_links==1).long().sum()
            loss_fct = nn.NLLLoss(ignore_index=-1, weight=torch.stack((total/(total-num_pos),total/num_pos)))
            sd_loss = loss_fct(pred_attn, dfg_links.view(-1))

            outputs = loss,loss*active_loss.sum(),active_loss.sum(),sd_loss, {'tp':tp,'tn':tn,'pos':num_pos, 'total':total}
                        #sd_mat, torch.argmax(lm_logits, dim=-1)
            return outputs
    
        else:
            #Predict 
            preds=[]       
            zero=torch.cuda.LongTensor(1).fill_(0)  
            for i in range(source_ids.shape[0]):
                context=encoder_output[:,i:i+1] # L,1,d
                context_mask=source_mask[i:i+1,:] # 1,L
                beam = Beam(self.beam_size,self.sos_id,self.eos_id)
                input_ids=beam.getCurrentState() # size,1    input_ids[0]=sos
                context=context.repeat(1, self.beam_size,1) # L, size, d
                context_mask=context_mask.repeat(self.beam_size,1) # size, d
                for _ in range(self.max_length): 
                    if beam.done():
                        break
                    attn_mask=-1e4 *(1-self.bias[:input_ids.shape[1],:input_ids.shape[1]])
                    tgt_embeddings = self.encoder.embeddings(input_ids).permute([1,0,2]).contiguous() \
                                                        + self.lang_embeddings(target_lang[i])[None,None,:]
                    out = self.decoder(tgt_embeddings,context,tgt_mask=attn_mask,memory_key_padding_mask=(1-context_mask).bool())
                    out = torch.tanh(self.dense(out))
                    hidden_states=out.permute([1,0,2]).contiguous()[:,-1,:]
                    out = self.lsm(self.lm_head(hidden_states)).data
                    beam.advance(out)
                    input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin()))
                    input_ids=torch.cat((input_ids,beam.getCurrentState()),-1)
                hyp= beam.getHyp(beam.getFinal())
                pred=beam.buildTargetTokens(hyp)[:self.beam_size]
                pred=[torch.cat([x.view(-1) for x in p]+[zero]*(self.max_length-len(p))).view(1,-1) for p in pred]
                preds.append(torch.cat(pred,0).unsqueeze(0))
                
            preds=torch.cat(preds,0)          # b, 10, L
            
            target_ids = torch.cat((0*preds[:,0,:1], preds[:,0,:-1]), axis=-1) # b,L
            attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]]) 
            tgt_embeddings = self.encoder.embeddings(target_ids).permute([1,0,2]).contiguous() + self.lang_embeddings(target_lang)[None,:,:]
            # L2, b, d
            out = self.decoder(tgt_embeddings,encoder_output,tgt_mask=attn_mask,\
                                             memory_key_padding_mask=(1-source_mask).bool()).permute([1,0,2]).contiguous() # b, L2, d
            out = out[:,:-1,:self.sd_dim]
            dfg_links = dfg_links[:,1:,1:].long() # b, L-1, L-1
            if self.sd_asym:
                pred_attn = torch.bmm(out, self.sd_weight(out).permute(0,2,1)) # b, L-1, L-1
                pred_attn = torch.sigmoid( pred_attn + self.sd_b1(out) + self.sd_b2(out).permute(0,2,1) )
            else:
                pred_attn = torch.sigmoid(torch.bmm(out, out.permute(0,2,1))) # b, L-1, L-1

            return preds, pred_attn   

        

        

class Beam(object):
    def __init__(self, size,sos,eos):
        self.size = size
        self.tt = torch.cuda
        # The score for each translation on the beam.
        self.scores = self.tt.FloatTensor(size).zero_()
        # The backpointers at each time-step.
        self.prevKs = []
        # The outputs at each time-step.
        self.nextYs = [self.tt.LongTensor(size)
                       .fill_(0)]
        self.nextYs[0][0] = sos
        # Has EOS topped the beam yet.
        self._eos = eos
        self.eosTop = False
        # Time and k pair for finished.
        self.finished = []

    def getCurrentState(self):
        "Get the outputs for the current timestep."
        batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1)
        return batch

    def getCurrentOrigin(self):
        "Get the backpointers for the current timestep."
        return self.prevKs[-1]

    def advance(self, wordLk):
        """
        Given prob over words for every last beam `wordLk` and attention
        `attnOut`: Compute and update the beam search.

        Parameters:

        * `wordLk`- probs of advancing from the last step (K x words)
        * `attnOut`- attention at the last step

        Returns: True if beam search is complete.
        """
        numWords = wordLk.size(1)

        # Sum the previous scores.
        if len(self.prevKs) > 0:
            beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)

            # Don't let EOS have children.
            for i in range(self.nextYs[-1].size(0)):
                if self.nextYs[-1][i] == self._eos:
                    beamLk[i] = -1e20
        else:
            beamLk = wordLk[0]
        flatBeamLk = beamLk.view(-1)
        bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)

        self.scores = bestScores

        # bestScoresId is flattened beam x word array, so calculate which
        # word and beam each score came from
        prevK = bestScoresId // numWords
        self.prevKs.append(prevK)
        self.nextYs.append((bestScoresId - prevK * numWords))


        for i in range(self.nextYs[-1].size(0)):
            if self.nextYs[-1][i] == self._eos:
                s = self.scores[i]
                self.finished.append((s, len(self.nextYs) - 1, i))

        # End condition is when top-of-beam is EOS and no global score.
        if self.nextYs[-1][0] == self._eos:
            self.eosTop = True

    def done(self):
        return self.eosTop and len(self.finished) >=self.size

    def getFinal(self):
        if len(self.finished) == 0:
            self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
        self.finished.sort(key=lambda a: -a[0])
        if len(self.finished) != self.size:
            unfinished=[]
            for i in range(self.nextYs[-1].size(0)):
                if self.nextYs[-1][i] != self._eos:
                    s = self.scores[i]
                    unfinished.append((s, len(self.nextYs) - 1, i)) 
            unfinished.sort(key=lambda a: -a[0])
            self.finished+=unfinished[:self.size-len(self.finished)]
        return self.finished[:self.size]

    def getHyp(self, beam_res):
        """
        Walk back to construct the full hypothesis.
        """
        hyps=[]
        for _,timestep, k in beam_res:
            hyp = []
            for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
                hyp.append(self.nextYs[j+1][k])
                k = self.prevKs[j][k]
            hyps.append(hyp[::-1])
        return hyps
    
    def buildTargetTokens(self, preds):
        sentence=[]
        for pred in preds:
            tokens = []
            for tok in pred:
                if tok==self._eos:
                    break
                tokens.append(tok)
            sentence.append(tokens)
        return sentence
        



        
        
        

        
        
        
        


In [40]:
from __future__ import absolute_import
import os
import sys
import pickle
import torch
import json
import random
import logging
import argparse
import numpy as np
from io import open
from itertools import cycle
import torch.nn as nn
from tqdm import tqdm, trange
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset
from torch.utils.data.distributed import DistributedSampler
from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
                          RobertaConfig, RobertaModel, RobertaTokenizer)
MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}

from parser import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript,DFG_csharp
from parser import (remove_comments_and_docstrings,
                   tree_to_token_index,
                   index_to_code_token,
                   tree_to_variable_index)
from tree_sitter import Language, Parser
dfg_function={
    'python':DFG_python,
    'java':DFG_java,
    'ruby':DFG_ruby,
    'go':DFG_go,
    'php':DFG_php,
    'javascript':DFG_javascript,
    'c_sharp':DFG_csharp,
}

langs = ['python', 'java', 'ruby', 'go', 'php', 'javascript', 'c_sharp']
lang_to_ind = {langs[i]:i for i in range(len(langs))}

#load parsers
parsers={}        
for lang in dfg_function:
    LANGUAGE = Language('parser/my-languages.so', lang)
    parser = Parser()
    parser.set_language(LANGUAGE) 
    parser = [parser,dfg_function[lang]]    
    parsers[lang]= parser
    
#remove comments, tokenize code and extract dataflow     
def extract_dataflow(code, parser,lang):
    #remove comments
    try:
        code=remove_comments_and_docstrings(code,lang)
    except:
        pass    
    #obtain dataflow
    if lang=="php":
        code="<?php"+code+"?>"    
    try:
        tree = parser[0].parse(bytes(code,'utf8'))    
        root_node = tree.root_node  
        tokens_index=tree_to_token_index(root_node)     
        code=code.split('\n')
        code_tokens=[index_to_code_token(x,code) for x in tokens_index]  
        index_to_code={}
        for idx,(index,code) in enumerate(zip(tokens_index,code_tokens)):
            index_to_code[index]=(idx,code)  
        try:
            DFG,_=parser[1](root_node,index_to_code,{}) 
        except:
            DFG=[]
        DFG=sorted(DFG,key=lambda x:x[1])
        indexs=set()
        for d in DFG:
            if len(d[-1])!=0:
                indexs.add(d[1])
            for x in d[-1]:
                indexs.add(x)
        new_DFG=[]
        for d in DFG:
            if d[1] in indexs:
                new_DFG.append(d)
        dfg=new_DFG
    except:
        dfg=[]
    return code_tokens,dfg


class Example(object):
    """A single training/test example."""
    def __init__(self,
                 source,
                 target,
                 lang
                 ):
        self.source = source
        self.target = target
        self.lang=lang

def read_examples(filename):
    """Read examples from filename."""
    examples=[]
    source,target=filename.split(',')
    lang='java'
    if source[-1]=='s':
        lang='c_sharp'
        
    with open(source,encoding="utf-8") as f1,open(target,encoding="utf-8") as f2:
        for line1,line2 in zip(f1,f2):
            line1=line1.strip()
            line2=line2.strip()
            examples.append(
                Example(
                    source=line1,
                    target=line2,
                    lang=lang
                        ) 
            )
    print (len(examples), '='*100, source, target)
    return examples


class InputFeatures(object):
    """A single training/test features for a example."""
    def __init__(self,
                 example_id,
                 source_ids,
                 position_idx,
                 dfg_to_code,
                 dfg_to_dfg,                 
                 target_ids,
                 source_mask,
                 target_mask,
                 dfg_links

    ):
        self.example_id = example_id
        self.source_ids = source_ids
        self.position_idx = position_idx
        self.dfg_to_code = dfg_to_code
        self.dfg_to_dfg = dfg_to_dfg
        self.target_ids = target_ids
        self.source_mask = source_mask
        self.target_mask = target_mask
        self.dfg_links = dfg_links
        

parsers={}        
for lang in dfg_function:
    LANGUAGE = Language('parser/my-languages.so', lang)
    parser = Parser()
    parser.set_language(LANGUAGE) 
    parser = [parser,dfg_function[lang]]    
    parsers[lang]= parser
    
def convert_examples_to_features(examples, tokenizer, args,stage=None,sd_asym=False):
    features = []
    match, nomatch = 0, 0
    pbar = tqdm(examples,total=len(examples))
    for example_index, example in enumerate(pbar):
        ##extract data flow
        code_tokens,dfg=extract_dataflow(example.source,
                                         parsers["c_sharp" if args.source_lang == "cs" else "java"],
                                         "c_sharp" if args.source_lang == "cs" else "java")
        code_tokens=[tokenizer.tokenize('@ '+x)[1:] if idx!=0 else tokenizer.tokenize(x) for idx,x in enumerate(code_tokens)]
        ori2cur_pos={}
        ori2cur_pos[-1]=(0,0)
        for i in range(len(code_tokens)):
            ori2cur_pos[i]=(ori2cur_pos[i-1][1],ori2cur_pos[i-1][1]+len(code_tokens[i]))    
        code_tokens=[y for x in code_tokens for y in x]  
        
        #truncating
        code_tokens=code_tokens[:args.max_source_length-3][:512-3]
        source_tokens =[tokenizer.cls_token]+code_tokens+[tokenizer.sep_token]
        source_ids =  tokenizer.convert_tokens_to_ids(source_tokens)
        position_idx = [i+tokenizer.pad_token_id + 1 for i in range(len(source_tokens))]
        dfg=dfg[:args.max_source_length-len(source_tokens)]
        source_tokens+=[x[0] for x in dfg]
        position_idx+=[0 for x in dfg]
        source_ids+=[tokenizer.unk_token_id for x in dfg]
        padding_length=args.max_source_length-len(source_ids)
        position_idx+=[tokenizer.pad_token_id]*padding_length
        source_ids+=[tokenizer.pad_token_id]*padding_length      
        source_mask = [1] * (len(source_tokens))
        source_mask+=[0]*padding_length        
        
        #reindex
        reverse_index={}
        for idx,x in enumerate(dfg):
            reverse_index[x[1]]=idx
        for idx,x in enumerate(dfg):
            dfg[idx]=x[:-1]+([reverse_index[i] for i in x[-1] if i in reverse_index],)    
        dfg_to_dfg=[x[-1] for x in dfg]
        dfg_to_code=[ori2cur_pos[x[1]] for x in dfg]
        length=len([tokenizer.cls_token])
        dfg_to_code=[(x[0]+length,x[1]+length) for x in dfg_to_code]        
      
        #target
        if stage=="test":
            target_tokens = tokenizer.tokenize("None")
            dfg_links = None
        else:
            target_tokens = tokenizer.tokenize(example.target)[:args.max_target_length-2]
            
            code_tokens_, target_dfg = extract_dataflow(example.target, parsers[args.target_lang], args.target_lang)
            code_tokens_=[tokenizer.tokenize('@ '+x)[1:] if idx!=0 else tokenizer.tokenize(x) for \
                             idx,x in enumerate(code_tokens_)]
            target_ori2cur_pos={}
            target_ori2cur_pos[-1]=(0,0)
            for i in range(len(code_tokens_)):
                target_ori2cur_pos[i]=(target_ori2cur_pos[i-1][1],target_ori2cur_pos[i-1][1]+len(code_tokens_[i]))    
            code_tokens=[y for x in code_tokens_ for y in x]  
            code_tokens = [x[1:] if (x[0]=='Ġ') else x for x in code_tokens]  
            
            scode = []
            for c in code_tokens:
                scode += list(c)

            tcode = []
            for c in target_tokens:
                if c.startswith('Ġ'):
                    c=c[1:]
                tcode += list(c)

            tcode_to_scode = []
            j = 0
            for i in range(len(tcode)):
                if j<len(scode):
                    if tcode[i]==scode[j]:
                        tcode_to_scode.append(j)
                        j += 1
                        match += 1
                    else:
                        tcode_to_scode.append(-1)
                        nomatch += 1
                else:
                    tcode_to_scode.append(-1)
                    nomatch += 1

            tcode_to_target = []
            for i in range(len(target_tokens)):
                if target_tokens[i].startswith('Ġ'):
                    tcode_to_target += [i]*(len(target_tokens[i])-1)
                else:
                    tcode_to_target += [i]*len(target_tokens[i])
            scode_to_code = []
            for i in range(len(code_tokens)):
                scode_to_code += [i]*len(code_tokens[i])

            target_to_code = [[] for i in range(len(target_tokens))]
            for i in range(len(tcode)):
                if tcode_to_scode[i]>=0:
                    target_to_code[tcode_to_target[i]].append( scode_to_code[tcode_to_scode[i]] )
            code_to_target = [[] for i in range(len(code_tokens))]
            for i in range(len(target_to_code)):
                for c in set(target_to_code[i]):
                    code_to_target[c].append(i)

            ct_to_ct = [] 
            for x in target_dfg:
                for i in x[4]:
                    ct_to_ct.append([x[1], i]) # left ct_ comes from right ct_
            code_to_code = []
            for left, right in ct_to_ct:
                for x1 in range(target_ori2cur_pos[left][0], target_ori2cur_pos[left][1]):
                    for x2 in range(target_ori2cur_pos[right][0], target_ori2cur_pos[right][1]):
                        code_to_code.append([x1, x2]) 
            
            dfg_links = np.zeros((args.max_target_length, args.max_target_length))
            dfg_links[0,:] = -1
            dfg_links[:,0] = -1
            dfg_links[len(target_tokens)+1:, :] = -1
            dfg_links[:, len(target_tokens)+1:] = -1
            np.fill_diagonal(dfg_links, -1)
            for left, right in  code_to_code:
                for x1 in code_to_target[left]:
                    for x2 in code_to_target[right]:
                        dfg_links[x1+1,x2+1]=1
                        if sd_asym==False:
                            dfg_links[x2+1,x1+1]=1
            
            
        target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token]            
        target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
        target_mask = [1] *len(target_ids)
        padding_length = args.max_target_length - len(target_ids)
        target_ids+=[tokenizer.pad_token_id]*padding_length
        target_mask+=[0]*padding_length   
        
        features.append(
            InputFeatures(
                 example_index,
                 source_ids,
                 position_idx,
                 dfg_to_code,
                 dfg_to_dfg,
                 target_ids,
                 source_mask,
                 target_mask,
                 dfg_links
            )
        )
        pbar.set_description(str(nomatch/(match+nomatch+0.0000001)))
    
    return features

class TextDataset(Dataset):
    def __init__(self, examples, args):
        self.examples = examples
        self.args=args  
        if self.args.source_lang.startswith('c'):
            self.source_lang = lang_to_ind['java']
        else:
            self.source_lang = lang_to_ind[self.args.source_lang]
        if self.args.target_lang.startswith('c'):
            self.target_lang = lang_to_ind['java']
        else:
            self.target_lang = lang_to_ind[self.args.target_lang]
        
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, item):
        #calculate graph-guided masked function
        attn_mask=np.zeros((self.args.max_source_length,self.args.max_source_length),dtype=np.bool)
        #calculate begin index of node and max length of input
        node_index=sum([i>1 for i in self.examples[item].position_idx])
        max_length=sum([i!=1 for i in self.examples[item].position_idx])
        #sequence can attend to sequence
        attn_mask[:node_index,:node_index]=True
        #special tokens attend to all tokens
        for idx,i in enumerate(self.examples[item].source_ids):
            if i in [0,2]:
                attn_mask[idx,:max_length]=True
        #nodes attend to code tokens that are identified from
        for idx,(a,b) in enumerate(self.examples[item].dfg_to_code):
            if a<node_index and b<node_index:
                attn_mask[idx+node_index,a:b]=True
                attn_mask[a:b,idx+node_index]=True
        #nodes attend to adjacent nodes         
        for idx,nodes in enumerate(self.examples[item].dfg_to_dfg):
            attn_mask[idx+node_index,idx+node_index]=True 
            for a in nodes:
                if a+node_index<len(self.examples[item].position_idx):
                    attn_mask[idx+node_index,a+node_index]=True  
        
        if self.examples[item].dfg_links is not None:
            dfg_links = torch.tensor(self.examples[item].dfg_links)
        else:
            dfg_links = torch.tensor(1)
        
        return (torch.tensor(self.examples[item].source_ids),
                    torch.tensor(self.examples[item].source_mask),
                    torch.tensor(self.examples[item].position_idx),
                    torch.tensor(attn_mask), 
                    torch.tensor(self.examples[item].target_ids),
                    torch.tensor(self.examples[item].target_mask),
                    torch.tensor(self.source_lang),
                    torch.tensor(self.target_lang),
                    dfg_links
                   )
    
def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYHTONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
class Args():
    def __init__(self):
        self.model_type = 'roberta'
        self.model_name_or_path='roberta-base'
        self.sd_config='24absasym'
        self.output_dir="../saved_models/translation/python-java_60000/"
        self.source_lang = 'python'
        self.target_lang="java"
        
        data_dir = '/home/mingzhu/CodeModel/g4g/old_data/pair_data_tok_full/Java-Python/'
        self.train_filename = data_dir+'train-Java-Python-tok.py,'+data_dir+'train-Java-Python-tok.java'
        self.dev_filename = self.train_filename.replace('train', 'val')
        self.test_filename = self.train_filename.replace('train', 'test')
        
        self.config_name="microsoft/graphcodebert-base"
        self.tokenizer_name="microsoft/graphcodebert-base"
        self.max_source_length=400
        self.max_target_length=400
        self.beam_size=10   
        self.sd_config='24absasym'
        
if True:
    args = Args()
    # print arguments
    sd_asym = args.sd_config.endswith('asym') 

    # Setup CUDA, GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.n_gpu = torch.cuda.device_count()
    args.device = device
    
    # make dir if output_dir not exist
    if os.path.exists(args.output_dir) is False:
        os.makedirs(args.output_dir)
        
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(args.config_name)
    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name )
    
    #budild model
    encoder = model_class.from_pretrained(args.model_name_or_path,config=config)    
    decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
    decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
    model=Seq2Seq(encoder=encoder,decoder=decoder,config=config,
                  beam_size=args.beam_size,max_length=args.max_target_length,
                  sos_id=tokenizer.cls_token_id,eos_id=tokenizer.sep_token_id,
                  mask_id=tokenizer.mask_token_id, n_langs=len(lang_to_ind), sd_config=args.sd_config)
    model.set_mode('translation_structdec')
    model.to(device)
    
    if True:
        p = []
        dfg = []
        load_model_path = args.output_dir+"checkpoint-best-bleu/pytorch_model.bin"
        if hasattr(model, 'module'):
            model.module.load_state_dict(torch.load(load_model_path))
        else:
            model.load_state_dict(torch.load(load_model_path))
        files=[]
        if args.test_filename is not None:
            files.append(args.test_filename)
        for idx,file in enumerate(files):   
            eval_examples = read_examples(file)[:5]
            print (len(eval_examples))
            eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='train', sd_asym=sd_asym)
            eval_data = TextDataset(eval_features,args) 

            # Calculate bleu
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=32,num_workers=4)

            model.eval() 
            for batch in tqdm(eval_dataloader,total=len(eval_dataloader)):
                batch = tuple(t.to(device) for t in batch)
                source_ids,source_mask,position_idx,att_mask,target_ids,target_mask,source_lang,target_lang,dfg_links = batch 
                with torch.no_grad():
                    preds, pred_dfg = model((source_ids,source_mask,position_idx,att_mask,None,None,source_lang,target_lang,dfg_links)) 
                    for pred in preds:
                        t=pred[0].cpu().numpy()
                        t=list(t)
                        if 0 in t:
                            t=t[:t.index(0)]
                        text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
                        p.append(text)
                        
                        dfg.append(pred_links)

df = pd.DataFrame()
df['dp'] = p
df['ref'] = [gold.target for gold in eval_examples]


0.0011607661055623468: 100%|██████████| 5/5 [00:00<00:00, 61.83it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

5


100%|██████████| 1/1 [01:02<00:00, 62.23s/it]


In [42]:
pred_dfg.shape

torch.Size([5, 399, 399])

In [22]:
len(eval_examples)

69

In [18]:
import bleu
import weighted_ngram_match
import syntax_match
import dataflow_match
import pandas as pd
pd.set_option('display.max_rows', 100)

ref = '../saved_models/translation/python-java_60000/test_0.gold'
hyp_nodp = '../saved_models/translation/python-java_60000_bs32_ep50_alpha0/test_0.output'
hyp_dp = '../saved_models/translation/python-java_60000/test_0.output'

ref_lines = open(ref, 'r', encoding='utf-8').readlines()
hyp_nodp_lines = open(hyp_nodp, 'r', encoding='utf-8').readlines()
hyp_dp_lines = open(hyp_dp, 'r', encoding='utf-8').readlines()

lang = 'java'
ref = [x.strip() for x in ref_lines]
hyp_nodp = [x.strip() for x in hyp_nodp_lines]
hyp_dp = [x.strip() for x in hyp_dp_lines]
scores = []

tokenized_refs = [[x.split()] for x in ref]
tokenized_hyps_nodp = [x.split() for x in hyp_nodp]
tokenized_hyps_dp = [x.split() for x in hyp_dp]

for i in range(len(tokenized_refs)):
    scores.append([bleu.corpus_bleu([tokenized_refs[i]],[tokenized_hyps_nodp[i]]),
                   bleu.corpus_bleu([tokenized_refs[i]],[tokenized_hyps_dp[i]]),
                   dataflow_match.corpus_dataflow_match([[ref[i]]], [hyp_nodp[i]], lang),
                   dataflow_match.corpus_dataflow_match([[ref[i]]], [hyp_dp[i]], lang),
                   ref_lines[i], hyp_nodp_lines[i], hyp_dp_lines[i]])

scores = pd.DataFrame(scores, columns=['nodp_bleu', 'dp_bleu', 'nodp_dfgm', 'dp_dfgm', 'ref', 'nodp', 'dp'])
scores

Unnamed: 0,nodp_bleu,dp_bleu,nodp_dfgm,dp_dfgm,ref,nodp,dp
0,0.010587,0.007831,0.236842,0.157895,import java . util . * ; import java . util . ...,public class GFG { static int getMissingNo ( i...,class GFG { static int getMissingNo ( int A [ ...
1,0.530937,0.530937,0.72973,0.72973,class MaximumDiffrence { int maxDiff ( int arr...,"class GFG { static int maxDiff ( int arr [ ] ,...","class GFG { static int maxDiff ( int arr [ ] ,..."
2,0.510628,0.504457,0.931034,0.931034,class Segregate { void segregate0and1 ( int ar...,class segregation { static void segregate0and1...,class segregation { static void segregate0and1...
3,0.866889,0.879259,0.358974,1.0,import java . io . * ; class NextGreatest { st...,class GFG { static void nextGreatest ( int arr...,class GFG { static void nextGreatest ( int [ ]...
4,0.182071,0.640493,0.3125,0.828125,class Main { static int findMaximum ( int arr ...,"static int findMaximum ( int arr [ ] , int low...",import java . io . * ; class GFG { static int ...
5,0.567744,0.565884,0.149425,0.16092,import java . io . * ; class GFG { static void...,import java . io . * ; class GFG { static void...,import java . io . * ; class GFG { static void...
6,0.810087,0.849924,0.76,0.76,import java . io . * ; class PairDifference { ...,class GFG { static int findPair ( int arr [ ] ...,import java . io . * ; class GFG { static int ...
7,0.777773,0.946491,0.896552,0.965517,class BinarySearch { int binarySearch ( int ar...,import java . util . * ; class GFG { static in...,class BinarySearch { int binarySearch ( int ar...
8,0.522012,0.549785,0.571429,0.714286,class CountingSort { void sort ( char arr [ ] ...,import java . io . * ; class GFG { static void...,import java . io . * ; class GFG { static void...
9,0.862872,0.663147,1.0,1.0,import java . io . * ; import java . util . * ...,import java . io . * ; import java . util . * ...,import java . io . * ; import java . util . * ...


In [2]:
i = 7
print (scores.iloc[i].ref, '\n')
print (scores.iloc[i].nodp, '\n')
print (scores.iloc[i].dp, '\n')

class BinarySearch { int binarySearch ( int arr [ ] , int l , int r , int x ) { if ( r >= l ) { int mid = l + ( r - l ) / 2 ; if ( arr [ mid ] == x ) return mid ; if ( arr [ mid ] > x ) return binarySearch ( arr , l , mid - 1 , x ) ; return binarySearch ( arr , mid + 1 , r , x ) ; } return - 1 ; }
 

import java . util . * ; class GFG { static int binarySearch ( int arr [ ] , int l , int r , int x ) { if ( r >= l ) { int mid = l + ( r - l ) / 2 ; if ( arr [ mid ] == x ) return mid ; if ( arr [ l ] < x ) return binarySearch ( arr , l , mid - 1 , x ) ; elsereturn binarySearch ( arr , l , mid - 1 , x ) ; } return - 1 ; }
 

class BinarySearch { int binarySearch ( int arr [ ] , int l , int r , int x ) { if ( r >= l ) { int mid = l + ( r - l ) / 2 ; if ( arr [ mid ] == x ) return mid ; if ( arr [ mid ] > x ) return binarySearch ( arr , l , mid - 1 , x ) ; elsereturn binarySearch ( arr , mid + 1 , r , x ) ; } elsereturn - 1 ; }
 



In [3]:
i = 56
print (scores.iloc[i].ref, '\n')
print (scores.iloc[i].nodp, '\n')
print (scores.iloc[i].dp, '\n')

class Main { static int ceilSearch ( int arr [ ] , int low , int high , int x ) { int mid ; if ( x <= arr [ low ] ) return low ; if ( x > arr [ high ] ) return - 1 ; mid = ( low + high ) / 2 ; if ( arr [ mid ] == x ) return mid ; else if ( arr [ mid ] < x ) { if ( mid + 1 <= high && x <= arr [ mid + 1 ] ) return mid + 1 ; elsereturn ceilSearch ( arr , mid + 1 , high , x ) ; } else { if ( mid - 1 >= low && x > arr [ mid - 1 ] ) return mid ; elsereturn ceilSearch ( arr , low , mid - 1 , x ) ; } }
 

class Main { int ceilSearch ( int arr [ ] , int low , int high , int x ) { if ( x <= arr [ low ] ) return low ; if ( x > arr [ high ] ) return - 1 ; int mid = ( low + high ) / 2 ; if ( arr [ mid ] == x ) return mid ; else if ( arr [ mid ] < x ) { if ( mid + 1 <= high && x <= arr [ mid + 1 ] ) return mid + 1 ; elsereturn CeilSearch ( arr , mid + 1 , high , x ) ; } else { if ( mid - 1 >= low && x > arr [ mid - 1 ] ) return mid ; elsereturn CeilSearch ( arr , low , mid - 1 , x ) ; } } public sta