In [1]:
cd /home/hd/hd_hd/hd_tf268/code-gen

/pfs/data5/home/hd/hd_hd/hd_tf268/code-gen


In [2]:
import os
import torch
import numpy as np
from bertviz import head_view
from pathlib import Path
from codegen_sources.model.src.model import build_model
from codegen_sources.model.src.utils import AttrDict
from codegen_sources.model.translate import Translator
from codegen_sources.model.src.data.dictionary import (
    Dictionary,
    BOS_WORD,
    EOS_WORD,
    PAD_WORD,
    UNK_WORD,
    MASK_WORD,
)
from codegen_sources.preprocessing.lang_processors.lang_processor import LangProcessor
from codegen_sources.preprocessing.lang_processors.cpp_processor import CppProcessor
from codegen_sources.preprocessing.lang_processors.java_processor import JavaProcessor
from codegen_sources.preprocessing.lang_processors.python_processor import PythonProcessor

adding to path /pfs/data5/home/hd/hd_hd/hd_tf268/code-gen


In [3]:
def visualize_weights(function_id, src_lang, tgt_lang):
    src_lang_processor = LangProcessor.processors[src_lang](root_folder="tree-sitter")
    tgt_lang_processor = LangProcessor.processors[tgt_lang](root_folder="tree-sitter")

    translator_path = f"models/transcoder_st/Online_ST_{src_lang.title()}_{tgt_lang.title()}.pth"
    translator = Translator(translator_path.replace("Cpp", "CPP"), 'data/bpe/cpp-java-python/codes')

    # Get function and reference function
    run = os.listdir(f"dump/transcoder_st/eval/{src_lang}_{tgt_lang}/online_st")[0]
    run_path = f"dump/transcoder_st/eval/{src_lang}_{tgt_lang}/online_st/{run}"
    
    langs = ["cpp", "java", "python"]
    
    if langs.index(src_lang) < langs.index(tgt_lang):
        ids_path = f"{run_path}/hypotheses/ids.{src_lang}_sa-{tgt_lang}_sa.test.txt"
    else:
        ids_path = f"{run_path}/hypotheses/ids.{tgt_lang}_sa-{src_lang}_sa.test.txt"

    ref_path = f"{run_path}/hypotheses/ref.{src_lang}_sa-{tgt_lang}_sa.test.txt"
    src_path = f"{run_path}/hypotheses/src.{src_lang}_sa-{tgt_lang}_sa.test.txt"
        
    ids_lines = open(ids_path, "r").readlines()
    ref_lines = open(ref_path, "r").readlines()
    src_lines = open(src_path, "r").readlines()
    
    for i, line in enumerate(ids_lines):
        if function_id in line:
            index = i
            break
            
    function = src_lang_processor.detokenize_code(src_lines[index])
    ref_function = tgt_lang_processor.detokenize_code(ref_lines[index])
            
    # Translate function
    f_fill, weights, tokens = output = translator.translate(
        function,
        lang1=src_lang,
        lang2=tgt_lang,
        beam_size=1,
        return_weights=True
    )
    
    print("=" * 100)
    print("Input")
    print("=" * 100)
    print(function)
    print("=" * 100)
    print("Output")
    print("=" * 100)
    print(f_fill[0])
    print("=" * 100)
    print("Reference")
    print("=" * 100)
    print(ref_function)
    print("=" * 100)
    
    head_view(weights, tokens)

In [4]:
src_lang = 'cpp'
tgt_lang = 'java'

function_id = "PROGRAM_PRINT_SUM_GIVEN_NTH_TERM_1"
visualize_weights(function_id, src_lang, tgt_lang)

INFO - 05/31/22 09:08:15 - 0:00:07 - Reloading encoder from models/transcoder_st/Online_ST_CPP_Java.pth ...
INFO - 05/31/22 09:08:39 - 0:00:30 - Reloading decoders from models/transcoder_st/Online_ST_CPP_Java.pth ...
INFO - 05/31/22 09:08:41 - 0:00:32 - Number of parameters (encoder): 143279616
INFO - 05/31/22 09:08:41 - 0:00:32 - Number of parameters (decoders): 168482304
INFO - 05/31/22 09:08:41 - 0:00:32 - Number of decoders: 1



Tokenized cpp_sa function:
['int', 'summingSeries', '(', 'long', 'n', ')', '{', 'return', 'pow', '(', 'n', ',', '2', ')', ';', '}']


Loading codes from /pfs/data5/home/hd/hd_hd/hd_tf268/code-gen/data/bpe/cpp-java-python/codes ...
Read 50000 codes from the codes file.


Input
int summingSeries ( long n ) {
  return pow ( n , 2 ) ;
}

Output
public static int summingSeries ( long n ) {
  return Math . pow ( n , 2 ) ;
}
@ @
Reference
static int summingSeries ( long n ) {
  return ( int ) Math . pow ( n , 2 ) ;
}



<IPython.core.display.Javascript object>

In [5]:
src_lang = 'cpp'
tgt_lang = 'python'

