In [1]:
from pathlib import Path
from tokenizers import ByteLevelBPETokenizer
import pandas as pd
from tokenizers.processors import BertProcessing
from tree_sitter import Language, Parser
import os
from pathlib import Path

In [2]:
def create_java_only_dataset():
    if not os.path.isfile("Data/Java_Unified_Data_with_SHA.csv"):
        df = pd.read_csv("Data/Unified_Data_with_SHA.csv")
        df2 = df[df["language_name"]=='Java']
        df2.reset_index(drop=True,inplace=True)
        df2.to_csv("Data/Java_Unified_Data_with_SHA.csv",index=False)
create_java_only_dataset()

In [3]:
create_java_only_dataset()

In [4]:
def get_uuid(text):
    return text.split("/")[-1].split(".")[0]

In [5]:
def remove_comments_and_docstrings(source):

    def replacer(match):
        s = match.group(0)
        if s.startswith('/'):
            return " " # note: a space and not an empty string
        else:
            return s
    pattern = re.compile(
        r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"',
        re.DOTALL | re.MULTILINE
    )
    temp=[]
    for x in re.sub(pattern, replacer, source).split('\n'):
        if x.strip()!="":
            temp.append(x)
    return '\n'.join(temp)

In [6]:
def create_report_files():
    if not os.path.isdir("Data/Report_Files/"):
        Path("Data/Report_Files/").mkdir(parents=True, exist_ok=True)
        df = pd.read_csv("Data/Java_Unified_Data_with_SHA.csv")
        for item in df.iterrows():
            uuid_name = item[1]['before_fix_uuid_file_path'].split("/")[-1].split(".")[0]
            file = open ("Data/Report_Files/{}.txt".format(uuid_name),"w")
            file.write(item[1]['title'] + " " + item[1]['description'])
            file.close()
create_report_files()

In [7]:
def convert_file_to_ast(file_path, parser):
    file = open(file_path,"r")
    file_content = file.read()
    tree = parser.parse(bytes(file_content,"utf-8"))
    return tree.root_node.sexp()

In [8]:
def create_ast_files():
    if not os.path.isdir("Data/AST_Files/"):
        Path("Data/AST_Files/").mkdir(parents=True, exist_ok=True)
        df = pd.read_csv("Data/Java_Unified_Data_with_SHA.csv")
        JAVA_LANGUAGE = Language('build/my-languages.so', 'java')
        parser = Parser()
        parser.set_language(JAVA_LANGUAGE)
        for item in df.iterrows():
            before_fix_uuid_name = item[1]['before_fix_uuid_file_path'].split("/")[-1].split(".")[0]
            before_fix_file = open ("Data/AST_Files/{}.txt".format(before_fix_uuid_name),"w")
            before_fix_file.write(convert_file_to_ast(item[1]['before_fix_uuid_file_path'],parser))
            before_fix_file.close()
            
            after_fix_uuid_name = item[1]['after_fix_uuid_file_path'].split("/")[-1].split(".")[0]
            after_fix_file = open ("Data/AST_Files/{}.txt".format(after_fix_uuid_name),"w")
            after_fix_file.write(convert_file_to_ast(item[1]['after_fix_uuid_file_path'],parser))
            after_fix_file.close()
create_ast_files()

In [9]:
df = pd.read_csv("Data/Java_Train_Data.csv")

In [10]:
before_fix_ast_paths = df['before_fix_uuid_file_path'].map(lambda x:"Data/AST_Files/" + get_uuid(x) + ".txt").tolist()
after_fix_ast_paths = df['after_fix_uuid_file_path'].map(lambda x:"Data/AST_Files/" + get_uuid(x) + ".txt").tolist()
report_files = df['before_fix_uuid_file_path'].map(lambda x:"Data/Report_Files/" + get_uuid(x) + ".txt").tolist()

In [11]:
all_file_path = before_fix_ast_paths + report_files

In [12]:
tokenizer = ByteLevelBPETokenizer()
tokenizer.train(files=all_file_path, min_frequency=2, special_tokens=[
    "<s>",
    "<pad>",
    "</s>",
    "<unk>",
    "<mask>",
])

In [13]:
tokenizer.save_model(".", "./aster")

['././aster-vocab.json', '././aster-merges.txt']

In [14]:
tokenizer = ByteLevelBPETokenizer(
    "aster-vocab.json",
    "aster-merges.txt",
)
tokenizer._tokenizer.post_processor = BertProcessing(
    ("</s>", tokenizer.token_to_id("</s>")),
    ("<s>", tokenizer.token_to_id("<s>")),
)
tokenizer.enable_truncation(max_length=3000)

In [15]:
tokenizer.encode("public static void main(){System.out.println(\"hello world\")}").attention_mask

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]