function_id = "PROGRAM_PRINT_SUM_GIVEN_NTH_TERM_1"
visualize_weights(function_id, src_lang, tgt_lang)

INFO - 05/31/22 09:08:50 - 0:00:41 - Reloading encoder from models/transcoder_st/Online_ST_CPP_Python.pth ...
INFO - 05/31/22 09:08:51 - 0:00:43 - Reloading decoders from models/transcoder_st/Online_ST_CPP_Python.pth ...
INFO - 05/31/22 09:08:53 - 0:00:44 - Number of parameters (encoder): 143279616
INFO - 05/31/22 09:08:53 - 0:00:44 - Number of parameters (decoders): 168482304
INFO - 05/31/22 09:08:53 - 0:00:44 - Number of decoders: 1



Tokenized cpp_sa function:
['int', 'summingSeries', '(', 'long', 'n', ')', '{', 'return', 'pow', '(', 'n', ',', '2', ')', ';', '}']
Input
int summingSeries ( long n ) {
  return pow ( n , 2 ) ;
}

Output
def summing_series ( n ) :
    return math.pow ( n , 2 )

Reference
def summingSeries ( n ) :
    return math.pow ( n , 2 )



Loading codes from /pfs/data5/home/hd/hd_hd/hd_tf268/code-gen/data/bpe/cpp-java-python/codes ...
Read 50000 codes from the codes file.


<IPython.core.display.Javascript object>

In [6]:
src_lang = 'java'
tgt_lang = 'cpp'

function_id = "PROGRAM_PRINT_SUM_GIVEN_NTH_TERM_1"
visualize_weights(function_id, src_lang, tgt_lang)

INFO - 05/31/22 09:08:59 - 0:00:50 - Reloading encoder from models/transcoder_st/Online_ST_Java_CPP.pth ...
INFO - 05/31/22 09:09:01 - 0:00:52 - Reloading decoders from models/transcoder_st/Online_ST_Java_CPP.pth ...
INFO - 05/31/22 09:09:02 - 0:00:53 - Number of parameters (encoder): 143279616
INFO - 05/31/22 09:09:02 - 0:00:53 - Number of parameters (decoders): 168482304
INFO - 05/31/22 09:09:02 - 0:00:53 - Number of decoders: 1



Tokenized java_sa function:
['static', 'int', 'summingSeries', '(', 'long', 'n', ')', '{', 'return', '(', 'int', ')', 'Math', '.', 'pow', '(', 'n', ',', '2', ')', ';', '}']
Input
static int summingSeries ( long n ) {
  return ( int ) Math . pow ( n , 2 ) ;
}

Output
int summing_series ( long n ) {
  return ( int ) pow ( n , 2 ) ;
}

Reference
int summingSeries ( long n ) {
  return pow ( n , 2 ) ;
}



Loading codes from /pfs/data5/home/hd/hd_hd/hd_tf268/code-gen/data/bpe/cpp-java-python/codes ...
Read 50000 codes from the codes file.


<IPython.core.display.Javascript object>

In [7]:
src_lang = 'java'
tgt_lang = 'python'

function_id = "PROGRAM_PRINT_SUM_GIVEN_NTH_TERM_1"
visualize_weights(function_id, src_lang, tgt_lang)

INFO - 05/31/22 09:09:08 - 0:00:59 - Reloading encoder from models/transcoder_st/Online_ST_Java_Python.pth ...
INFO - 05/31/22 09:09:09 - 0:01:01 - Reloading decoders from models/transcoder_st/Online_ST_Java_Python.pth ...
INFO - 05/31/22 09:09:11 - 0:01:02 - Number of parameters (encoder): 143279616
INFO - 05/31/22 09:09:11 - 0:01:02 - Number of parameters (decoders): 168482304
INFO - 05/31/22 09:09:11 - 0:01:02 - Number of decoders: 1



Tokenized java_sa function:
['static', 'int', 'summingSeries', '(', 'long', 'n', ')', '{', 'return', '(', 'int', ')', 'Math', '.', 'pow', '(', 'n', ',', '2', ')', ';', '}']
Input
static int summingSeries ( long n ) {
  return ( int ) Math . pow ( n , 2 ) ;
}

Output
def summing_series ( n ) :
    return int ( math.pow ( n , 2 ) )

Reference
def summingSeries ( n ) :
    return math.pow ( n , 2 )



Loading codes from /pfs/data5/home/hd/hd_hd/hd_tf268/code-gen/data/bpe/cpp-java-python/codes ...
Read 50000 codes from the codes file.


<IPython.core.display.Javascript object>

In [None]:
src_lang = 'python'
tgt_lang = 'cpp'

function_id = "PROGRAM_PRINT_SUM_GIVEN_NTH_TERM_1"
visualize_weights(function_id, src_lang, tgt_lang)

In [None]:
src_lang = 'python'
tgt_lang = 'java'

function_id = "PROGRAM_PRINT_SUM_GIVEN_NTH_TERM_1"
visualize_weights(function_id, src_lang, tgt_lang)