In [1]:
import tree_sitter
from tree_sitter import Language, Parser
import re

all the declarations as in enums.py

In [2]:
LANG_JAVA = 'java'
LANG_PYTHON = 'python'
LANG_GO = 'go'
LANG_PHP = 'php'
LANG_JAVASCRIPT = 'javascript'
LANG_RUBY = 'ruby'
LANG_C_SHARP = 'c_sharp'

In [3]:
# training mode
TRAINING_MODE_PRE_TRAIN = 'pre_train'
TRAINING_MODE_FINE_TUNE = 'fine_tune'

# pre-training task names
TASK_CODE_AST_PREDICTION = 'cap'
TASK_MASS = 'mass'
TASK_METHOD_NAME_PREDICTION = 'mng'

PRE_TRAIN_TASKS = [
    TASK_CODE_AST_PREDICTION,
    TASK_MASS,
    TASK_METHOD_NAME_PREDICTION
]

# downstream task names
TASK_SUMMARIZATION = 'summarization'
TASK_TRANSLATION = 'translation'
TASK_SEARCH = 'search'
TASK_CLONE_DETECTION = 'clone'
TASK_COMPLETION = 'completion'
TASK_BUG_FIX = 'bug_fix'

ALL_DOWNSTREAM_TASKS = [
    TASK_SUMMARIZATION,
    TASK_TRANSLATION,
    TASK_SEARCH,
    TASK_CLONE_DETECTION,
    TASK_COMPLETION,
    TASK_BUG_FIX
]

# programming language
LANG_JAVA = 'java'
LANG_PYTHON = 'python'
LANG_GO = 'go'
LANG_PHP = 'php'
LANG_JAVASCRIPT = 'javascript'
LANG_RUBY = 'ruby'
LANG_C_SHARP = 'c_sharp'

# BART model mode
MODEL_MODE_CLS = 'bart_cls'
MODEL_MODE_GEN = 'bart_gen'
MODEL_MODE_SEARCH = 'bart_search'

In [4]:
LANGUAGE = {LANG_GO: Language('/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/spt-code/sources/data/asts/build/my-languages.so', 'go'),
            LANG_JAVASCRIPT: Language('/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/spt-code/sources/data/asts/build/my-languages.so', 'javascript'),
            LANG_PYTHON: Language('/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/spt-code/sources/data/asts/build/my-languages.so', 'python'),
            LANG_JAVA: Language('/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/spt-code/sources/data/asts/build/my-languages.so', 'java'),
            LANG_PHP: Language('/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/spt-code/sources/data/asts/build/my-languages.so', 'php'),
            LANG_RUBY: Language('/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/spt-code/sources/data/asts/build/my-languages.so', 'ruby'),
            LANG_C_SHARP: Language('/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/spt-code/sources/data/asts/build/my-languages.so', 'c_sharp')}



In [5]:
parser = Parser()

In [None]:
#It contains language-specific prefix and postfix strings.

In [6]:
SOURCE_PREFIX_POSTFIX = {
    LANG_PHP: ['<?php ', ' ?>'],
    LANG_JAVA: ['class A{ ', ' }']
}

defining a dictionary in Python named PATTERNS_METHOD_ROOT. It contains language-specific patterns for identifying the root node of a method in a code AST (Abstract Syntax Tree).

In [7]:
#It will search for a method_declaration within a class_body, which in turn is within a class_declaration. @method_root seems to be a placeholder for where the root of the method should be.
PATTERNS_METHOD_ROOT = {
    LANG_JAVA: """
    (program
        (class_declaration
            body: (class_body
                (method_declaration) @method_root)
        )
    )
    """
}

This dictionary contains language-specific patterns for identifying the body of a method/function in different programming languages.

In [8]:
#This pattern looks for a method_declaration node with a block as its body, where @body is a placeholder for capturing the body of the method.
PATTERNS_METHOD_BODY = {
    LANG_JAVA: """
    (method_declaration
        body: (block) @body
    )
    """,

#This pattern searches for a function_declaration in a JavaScript program, and captures its statement_block as the body of the function.
    LANG_JAVASCRIPT: """
    (program
        (function_declaration
            body: (statement_block) @body
        )
    )
    """,
#This pattern looks for either a function_declaration or a method_declaration in a Go source file, and captures the block as the body of the function/method.
    LANG_GO: """
    (source_file
        [
        (function_declaration
            body: (block) @body)

        (method_declaration
            body: (block) @body)
        ]
    )
    """
}

it contains language-specific patterns for identifying the names of methods/functions in various programming languages.

In [9]:
#This pattern looks for a method_declaration in Java and captures its name, represented by an identifier.
PATTERNS_METHOD_NAME = {
    LANG_JAVA: """
    (method_declaration
        name: (identifier) @method_name
    )
    """,

    LANG_PYTHON: """
    (module
        (function_definition
            name: (identifier) @method_name
        )
    )
    """,

    LANG_GO: """
    [
        (source_file
            (method_declaration
                name: (field_identifier) @method_name
            )
        )
        (source_file
            (function_declaration
                name: (identifier) @method_name
            )
        )
    ]
    """,

    LANG_JAVASCRIPT: """
    (program
        (function_declaration
            name: (identifier) @method_name
        )
    )
    """,

    LANG_RUBY: """
    (program
        (method
            name: (identifier) @method_name
        )
    )
    """,

    LANG_PHP: """
    (program
        (function_definition
            name: (name) @method_name
        )
    )
    """
}

It contains language-specific patterns for identifying method invocations in various programming languages.

In [10]:
#This pattern looks for a method_invocation in Java and captures its name, represented by an identifier.
PATTERNS_METHOD_INVOCATION = {
    LANG_JAVA: """
    (method_invocation
        name: (identifier) @method_invocation
    )
    """,
#It looks for either a function call or a method call on an object (attribute access), capturing the name of the method in both cases.
    LANG_PYTHON: """
    [
        (call
            function: (identifier) @method_invocation
        )
        (call
            function: (attribute
                attribute: (identifier) @method_invocation
            )
        )
    ]
    """,

    LANG_GO: """
    [
        (call_expression
            function: (selector_expression
                field: (field_identifier) @method_invocation
            )
        )
        (call_expression
            function: (identifier) @method_invocation
        )
    ]
    """,

    LANG_JAVASCRIPT: """
    [
        (call_expression
            function: (member_expression
                property: (property_identifier) @method_invocation
            )
        )
        (call_expression
            function: (identifier) @method_invocation
        )
    ]
    """,

    LANG_RUBY: """
    (call
        method: (identifier) @method_invocation
    )
    """,

    LANG_PHP: """
    [
        (scoped_call_expression
            name: (name) @method_invocation
        )
        (function_call_expression
            (name) @method_invocation
        )
        (member_call_expression
            name: (name) @method_invocation
        )
        (object_creation_expression
            (qualified_name
                (name) @method_invocation
            )
        )
    ]
    """
}

The STATEMENT_ENDING_STRINGS dictionary defines common types of statement endings for different programming languages

In [11]:
STATEMENT_ENDING_STRINGS = {
    LANG_JAVA: ['statement', 'expression', 'declaration'],
    LANG_PYTHON: ['statement', 'assignment'],
    LANG_GO: ['statement', 'declaration', 'expression'],
    LANG_JAVASCRIPT: ['statement', 'expression'],
    LANG_RUBY: ['call', 'assignment', 'if', 'unless_modifier', 'operator_assignment', 'if_modifier', 'return',
                'rescue', 'else', 'unless', 'when', 'for', 'while_modifier', 'until'],
    LANG_PHP: ['statement', 'expression']
}

To split a camel case identifier into its constituent parts

In [12]:
def camel_split(identifier):
    matches = re.finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', identifier)
    return [m.group(0) for m in matches]

This function takes an identifier string and splits it into a list of subtokens, eliminating tokens that are not characters or digits

In [13]:
def split_identifier(identifier):
    """
    Split identifier into a list of subtokens.
    Tokens except characters and digits will be eliminated.

    Args:
        identifier (str): given identifier

    Returns:
        list[str]: list of subtokens
    """
    words = []

    word = re.sub(r'[^a-zA-Z0-9]', ' ', identifier)
    word = re.sub(r'(\d+)', r' \1 ', word)
    split_words = word.strip().split()
    for split_word in split_words:
        camel_words = camel_split(split_word)
        for camel_word in camel_words:
            words.append(camel_word.lower())

    return words

In [14]:
import sys

To parse code into an Abstract Syntax Tree (AST) using a parser specific to the language provided

In [15]:
def parse_ast(source, lang):
    """
    Parse the given code into corresponding ast.
    Args:
        source (str): code in string
        lang (str): Set the language

    Returns:
        tree_sitter.Node: Method/Function root node

    """
    #Attempts to set the language of the parser. If it fails, it prints an error message and exits the program.
    try:
        parser.set_language(LANGUAGE[lang])
    except Exception as err:
        print(f'[ERR]: {err}')
        sys.exit(0)

    #If the language has specified prefix and postfix strings in SOURCE_PREFIX_POSTFIX, it adds them to the source code.
    if lang in SOURCE_PREFIX_POSTFIX:
        source = SOURCE_PREFIX_POSTFIX[lang][0] + source + SOURCE_PREFIX_POSTFIX[lang][1]
    #It then parses the source code using the parser for the specified language. The code is encoded and decoded to handle Unicode characters.
    tree = parser.parse(source.encode('utf-8').decode('unicode_escape').encode())
    root = tree.root_node
    # tree = parser.parse(str.encode(source))
    #If there's a pattern specified for finding the root of a method/function in PATTERNS_METHOD_ROOT, it uses that pattern to locate the root node.
    if lang in PATTERNS_METHOD_ROOT:
        query = LANGUAGE[lang].query(PATTERNS_METHOD_ROOT[lang])
        captures = query.captures(root)
        root = captures[0][0]
    return root
    #it returns the root node of the parsed AST.

The get_node_name function retrieves the name of a given node in an Abstract Syntax Tree (AST). 

In [16]:
def get_node_name(source, node, lang):
    """
    Get node name, for php is shifted by prefix.

    Args:
        source (str): Source code string
        node (tree_sitter.Node): Node instance
        lang (str): Source code language

    Returns:
        str: Name of node

    """
    #This condition checks if the node is a named node. A named node is one that represents an actual part of the code (like an identifier, a keyword, etc.).
    if node.is_named:
    # If the language has specified prefix and postfix strings in SOURCE_PREFIX_POSTFIX, it adjusts the byte positions of the node to accommodate for these strings and extracts the corresponding substring from the source code.
        if lang in SOURCE_PREFIX_POSTFIX:
            return source[node.start_byte - len(SOURCE_PREFIX_POSTFIX[lang][0]):
                          node.end_byte - len(SOURCE_PREFIX_POSTFIX[lang][0])]
        else:
        #If there's no prefix and postfix for the language, it simply extracts the substring corresponding to the node from the source code.
            return source[node.start_byte: node.end_byte]
    return ''

#This function aims to return the name of a method/function given its root node in an AST

In [17]:
def get_method_name(source, root, lang):
    """
    Return the name of method/function.

    Args:
        source (str): Source code string
        root (tree_sitter.Node): Method/Function root node
        lang (str): Source code language

    Returns:

    """
    #It queries the AST for the pattern to find the method/function name based on the language specified.
    query = LANGUAGE[lang].query(PATTERNS_METHOD_NAME[lang])
    #it captures the matches
    captures = query.captures(root)
    #If no captures are found, it returns an empty string.
    if len(captures) == 0:
        return ''
    #returns the name of the first captured node using the get_node_name function. 
    return get_node_name(source, captures[0][0], lang)

To check whether a given node in an Abstract Syntax Tree (AST) represents a statement-level construct in the code

In [18]:
def is_statement_node(node, lang):
    """
    Return whether the node is a statement level node.

    Args:
        node (tree_sitter.Node): Node to be queried
        lang (str): Source code language

    Returns:
        bool: True if given node is a statement node

    """
    #Retrieves the list of statement ending strings specific to the language from the STATEMENT_ENDING_STRINGS dictionary.
    endings = STATEMENT_ENDING_STRINGS[lang]
    #Splits the node's type string by underscores and takes the last part. 
    end = node.type.split('_')[-1]
    #if the extracted ending is found in the list of statement endings for the language, it returns True, indicating that the node represents a statement-level construct. Otherwise, it returns False
    if end in endings:
        return True
    else:
        return False

this function retrieves the type of a given node in the Abstract Syntax Tree (AST)

In [19]:
#If the language is Ruby (LANG_RUBY), it appends "_statement" to the end of the node's type. 
#This is because in the Ruby language, many node types are followed by "_statement" (e.g., if_statement, while_statement). For other languages, it returns the node's type as it is.
def get_node_type(node, lang):
    """
    Return the type of node, for ruby, add ``_statement`` to the end.

    Args:
        node (tree_sitter.Node): Node to be queried
        lang (str): Source code language

    Returns:
        str: Type of the node

    """
    return f'{node.type}_statement' if lang == LANG_RUBY else node.type

This function generates a representation of a source code AST in X-SBT format (Extended-Simple Binary Tree) recursively

In [20]:
def __statement_xsbt(node, lang):
    """
    Method used to generate X-SBT recursively.

    Args:
        node (tree_sitter.Node): Root node to traversal
        lang (str): Source code language

    Returns:
        list[str]: List of strings representing node types

    """
    xsbt = []
#If the current node has no children and is a statement node (determined by is_statement_node function), its type is appended to xsbt
    if len(node.children) == 0:
        if is_statement_node(node, lang):
            xsbt.append(get_node_type(node, lang))
    else:
    #If the current node has children:
    #If the current node is a statement node, its type followed by "__" is appended to xsbt.
        if is_statement_node(node, lang):
            xsbt.append(f'{get_node_type(node, lang)}__')
    #the current length of xsbt is stored in len_before.
        len_before = len(xsbt)
    #Each child node is processed recursively, and their X-SBT representations are concatenated to xsbt.
        for child in node.children:
            xsbt += __statement_xsbt(node=child, lang=lang)
    #After processing children, if the length of xsbt remains unchanged and is not zero (i.e., no new nodes were added), the last item in xsbt is replaced with the current node's type. This ensures correct representation of the parent node.
        if len_before == len(xsbt) and len_before != 0:
            xsbt[-1] = get_node_type(node, lang)
    #If new nodes were added to xsbt and the current node is a statement node, the current node's type preceded by "__" is appended to xsbt. This marks the end of the parent node's representation and the start of child nodes' representations
        elif is_statement_node(node, lang):
            xsbt.append(f'__{get_node_type(node, lang)}')

    return xsbt

This function generates an X-SBT (Extended-Simple Binary Tree) string from a given root node in the Abstract Syntax Tree (AST).

In [21]:

def generate_statement_xsbt(node, lang):
    """
    Generate X-SBT string.

    Args:
        node (tree_sitter.Node): Root node to traversal
        lang (str): Source code language

    Returns:
        str: X-SBT string

    """
#If the language has specified patterns for method bodies in PATTERNS_METHOD_BODY, it queries the AST using this pattern and retrieves the first capture. 
#This is useful for extracting the method body node if necessary.
    if lang in PATTERNS_METHOD_BODY:
        query = LANGUAGE[lang].query(PATTERNS_METHOD_BODY[lang])
        captures = query.captures(node)
        node = captures[0][0]
    #It generates the X-SBT tokens using the __statement_xsbt function with the given node and language.
    tokens = __statement_xsbt(node=node, lang=lang)
    return ' '.join(tokens)

This function aims to extract method invocation sequences from a given root node in an Abstract Syntax Tree (AST). 

In [22]:
def extract_method_invocation(source, root, lang):
    """
    Extract method invocation sequence from given root.

    Args:
        source (str): Source code string
        root (tree_sitter.Node): Node to be extracted from
        lang (str): Source code language

    Returns:
        list[str]: List of method invocation strings

    """
#It queries the AST using the method invocation pattern specified in PATTERNS_METHOD_INVOCATION for the given language. It captures all instances of method invocations.
    query = LANGUAGE[lang].query(PATTERNS_METHOD_INVOCATION[lang])
    captures = query.captures(root)
#It iterates over each captured method invocation node, retrieves its name using the get_node_name function, and collects these names into a list.
    return [get_node_name(source=source, node=capture[0], lang=lang) for capture in captures]

This function is designed to extract natural language tokens from source code, including splitting the method/function name and method invocation names if necessary. 

In [23]:
def extract_nl_from_code(source, root, lang, name=None, replace_method_name=False):
    """
    Extract nl tokens from given source code, including split name and method invocations.

    Args:
        source (str): Source code string`
        root (tree_sitter.Node): Root of code
        lang (str): Source code language
        name (str): optional, name of method/function
        replace_method_name (bool): Whether to replace method name and returns a version that without names additionally

    Returns:
        Union[(str, str), str]:
            - Nl string
            - Nl string without method name

    """
#Two lists, tokens and tokens_wo_name, are initialized to hold the NL tokens. tokens will hold all NL tokens, including method/function names and method invocation names. tokens_wo_name will hold NL tokens excluding method/function names.
    tokens = []
    tokens_wo_name = []

#If the method/function name is not provided, it's extracted using get_method_name. The name is then split into tokens using split_identifier, and these tokens are added to both tokens and tokens_wo_name.
    if name is None:
        name = get_method_name(source=source, root=root, lang=lang)
    name_tokens = split_identifier(name)
    tokens += name_tokens

#Each invocation name is split into tokens using split_identifier, and these tokens are added to both tokens and tokens_wo_name.
    invocations = extract_method_invocation(source=source, root=root, lang=lang)
    for invocation in invocations:
        subtokens = split_identifier(invocation)
        tokens += subtokens
        tokens_wo_name += subtokens

#If replace_method_name is True, it returns both the NL string with method/function names (' '.join(tokens)) and the NL string without method/function names (' '.join(tokens_wo_name)).
#If replace_method_name is False, it returns only the NL string with method/function names (' '.join(tokens)).
    if replace_method_name:
        return ' '.join(tokens), ' '.join(tokens_wo_name)
    else:
        return ' '.join(tokens)

This generate_single_ast_nl function generates AST sequences and natural language (NL) sequences for a single source code sample

In [24]:
def generate_single_ast_nl(source, lang, name=None, replace_method_name=False):
    """
    Generate AST sequence and nl sequence for a single source code sample.

    Args:
        source (str): Source code string
        lang (str): Source code language
        name (str): optional, name of method/function
        replace_method_name (bool): Whether to replace method name and returns a version that without names additionally

    Returns:
        Union[(str, str), (str, str, str)]:
            - AST sequence in string
            - Nl sequence in string

    """
#The source code is parsed into an AST (root) using the parse_ast function.
#The AST is converted into an AST sequence (ast) using the generate_statement_xsbt function.
    root = parse_ast(source=source, lang=lang)
    ast = generate_statement_xsbt(node=root, lang=lang)
#If replace_method_name is True, NL sequences are extracted with and without method/function names using the extract_nl_from_code function. 
#Both the AST sequence, NL sequence with names (nl), and NL sequence without names (nl_wo_name) are returned.
    if replace_method_name:
        nl, nl_wo_name = extract_nl_from_code(source=source,
                                              root=root,
                                              lang=lang,
                                              name=name,
                                              replace_method_name=replace_method_name)
        return ast, nl, nl_wo_name
    else:
#f replace_method_name is False, NL sequence is extracted only with method/function names.
        nl = extract_nl_from_code(source=source, root=root, lang=lang, name=name)
        return ast, nl

This function generates AST sequences and natural language (NL) sequences for a list of source code samples. It filters out any exceptions that occur during processing. 

In [25]:
def generate_asts_nls(sources, langs):
    """
    Generate AST sequence and nl sequence for a list of source code samples, exceptions will be eliminate.

    Args:
        sources (str): List of source code strings
        langs (str): List of source code languages

    Returns:
        (list[str], list[str], list[str], list[str]):
            - List of language strings
            - List of source code strings
            - List of AST sequence strings
            - List of nl sequence strings

    """
    assert len(sources) == len(langs)
    new_langs = []
    new_sources = []
    asts = []
    nls = []
    #It iterates through each pair of language and source code using zip(langs, sources)
    for lang, source in zip(langs, sources):
        try:
        #For each pair, it attempts to generate AST and NL sequences using generate_single_ast_nl
            ast, nl = generate_single_ast_nl(source=source, lang=lang)
        #If successful, it appends the language, source code, AST sequence, and NL sequence to the respective lists.
            new_langs.append(lang)
            new_sources.append(source)
            asts.append(ast)
            nls.append(nl)
        #If an exception occurs, it continues to the next pair without raising the exception.
        except Exception:
            continue
    #it returns the filtered lists of language strings, source code strings, AST sequences, and NL sequences.
    return new_langs, new_sources, asts, nls

In [26]:

def analyze_java_code(source_code):
    root = parse_ast(source_code, 'java')
    method_names = get_method_name(source_code, root, 'java')
    method_invocations = extract_method_invocation(source_code, root, 'java')
    print("Method Names:", method_names)
    print("Method Invocations:", method_invocations)

In [27]:
# Example code
java_code = "public static double Y(double x) {\r\n        if (x < 8.0) {\r\n            double y = x * x;\r\n            double ans1 = x * (-0.4900604943e13 + y * (0.1275274390e13\r\n                    + y * (-0.5153438139e11 + y * (0.7349264551e9\r\n                    + y * (-0.4237922726e7 + y * 0.8511937935e4)))));\r\n            double ans2 = 0.2499580570e14 + y * (0.4244419664e12\r\n                    + y * (0.3733650367e10 + y * (0.2245904002e8\r\n                    + y * (0.1020426050e6 + y * (0.3549632885e3 + y)))));\r\n            return (ans1 / ans2) + 0.636619772 * (J(x) * Math.log(x) - 1.0 / x);\r\n        } else {\r\n            double z = 8.0 / x;\r\n            double y = z * z;\r\n            double xx = x - 2.356194491;\r\n            double ans1 = 1.0 + y * (0.183105e-2 + y * (-0.3516396496e-4\r\n                    + y * (0.2457520174e-5 + y * (-0.240337019e-6))));\r\n            double ans2 = 0.04687499995 + y * (-0.2002690873e-3\r\n                    + y * (0.8449199096e-5 + y * (-0.88228987e-6\r\n                    + y * 0.105787412e-6)));\r\n            return Math.sqrt(0.636619772 / x) *\r\n                    (Math.sin(xx) * ans1 + z * Math.cos(xx) * ans2);\r\n        }\r\n    }"

analyze_java_code(java_code)

Method Names: Y
Method Invocations: ['J', 'log', 'sqrt', 'sin', 'cos']


It retrieves the AST sequence, method/function names, and source code for a given list of code lines in multiple languages

In [28]:

def get_ast_name(languages, code_lines):
    assert len(languages) == len(code_lines)
    langs = []
    asts = []
    names = []
    codes = []
#Iterates through each pair of language and code line using zip(languages, code_lines)
    for lang, line in zip(languages, code_lines):
        try:
            #Tries to parse the code line into an AST using parse_ast.
            tree = parse_ast(line, lang=lang)
            #Generates the AST sequence using generate_statement_xsbt.
            ast = generate_statement_xsbt(line, tree.root_node)
            #retrieves the method/function name using get_method_name
            name = get_method_name(line, lang=lang, root=tree.root_node)
            #appends the language, AST sequence (as a string), method/function name, and source code to the respective lists.
            langs.append(lang)
            asts.append(' '.join(ast))
            names.append(name)
            codes.append(line)
        except Exception:
            continue
    return langs, codes, asts, names

These functions are used to generate AST (Abstract Syntax Tree) representations for a single source code sample, optionally including the method/function name.

In [29]:

def get_single_ast(lang, source):
    tree = parse_ast(source=source, lang=lang)
    ast = generate_statement_xsbt(tree.root_node)
    return ' '.join(ast)


def get_single_ast_name(lang, source):
    tree = parse_ast(source=source, lang=lang)
    ast = generate_statement_xsbt(tree.root_node)
    name = get_method_name(source, lang=lang, root=tree.root_node)
    return ast, name


In [30]:
def lang_sample(lang):
    import random, json
    with open(f'/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/dataset/pre_train/java/traintest.jsonl') as f:
        line = f.readline()#f.readlines()[random.randint(0, 1000)]
        data = json.loads(line)#json.loads(line.strip())
        name = data['func_name']
        source = data['code']
    return source, name


In [31]:
lang = 'java'

source, name = lang_sample(lang)
print('-' * 100)
print('source:')
print(source)
print('-' * 100)
print('name in json:')
print(name)
print('-' * 100)

root = parse_ast(source, lang=lang)
print('method name:')
print(get_method_name(source=source, root=root, lang=lang))
print('-' * 100)
print('method_invocation:')
print(extract_method_invocation(source, root, lang))
print('-' * 100)
print('xsbt in statements:')
print(generate_statement_xsbt(root, lang))
print('-' * 100)
print('nl from code:')
print(extract_nl_from_code(source,root,lang,name))
print('-' * 100)
print('single ast from code:')
print(generate_single_ast_nl(source, lang, name))

----------------------------------------------------------------------------------------------------
source:
public static double Sinc(double x) {
        return Math.sin(Math.PI * x) / (Math.PI * x);
    }
----------------------------------------------------------------------------------------------------
name in json:
Tools.Sinc
----------------------------------------------------------------------------------------------------
method name:
Sinc
----------------------------------------------------------------------------------------------------
method_invocation:
['sin']
----------------------------------------------------------------------------------------------------
xsbt in statements:
return_statement__ binary_expression__ binary_expression parenthesized_expression__ binary_expression __parenthesized_expression __binary_expression __return_statement
----------------------------------------------------------------------------------------------------
nl from code:
tools sinc sin
-

In [32]:
from tokenizers import Tokenizer
from tokenizers.models import BPE, WordLevel
from tokenizers.trainers import BpeTrainer, WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers import normalizers
from tokenizers.normalizers import Strip, Lowercase, NFD, StripAccents
from tokenizers.processors import TemplateProcessing, BertProcessing

import pickle
import os
import logging

from typing import List, Union

import sys

In [33]:
logger = logging.getLogger(__name__)

This class will manage vocabulary creation, tokenization, and encoding/decoding of sequences

In [34]:
class Vocab(object):
    # special vocabulary symbols
    PAD_TOKEN = '[PAD]'  # padding token
    SOS_TOKEN = '[SOS]'  # start of sequence, also CLS
    EOS_TOKEN = '[EOS]'  # end of sequence
    UNK_TOKEN = '[UNK]'  # unknown token
    MSK_TOKEN = '[MSK]'  # mask token
    SEP_TOKEN = '[SEP]'  # sentence separator token
    # CLS_TOKEN = '[CLS]'     # classification placeholder

    # default special symbols, if need additional symbols, use init parameter 'additional_special_symbols'
    START_VOCAB = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN, MSK_TOKEN, SEP_TOKEN]

    # post-processors
    # bert processor: add SOS at the beginning and SEP at the end of sequence
    # bert_processor = BertProcessing(sep=(SEP_TOKEN, START_VOCAB.index(EOS_TOKEN)),
    #                                 cls=(SOS_TOKEN, START_VOCAB.index(SOS_TOKEN)))
    # sos processor: add SOS at the beginning of sequence
    sos_processor = TemplateProcessing(single=f'{SOS_TOKEN} $', pair=f'{SOS_TOKEN} $A $B',
                                       special_tokens=[(SOS_TOKEN, START_VOCAB.index(SOS_TOKEN))])
    # eos processor: add EOS at the end of sequence
    eos_processor = TemplateProcessing(single=f'$ {EOS_TOKEN}', pair=f'$A $B {EOS_TOKEN}',
                                       special_tokens=[(EOS_TOKEN, START_VOCAB.index(EOS_TOKEN))])
    # sep processor: add SEP at the end of sequence
    sep_processor = TemplateProcessing(single=f'$ {SEP_TOKEN}', pair=f'$A $B {SEP_TOKEN}',
                                       special_tokens=[(SEP_TOKEN, START_VOCAB.index(SEP_TOKEN))])

    def __init__(
            self,
            name,
            method,
            vocab_size=None,
            datasets: Union[List[str], List[List[str]]] = None,
            additional_special_symbols=None,
            ignore_case=False,
            save_root=None,
            index_offset=None,
    ):
        """
        Initialize a vocabulary and train the tokenizer.

        Args:
            name (str): Vocabulary name
            method (str): Tokenize method
            vocab_size (int): Maximum size of the vocabulary
            datasets (Union[List[str], List[List[str]]]): List of (file paths/list of string) to train the tokenizer
            additional_special_symbols (list[str]): Optional, list of custom special symbols
            ignore_case (bool): Ignore cases if True, default False
            save_root (str): Optional, if given, save to given root
            index_offset (int): Optional, the index offset when encoding and decoding.

        """
        assert method in ['word', 'bpe'], \
            'Tokenize method not supported, given {}, expect \'word\' or \'bpe\''.format(method)

        self.name = name
        self.method = method
        #It initializes __special_symbols with the default special symbols from Vocab.START_VOCAB
        self.__special_symbols = Vocab.START_VOCAB.copy()
        if additional_special_symbols:
            self.add_special_symbols(additional_special_symbols)
        self.ignore_case = ignore_case

        #It sets index_offset to the provided value or None.
        if index_offset is not None and index_offset != 0:
            self.index_offset = index_offset
        else:
            self.index_offset = None

        # tokenizer and trainer
        if method == 'word':
            tokenize_class = WordLevel
            trainer_class = WordLevelTrainer
        else:
            if vocab_size is None:
                logger.warning('It is recommended to specific the vocabulary size for BPE tokenize method')
            tokenize_class = BPE
            trainer_class = BpeTrainer

        #The tokenizer is initialized with an unknown token.
        self.tokenizer = Tokenizer(tokenize_class(unk_token=Vocab.UNK_TOKEN))
        #If vocab_size is provided, a trainer is created with special symbols and the specified vocabulary size; otherwise, only special symbols are used.
        if vocab_size:
            trainer = trainer_class(special_tokens=self.__special_symbols, vocab_size=vocab_size)
        else:
            trainer = trainer_class(special_tokens=self.__special_symbols)
        #It sets up the pre-tokenizer to use whitespace.
        self.tokenizer.pre_tokenizer = Whitespace()

        # normalizer
        #It sets up a normalizer to normalize the input text, optionally converting it to lowercase if ignore_case is True.
        if ignore_case:
            normalizer = normalizers.Sequence([NFD(), StripAccents(), Strip(), Lowercase()])
        else:
            normalizer = normalizers.Sequence([NFD(), StripAccents(), Strip()])
        self.tokenizer.normalizer = normalizer

    #If the dataset is a list of file paths, the tokenizer is trained using files with the specified trainer.
    #If the dataset is a list of lists of strings, the tokenizer is trained from the provided iterator.
    #If the dataset type is not supported, a TypeError is raised.
        # train tokenizer
        if isinstance(datasets[0], str):
            self.tokenizer.train(files=datasets, trainer=trainer)
        elif isinstance(datasets[0], list):
            self.tokenizer.train_from_iterator(iterator=datasets, trainer=trainer)
        else:
            raise TypeError('The type of datasets is not support, expect list of paths or list of lines')

        # pad idx
        #The pad_token_id is set to the index of the padding token.
        self.pad_token_id = self.get_pad_index()

        # save
        #If save_root is provided, the vocabulary is saved to the specified directory.
        if save_root:
            self.save(vocab_root=save_root)

#The add_special_symbols method allows adding custom special symbols to the vocabulary.
    def add_special_symbols(self, symbols: list):
        assert isinstance(symbols, list)
        for symbol in symbols:
            assert isinstance(symbol, str)
    #If the symbol is not already in the list of special symbols (__special_symbols), it appends the symbol to the list.
            if symbol not in self.__special_symbols:
                self.__special_symbols.append(symbol)

# returns the index of a given word in the vocabulary
    def get_index(self, token: str) -> int:
        """
        Return the index of given word, if the given word is not in the vocabulary, return the index of UNK token.

        Args:
            token (str): Word in str

        Returns:
            int: Index of the given word, [UNK] if OOV

        """
    #It first checks if the vocabulary is case-insensitive (ignore_case is True). If so, it converts the token to lowercase.
        if self.ignore_case:
            token = token.lower()
    #Then, it uses the tokenizer.token_to_id() method to retrieve the index of the token.
#If an index_offset is specified and the token is not in the list of special symbols (__special_symbols), it adds the offset to the index.
        index = self.tokenizer.token_to_id(token)
#If the token is not in the vocabulary, it returns the index of the UNK token.
        if self.index_offset and token not in self.__special_symbols:
            index += self.index_offset
#Otherwise, it returns the obtained index.
        return index if index else self.tokenizer.token_to_id(Vocab.UNK_TOKEN)

    def get_token(self, index: int) -> str:
        """
        Return the corresponding token of the given index, if not in the vocabulary, return index of UNK.

        Args:
            index: Given index

        Returns:
            str: Token of the given index

        """
        if self.index_offset:
            if index >= (len(self.__special_symbols) + self.index_offset):
                index -= self.index_offset
            elif len(self.__special_symbols) <= index < (len(self.__special_symbols) + self.index_offset):
                index = self.get_unk_index()
        token = self.tokenizer.id_to_token(index)
        return token if token else Vocab.UNK_TOKEN

    def transfer_index(self, index):
        """
        Return the transferred index based on the index offset

        Args:
            index (int): Index to transfer

        Returns:
            int: Transferred index

        """
        if not self.index_offset or index < len(self.__special_symbols):
            return index
        return index + self.index_offset

    def restore_index(self, index):
        """
        Return the restored index based on the base index

        Args:
            index (int): Index to restore

        Returns:
            int: Restored index

        """
        if not self.index_offset or index < len(self.__special_symbols):
            return index
        if index < self.index_offset:
            return self.get_unk_index()
        return index - self.index_offset
    # Methods like encode_sequence, encode_batch, decode, and decode_batch handle encoding and decoding of sequences using the tokenizer.
    def encode_sequence(self, sequence: Union[str, List[str]], is_pre_tokenized=False):
        """
        Encode a sequence to corresponding ids.

        Args:
            sequence (Union[str, List[str]]): Sequence to be encoded,
                when is_pre_tokenized is False, the type should be str,
                when is_pre_tokenized is True, the type should be List[str]
            is_pre_tokenized (bool): Whether the input is already pre-tokenized

        Returns:
            list[int], list[int]: indices and mask for sequence

        """
        if self.ignore_case:
            sequence = sequence.lower()
        encoded = self.tokenizer.encode(sequence=sequence, is_pretokenized=is_pre_tokenized)
        ids = [self.transfer_index(index) for index in encoded.ids]
        return ids, encoded.attention_mask

    def encode_batch(self, batch: Union[List[str], List[List[str]]], is_pre_tokenized=False,
                     pad=False, max_length=None):
        """
        Encode a batch of sequences.

        Args:
            batch (Union[List[str], List[List[str]]]): batch of sequences to be encoded,
                when is_pre_tokenized is False, the type should be List[str],
                when is_pre_tokenized is True, the type should be List[List[str]]
            is_pre_tokenized (bool): Whether the input is already pre-tokenized
            pad (bool): Whether to pad each of the sequence
            max_length (int): The length to padding

        Returns:
            (list[list[int]], list[list[int]]):
                - encoded batch of indices
                - encoded batch of attention masks

        """
        if self.ignore_case:
            batch = [sequence.lower() if isinstance(sequence, str) else [token.lower() for token in sequence]
                     for sequence in batch]
        if pad:
            self.tokenizer.enable_padding(length=max_length)
        else:
            self.tokenizer.no_padding()
        encoded_batch = self.tokenizer.encode_batch(input=batch, is_pretokenized=is_pre_tokenized)
        ids = [[self.transfer_index(index) for index in encoded.ids] for encoded in encoded_batch]
        attention_mask = [encoded.attention_mask for encoded in encoded_batch]
        return ids, attention_mask

    def decode(self, ids: List[int], skip_special_tokens=True) -> str:
        """
        Decode the given list of ids back to a string.

        Args:
            ids (list[int]): The list of ids that we want to decode
            skip_special_tokens (bool): Whether the special tokens should be removed from the decoded string,
                default True

        Returns:
            str: The decoded string

        """
        if self.index_offset:
            ids = [self.restore_index(index) for index in ids]
        return self.tokenizer.decode(ids=ids, skip_special_tokens=skip_special_tokens)

    def decode_batch(self, batch: List[List[int]], skip_special_tokens=True) -> List[str]:
        """
        Decode a batch of ids back to their corresponding string.

        Args:
            batch (list[list[str]]): The batch of sequences we want to decode
            skip_special_tokens (bool): Whether the special tokens should be removed from the decoded string,
                default True

        Returns:
            list[str]: The list of decoded string

        """
        if self.index_offset:
            batch = [[self.restore_index(index) for index in seq] for seq in batch]
        return self.tokenizer.decode_batch(sequences=batch, skip_special_tokens=skip_special_tokens)
    
#Methods like get_pad_index, get_sos_index, get_eos_index, get_unk_index, get_mask_index, num_special_token, __len__, and __contains__ 
#provide various utility functionalities such as getting special token indices, checking if a token is in the vocabulary, etc.
    def get_pad_index(self):
        return self.tokenizer.token_to_id(Vocab.PAD_TOKEN)

    def get_sos_index(self):
        return self.tokenizer.token_to_id(Vocab.SOS_TOKEN)

    def get_eos_index(self):
        return self.tokenizer.token_to_id(Vocab.EOS_TOKEN)

    def get_unk_index(self):
        return self.tokenizer.token_to_id(Vocab.UNK_TOKEN)

    def get_mask_index(self):
        return self.tokenizer.token_to_id(Vocab.MSK_TOKEN)

    def save(self, vocab_root, name=None):
        """
        Save the vocabulary to the given directory.

        Args:
            vocab_root (str): Parent directory to be saved
            name (str): Vocabulary name

        """
        vocab_name = name if name else self.name
        vocab_dir = os.path.join(vocab_root, vocab_name)
        if not os.path.exists(vocab_dir) or not os.path.isdir(vocab_dir):
            os.makedirs(vocab_dir)

        # save pickle file for whole instance
        with open(os.path.join(vocab_dir, '{}.pk'.format(vocab_name)), mode='wb') as f:
            pickle.dump(self, f)
        # save tokenizer
        self.tokenizer.save(os.path.join(vocab_dir, '{}_tokenizer.json'.format(vocab_name)))
        # save token to id mapping as a txt file
        with open(os.path.join(vocab_dir, '{}_mapping.txt'.format(vocab_name)), mode='w', encoding='utf-8') as f:
            for token, index in sorted(self.tokenizer.get_vocab().items(), key=lambda item: item[1]):
                f.write('{}\t{}\n'.format(token, index))

    def save_pretrained(self, output_dir):
        return

    def num_special_token(self):
        return len(self.__special_symbols)
    
#The save method saves the vocabulary to a directory, including the tokenizer and token-to-id mappings. 
    def save_pickle(self, path):
        """Save to binary pickle file"""
        with open(path, mode='wb') as f:
            pickle.dump(self, f)
        logger.info(f'Vocab saved to {path}')

    def __len__(self):
        return self.tokenizer.get_vocab_size()

    def __contains__(self, item):
        """
        Return True if the given token is in the vocab, else False.

        Args:
            item (str): Word to query

        Returns:
            bool: True if the given token is in the vocab, else False.

        """
        if self.ignore_case:
            item = item.lower()
        return True if self.tokenizer.token_to_id(item) else False

In [35]:
vocab = Vocab(
    name="example_vocab",
    method="word",
    vocab_size=100,
    datasets=["/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/dataset/pre_train/java/train/token.code"]
)

# Test add_special_symbols
additional_special_symbols = ['[CLS]', '[MASK]']
vocab.add_special_symbols(additional_special_symbols)
print("Special symbols after adding:", vocab._Vocab__special_symbols)

# Test get_index
token = 'token'
print("Index of '{}':".format(token), vocab.get_index(token))

# Test get_token
index = 5
print("Token at index {}: ".format(index), vocab.get_token(index))

# Test transfer_index
index_to_transfer = 3
print("Transferred index for {}: ".format(index_to_transfer), vocab.transfer_index(index_to_transfer))

# Test restore_index
index_to_restore = 8
print("Restored index for {}: ".format(index_to_restore), vocab.restore_index(index_to_restore))

# Test encode_sequence
sequence = "public static double Sinc(double x) {\r\n        return Math.sin(Math.PI * x) / (Math.PI * x);\r\n    }"
ids, attention_mask = vocab.encode_sequence(sequence)
print("Encoded sequence indices:", ids)
print("Attention mask:", attention_mask)

# Test encode_batch
batch = ["public static double Sinc(double x) {\r\n        return Math.sin(Math.PI * x) / (Math.PI * x);\r\n    }","public static double Y(int n, double x) {\r\n        double by, bym, byp, tox;\r\n\r\n        if (n == 0) return Y0(x);\r\n        if (n == 1) return Y(x);\r\n\r\n        tox = 2.0 / x;\r\n        by = Y(x);\r\n        bym = Y0(x);\r\n        for (int j = 1; j < n; j++) {\r\n            byp = j * tox * by - bym;\r\n            bym = by;\r\n            by = byp;\r\n        }\r\n        return by;\r\n    }"]
batch_ids, batch_attention_mask = vocab.encode_batch(batch)
print("Encoded batch indices:", batch_ids)
print("Batch attention masks:", batch_attention_mask)

# Test decode
decoded_sequence = vocab.decode(ids)
print("Decoded sequence:", decoded_sequence)

# Test decode_batch
decoded_batch = vocab.decode_batch(batch_ids)
print("Decoded batch:", decoded_batch)

# Test get_pad_index
print("PAD token index:", vocab.get_pad_index())

# Test get_sos_index
print("SOS token index:", vocab.get_sos_index())

# Test get_eos_index
print("EOS token index:", vocab.get_eos_index())

# Test get_unk_index
print("UNK token index:", vocab.get_unk_index())

# Test get_mask_index
print("MSK token index:", vocab.get_mask_index())

# Test num_special_token
print("Number of special tokens:", vocab.num_special_token())

# Test __len__
print("Vocabulary size:", len(vocab))

# Test __contains__
token_to_check = 'token'
print("'{}' is in vocabulary:".format(token_to_check), token_to_check in vocab)

Special symbols after adding: ['[PAD]', '[SOS]', '[EOS]', '[UNK]', '[MSK]', '[SEP]', '[CLS]', '[MASK]']
Index of 'token': 3
Token at index 5:  [SEP]
Transferred index for 3:  3
Restored index for 8:  8
Encoded sequence indices: [22, 37, 67, 3, 6, 67, 3, 7, 11, 16, 3, 8, 3, 6, 3, 8, 3, 82, 3, 7, 81, 6, 3, 8, 3, 82, 3, 3, 12]
Attention mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Encoded batch indices: [[22, 37, 67, 3, 6, 67, 3, 7, 11, 16, 3, 8, 3, 6, 3, 8, 3, 82, 3, 7, 81, 6, 3, 8, 3, 82, 3, 3, 12], [22, 37, 67, 3, 6, 26, 87, 10, 67, 3, 7, 11, 67, 3, 10, 3, 10, 3, 10, 3, 9, 15, 6, 87, 31, 27, 7, 16, 3, 6, 3, 3, 15, 6, 87, 31, 39, 7, 16, 3, 6, 3, 3, 3, 13, 97, 8, 27, 81, 3, 9, 3, 13, 3, 6, 3, 3, 3, 13, 3, 6, 3, 3, 35, 6, 26, 3, 13, 39, 9, 3, 17, 87, 9, 3, 3, 11, 3, 13, 3, 82, 3, 82, 3, 40, 3, 9, 3, 13, 3, 9, 3, 13, 3, 9, 12, 16, 3, 9, 12]]
Batch attention masks: [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

This code will load the vocabulary instance.

In [36]:
def load_vocab(vocab_root, name) -> Vocab:
    """
    Load vocabulary instance from pickle file, which should locates in the sub-directory named given vocabulary name.

    Args:
        vocab_root (BytesPath): Root of the vocabulary
        name (str): Name of the vocabulary

    Returns:
        Vocab: Loaded vocab instance

    """
    vocab_full_path = os.path.abspath(os.path.join(vocab_root, name, '{}.pk'.format(name)))
    if not os.path.exists(vocab_full_path):
        logger.info('-' * 100)
        logger.info('Not Exist: %s', vocab_full_path)
        logger.info('-' * 100)
        sys.exit        

    with open(os.path.join(vocab_root, name, '{}.pk'.format(name)), mode='rb') as f:
        obj = pickle.load(f)
    assert isinstance(obj, Vocab)
    return obj


This code will initialize the vocabulary and save the path.

In [37]:
def init_vocab(vocab_save_dir,
               name,
               method='word',
               vocab_size=None,
               datasets: Union[List[str], List[List[str]]] = None,
               additional_special_symbols=None,
               ignore_case=False,
               save_root=None,
               index_offset=None,
               load_if_saved=True) -> Vocab:
    vocab_name = '.'.join(
        [sub_name for sub_name in [name, method, str(vocab_size), str(index_offset)] if sub_name is not None])
    path = os.path.join(vocab_save_dir, f'{vocab_name}.pk')
    if load_if_saved:
        if os.path.exists(path) and os.path.isfile(path):
            logger.info(f'Trying to load saved binary pickle file from: {path}')
            with open(path, mode='rb') as f:
                obj = pickle.load(f)
            assert isinstance(obj, Vocab)
            if save_root:
                obj.save(save_root)
            return obj
    vocab = Vocab(name=name,
                  method=method,
                  vocab_size=vocab_size,
                  datasets=datasets,
                  additional_special_symbols=additional_special_symbols,
                  ignore_case=ignore_case,
                  save_root=save_root,
                  index_offset=index_offset)
    vocab.save_pickle(path)
    return vocab

Datasets:

In [38]:
import sys
sys.path.append('/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/spt-code/sources')

In [39]:
from data.antlr_parsers.java.Java8Lexer import Java8Lexer

In [40]:
from tqdm import tqdm
from antlr4 import InputStream
import nltk

In [41]:
STRING_MATCHING_PATTERN = re.compile(r'([bruf]*)(\"\"\"|\'\'\'|\"|\')(?:(?!\2)(?:\\.|[^\\]))*\2')
NON_SPACE_MATCHING_PATTERN = re.compile(r'\S')

In [42]:
MAPPING_LANG_LEXER = {
    LANG_JAVA: Java8Lexer
}


In [43]:

main_args = None

def set_args(args):
    global main_args
    main_args = args

In [44]:
def load_eval_lines(path):
    """
    Load and eval lines from given path.

    Args:
        path (str): Dataset file path

    Returns:
        list: List of lines

    """
    with open(path, encoding='utf-8') as f:
        lines = [eval(line.strip()) for line in f]
    return lines


In [45]:

def load_eval_list_lines(path):
    """
    Load and eval lines from given path, each line is a list that will be convert into a string.

    Args:
        path (str): Dataset file path

    Returns:
        list: List of lines

    """
    lines = []
    with open(path, encoding='utf-8') as f:
        for line in f.readlines():
            tokens = eval(line.strip())
            string = ' '.join(tokens)
            string = re.sub(r'\s+', ' ', string)
            lines.append(string)
    return lines

In [46]:

def load_lines(path):
    """
    Load lines from given path.

    Args:
        path (str): Dataset file path

    Returns:
        list: List of lines

    """
    with open(path, encoding='utf-8') as f:
        lines = [line.strip() for line in f]
    return lines

In [47]:
def trim_method_name(full_name):
    """
    Extract method/function name from its full name,
    e.g., RpcResponseResolver.resolveResponseObject -> resolveResponseObject

    Args:
        full_name (str): Full name

    Returns:
        str: Method/Function name

    """
    point_pos = full_name.rfind('.')
    if point_pos != -1:
        return full_name[point_pos + 1:]
    else:
        return full_name

In [48]:

def replace_string_literal(source):
    """
    Replace the string literal in source code with ``<STR>``.

    Args:
        source (str): Source code in string

    Returns:
        str: Code after replaced

    """
    return re.sub(pattern=STRING_MATCHING_PATTERN, repl='___STR', string=source)

In [49]:

import tokenize
from io import StringIO

def remove_comments_and_docstrings(source, lang):
    """
    Remove docs and comments from source string.
    Thanks to authors of GraphCodeBERT
    from: https://github.com/microsoft/CodeBERT/blob/master/GraphCodeBERT/codesearch/parser/utils.py#L4

    Args:
        source (str): Source code string
        lang (str): Source code language

    Returns:
        str: Source string

    """
    if lang == LANG_PYTHON:
        try:
            io_obj = StringIO(source)
            out = ""
            prev_token_type = tokenize.INDENT
            last_lineno = -1
            last_col = 0
            for tok in tokenize.generate_tokens(io_obj.readline):
                token_type = tok[0]
                token_string = tok[1]
                start_line, start_col = tok[2]
                end_line, end_col = tok[3]
                # l_text = tok[4]
                if start_line > last_lineno:
                    last_col = 0
                if start_col > last_col:
                    out += (" " * (start_col - last_col))
                # Remove comments:
                if token_type == tokenize.COMMENT:
                    pass
                # This series of conditionals removes docstrings:
                elif token_type == tokenize.STRING:
                    if prev_token_type != tokenize.INDENT:
                        # This is likely a docstring; double-check we're not inside an operator:
                        if prev_token_type != tokenize.NEWLINE:
                            if start_col > 0:
                                out += token_string
                else:
                    out += token_string
                prev_token_type = token_type
                last_col = end_col
                last_lineno = end_line
            temp = []
            for x in out.split('\n'):
                if x.strip() != "":
                    temp.append(x)
            return '\n'.join(temp)
        except Exception:
            return source
    elif lang in [LANG_RUBY]:
        return source
    else:
        def replacer(match):
            s = match.group(0)
            if s.startswith('/'):
                return " "  # note: a space and not an empty string
            else:
                return s

        pattern = re.compile(
            r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"',
            re.DOTALL | re.MULTILINE
        )
        temp = []
        for x in re.sub(pattern, replacer, source).split('\n'):
            if x.strip() != "":
                temp.append(x)
        return '\n'.join(temp)


In [50]:
import json
def parse_json_file(file, lang):
    """
    Parse a dataset file where each line is a json string representing a sample.

    Args:
        file (str): The file path
        lang (str): Source code language

    Returns:
        (list[str], list[str], list[str], list[str], List[str]):
            - List of source codes
            - List of tokenized codes
            - List of split method names
            - List of tokenized codes with method name replaced with ``f``
            - List of docstring strings, not every sample has it

    """
    sources = []
    codes = []
    names = []
    codes_wo_name = [] 
    docs = []

    # #######################################################################
    # Updated to reduce the time to parse, myoungkyu song, 03/23/2024
    if main_args.parse_subset_ratio:
        lines_to_extract = 0
        line_counter = 0
        total_lines = 0

        with open(file, encoding='utf-8') as f:
            total_lines = sum(1 for _ in f)
            lines_to_extract = int(total_lines * main_args.parse_subset_ratio)

        # if total_lines > 10_000:
        #     lines_to_extract = int(lines_to_extract * main_args.parse_subset_ratio)
        if total_lines > 100_000:
            lines_to_extract = int(lines_to_extract * main_args.parse_subset_ratio)

        logger.info('*' * 100)
        logger.info(f'{lang} => The size of trimmed / original pre_train set to parse: {lines_to_extract} / {total_lines}')
    # #######################################################################

    with open(file, encoding='utf-8') as f:
        for line in f.readlines():
            if main_args.parse_subset_ratio:
                if line_counter > lines_to_extract:
                    break
                line_counter += 1

            data = json.loads(line.strip())
            name = trim_method_name(data['func_name'])
            source = data['code'].strip()
            source = remove_comments_and_docstrings(source, lang)
            source = replace_string_literal(source)
            code = replace_string_literal(' '.join(data['code_tokens']))

            sources.append(source)
            codes.append(code)

            code_wo_name = code.replace(name, 'f', 1)
            codes_wo_name.append(code_wo_name)

            name = ' '.join(split_identifier(name))
            names.append(name)

            if 'docstring' in data:
                doc = clean_doc(data['docstring'])
                if doc:
                    docs.append(doc)

    return sources, codes, names, codes_wo_name, docs

In [51]:
def iter_all_files(base):
    """
    Iterator for all file paths in the given base path.

    Args:
        base (str): Path like string

    Returns:
        str: Path of each file
    """
    for root, ds, fs in os.walk(base):
        for f in fs:
            yield os.path.join(root, f)

In [52]:

def iter_pre_train_dataset_files(lang_dir, lang):
    """
    Get files for pre-training, all files with extension ``jsonl`` will be included.

    Args:
        lang_dir (str): Path of language dir
        lang (str): Source code language

    Returns:
        list[str]: List of paths of files

    """
    # if lang in [enums.LANG_PYTHON]:
    #     for file in iter_all_files(base=lang_dir):
    #         if file.endswith('.jsonl'):
    #             return [file]
    # if lang in [enums.LANG_PYTHON]:
    #     return [file for file in iter_all_files(base=lang_dir) if file.endswith('.jsonl')]
    if lang in [LANG_GO, LANG_JAVA, LANG_PYTHON, LANG_JAVASCRIPT, LANG_PHP,
                LANG_RUBY]:
        return [file for file in iter_all_files(base=lang_dir) if file.endswith('.jsonl')]
    return []

In [53]:
def load_pre_train_dataset(file, lang):
    """
    Load json dataset from given file.

    Args:
        file (str): Path of dataset file
        lang (str): Source code language

    Returns:
        (list[str], list[str], list[str], list[str], list[str]):
            - List of source code strings
            - List of tokenized code strings
            - List of nl strings
            - List of tokenized code strings with method names replaced
            - List of doc strings, not every sample has it

    """
    if lang in [LANG_JAVA, LANG_PYTHON, LANG_GO,
                LANG_JAVASCRIPT, LANG_PHP, LANG_RUBY]:
        sources, codes, names, codes_wo_name, docs = parse_json_file(file, lang=lang)
        return sources, codes, names, codes_wo_name, docs


In [54]:

def load_dataset_from_dir(dataset_dir):
    """
    Load all files in the given dir, only for pre-training.

    Args:
        dataset_dir (str): Root directory

    Returns:
        (dict, list[str], list[str], list[str], List[str], list[str], list[str], list[str], list[str], list[str]):
            - Dict of paths: key is the dataset group, value is the path
            - List of str: languages for each line
            - List of str: source code
            - List of str: tokenized code string
            - List of ast: linearized ast string
            - List of str: split method name string
            - List of str:
            - List of str:
            - List of str:
            - List of str: List of docs

    """
    paths = {}
    languages = []
    all_sources = []
    all_asts = []
    all_codes = []
    all_codes_wo_name = []
    all_names = []
    all_names_wo_name = []
    all_only_names = []
    all_docs = []

    if not os.path.exists(dataset_dir):
        logger.info('-' * 100)
        full_path_dataset_dir = os.path.abspath(dataset_dir)
        check_exist = os.path.exists(full_path_dataset_dir)
        logger.info('Directory Not Exist: %s', dataset_dir)
        logger.info('-' * 100)
        sys.exit()

    for file in os.listdir(dataset_dir):

        path = os.path.join(dataset_dir, file)
        if os.path.isfile(path):
            continue

        lang = file
        dataset_files = iter_pre_train_dataset_files(path, lang=lang)
        if len(dataset_files) > 0:
            logger.info(f'  Language: {lang}')
            paths[lang] = dataset_files
            n_sample = 0
            for dataset_file_path in dataset_files:
                sources, codes, names, codes_wo_name, docs = load_pre_train_dataset(file=dataset_file_path,
                                                                                    lang=lang)

                new_sources = []
                new_codes = []
                new_codes_wo_name = []
                new_names = []
                new_names_wo_name = []
                only_names = []
                asts = []
                for source, code, name, code_wo_name in tqdm(zip(sources, codes, names, codes_wo_name),
                                                             desc=f'Parsing {os.path.basename(dataset_file_path)}',
                                                             leave=False,
                                                             total=len(sources)):
                    try:
                        ast, nl, nl_wo_name = generate_single_ast_nl(source=source,
                                                                     lang=lang,
                                                                     name=name,
                                                                     replace_method_name=True)
                        new_sources.append(source)
                        new_codes.append(code)
                        new_codes_wo_name.append(code_wo_name)
                        new_names.append(nl)
                        new_names_wo_name.append(nl_wo_name)
                        asts.append(ast)
                        only_names.append(name)
                    except Exception:
                        continue

                all_sources += new_sources
                all_codes += new_codes
                all_codes_wo_name += new_codes_wo_name
                all_names += new_names
                all_names_wo_name += new_names_wo_name
                all_only_names += only_names
                all_asts += asts
                all_docs += docs

                n_line = len(new_sources)
                languages += [lang for _ in range(n_line)]
                n_sample += n_line

                logger.info(f'    File: {dataset_file_path}, {n_line} samples')

            logger.info(f'  {lang} dataset size: {n_sample}')

    assert len(languages) == len(all_sources) == len(all_codes) == len(all_codes_wo_name) == len(all_asts) == \
           len(all_names) == len(all_names_wo_name) == len(all_only_names)
    return paths, languages, all_sources, all_codes, all_asts, all_names, all_codes_wo_name, all_names_wo_name, \
           all_only_names, all_docs


In [55]:
# Load and print the lines from a file
lines = load_lines("/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/dataset/pre_train/java/train/sourcetest1.code")
print("Loaded lines:", lines)

# Trim the method name and print
full_name = "ClassLoader getBootstrapClassLoader"
trimmed_name = trim_method_name(full_name)
print("Trimmed method name:", trimmed_name)

# Replace string literals and print
source_code = "void pushBaseIndentifier(String baseID)\n  {\n\n    if (null != baseID)\n    {\n      int posOfHash = baseID.indexOf('#');\n\n      if (posOfHash > -1)\n      {\n        m_fragmentIDString = baseID.substring(posOfHash + 1);\n        m_shouldProcess = false;\n      }\n      else\n        m_shouldProcess = true;\n    }\n    else\n      m_shouldProcess = true;\n\n    m_baseIdentifiers.push(baseID);\n  }"
replaced_source = replace_string_literal(source_code)
print("Replaced string literals:", replaced_source)

# Remove comments and docstrings from source code and print
cleaned_source = remove_comments_and_docstrings(source_code, lang="java")
print("Cleaned source code:", cleaned_source)

# # Parse a JSON file and print
# parsed_data = parse_json_file("/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/dataset/pre_train/java/train.jsonl", lang="java")
# print("Parsed data:", parsed_data)


Loaded lines: ['public static double Sinc(double x) {\\r\\n        return Math.sin(Math.PI * x) / (Math.PI * x);\\r\\n    }', 'protected void modify(Transaction t) {\\n        try {\\n            this.lock.writeLock().lock();\\n            t.perform();\\n        } finally {\\n            this.lock.writeLock().unlock();\\n        }\\n    }', 'protected <E> E read(Supplier<E> sup) {\\n        try {\\n            this.lock.readLock().lock();\\n            return sup.get();\\n        } finally {\\n            this.lock.readLock().unlock();\\n        }\\n    }']
Trimmed method name: ClassLoader getBootstrapClassLoader
Replaced string literals: void pushBaseIndentifier(String baseID)
  {

    if (null != baseID)
    {
      int posOfHash = baseID.indexOf(___STR);

      if (posOfHash > -1)
      {
        m_fragmentIDString = baseID.substring(posOfHash + 1);
        m_shouldProcess = false;
      }
      else
        m_shouldProcess = true;
    }
    else
      m_shouldProcess = true;

    m

In [56]:

def trim_spaces(string):
    """
    Replace consecutive spaces with a single whitespace.

    Args:
        string (str): String

    Returns:
        str: Replaced string
    """
    return re.sub(r'\s+', ' ', string).strip()

In [57]:

def tokenize_python(source):
    """
    Python lib to tokenize python source code.

    Args:
        source (str): Source code string

    Returns:
        str: Tokenized code string

    """
    tokens = tokenize.generate_tokens(StringIO(source).readline)
    return ' '.join([token.string for token in tokens if token.string.strip() != ''])

In [58]:

def tokenize_source(source, lang, use_regular=False):
    """
    Tokenize the source code into tokens.

    Args:
        source (str): Source in string
        lang (str): Language of source code
        use_regular (bool): Whether to use regular tokenize method, default to False

    Returns:
        str: Tokenized code, delimited by whitespace, string literal will be replaced by ``___STR``

    """
    if use_regular:
        code = replace_string_literal(regular_tokenize(source))
        return trim_spaces(code)
    if lang == LANG_PYTHON:
        tokens = tokenize.generate_tokens(StringIO(source).readline)
        code = ' '.join([token.string for token in tokens])
        code = replace_string_literal(code)
        return trim_spaces(code)
    if lang in [LANG_JAVA, LANG_JAVASCRIPT, LANG_PHP, LANG_GO]:
        input_stream = InputStream(source)
        lexer = MAPPING_LANG_LEXER[lang](input_stream)
        tokens = [token.text for token in lexer.getAllTokens()]
        code = replace_string_literal(' '.join(tokens))
        return trim_spaces(code)
    elif lang == LANG_RUBY:
        tokens = MAPPING_LANG_LEXER[lang].get_pure_tokens(source)
        code = replace_string_literal(' '.join([token[0] for token in tokens]))
        return trim_spaces(code)
    else:
        # TODO: c# tokenize
        code = replace_string_literal(regular_tokenize(source))
        return trim_spaces(code)

In [59]:

def count_non_space_chars(s):
    """
    Count the non-space characters.

    Args:
        s (str): String to be counted

    Returns:
        int: Number of non-space characters

    """
    matches = re.findall(NON_SPACE_MATCHING_PATTERN, s)
    return len(matches)


In [60]:

def align_source_code(former_source, code):
    """
    Align former source to code token string and split code into former one and latter one.

    Args:
        former_source (str): Former part of the source
        code (str): Tokenized source code

    Returns:
        (str, str):
            - Former part of tokenized code
            - Latter part of tokenized code

    """
    former_count = count_non_space_chars(former_source)
    total = 0
    code_tokens = code.split(' ')
    token_index = 0
    while total < former_count:
        total += count_non_space_chars(code_tokens[token_index])
        token_index += 1
    former_code = ' '.join(code_tokens[:token_index])
    latter_code = ' '.join(code_tokens[token_index:])
    return former_code, latter_code


In [61]:

def regular_tokenize(source: str):
    """
    NLTK word tokenize with simple adoptions for source code.

    Args:
        source (str): Source code string.

    Returns:
        str: Tokenized code string
    """
    source = re.sub(r'(\S)[.=](\S)', r'\1 . \2', source)
    return ' '.join(nltk.word_tokenize(source))



In [62]:

def clean_doc(s):
    """
    Clean docstring.

    Args:
        s (str): Raw docstring

    Returns:
        str: Cleaned docstring

    """
    # // Create an instance of  {@link RepresentationBaseType } and {@link RepresentationBaseType }.
    # // Create an instance of RepresentationBaseType and RepresentationBaseType
    # // Public setter for the  {@code rowMapper}.
    # // Public setter for the rowMapper
    # comment = comment.replaceAll("\\{@link|code(.*?)}", "$1");
    # comment = comment.replaceAll("@see", "");

    s = re.sub(r'{@link|code(.*?)}', r'\1', s)
    s = re.sub(r'@see', '', s)

    # // Implementation of the <a href="http://www.tarsnap.com/scrypt/scrypt.pdf"/>scrypt KDF</a>.
    # // Implementation of the scrypt KDF
    # comment = comment.replaceAll("<a.*?>(.*?)a>", "$1");
    s = re.sub(r'<a.*?>(.*?)a>', r'\1', s)

    # // remove all tags like <p>, </b>
    # comment = comment.replaceAll("</?[A-Za-z0-9]+>", "");
    s = re.sub(r'</?[A-Za-z0-9]+>', '', s)

    # // Set the list of the watchable objects (meta data).
    # // Set the list of the watchable objects
    # comment = comment.replaceAll("\\(.*?\\)", "");
    s = re.sub(r'\(.*?\)', '', s)

    # // #dispatchMessage dispatchMessage
    # // dispatchMessage
    # comment = comment.replaceAll("#([\\w]+)\\s+\\1", "$1");
    s = re.sub(r'#([\w]+)\s+\1', r'\1', s)

    # // remove http url
    # comment = comment.replaceAll("http\\S*", "");
    s = re.sub(r'http\S*', '', s)

    # // characters except english and number are ignored.
    # comment = comment.replaceAll("[^a-zA-Z0-9_]", " ");
    s = re.sub(r'[^a-zA-Z0-9_]', ' ', s)

    # // delete empty symbols
    # comment = comment.replaceAll("[ \f\n\r\t]", " ").trim();
    # comment = comment.replaceAll(" +", " ");
    s = re.sub(r'[ \f\n\r\t]', ' ', s).strip()
    s = re.sub(r' +', ' ', s).strip()

    if len(s) == 0 or len(s.split()) < 3:
        return None
    else:
        return s

In [63]:

def convert_python_source_classical_summarization(source: str):
    source = re.sub(r' *DCNL *', '\n', source)
    source = re.sub(r' *DCSP *', '\t', source)
    return source

In [64]:
# Trim spaces in a string and print
trimmed_string = trim_spaces("public void put(String hostname, int netId, InetAddress[] addresses ) {\n        cache.put(new AddressCacheKey(hostname, netId), new AddressCacheEntry(addresses));\n    }")
print("Trimmed string:", trimmed_string)

# Tokenize Python source code and print
source_code = "public void put(String hostname, int netId, InetAddress[] addresses ) {\n        cache.put(new AddressCacheKey(hostname, netId), new AddressCacheEntry(addresses));\n    }"
tokenized_code = tokenize_python(source_code)
print("Tokenized Python code:", tokenized_code)

# Tokenize source code and print
source_code = "public void put(String hostname, int netId, InetAddress[] addresses ) {\n        cache.put(new AddressCacheKey(hostname, netId), new AddressCacheEntry(addresses));\n    }"
tokenized_code = tokenize_source(source_code, lang="python")
print("Tokenized code:", tokenized_code)

# Count non-space characters in a string and print
count = count_non_space_chars("public void put(String hostname, int netId, InetAddress[] addresses ) {\n        cache.put(new AddressCacheKey(hostname, netId), new AddressCacheEntry(addresses));\n    }")
print("Count of non-space characters:", count)

# Align source code and print
former_source = "public static double Sinc(double x) {\r\n        return Math.sin(Math.PI * x) / (Math.PI * x);\r\n    }"
code = "public static double Sinc ( double x ) { return Math . sin ( Math . PI * x ) / ( Math . PI * x ) ; }"
aligned_former_code, latter_code = align_source_code(former_source, code)
print("Aligned former code:", aligned_former_code)
print("Latter code:", latter_code)

# Tokenize source code using NLTK and print
source_code = "public void put(String hostname, int netId, InetAddress[] addresses ) {\n        cache.put(new AddressCacheKey(hostname, netId), new AddressCacheEntry(addresses));\n    }"
tokenized_code = regular_tokenize(source_code)
print("Regular tokenized code:", tokenized_code)

# Clean docstring and print
docstring = "Gets the proper modulus operation.\n\n@param x Integer.\n@param m Modulo.\n@return Modulus."
cleaned_docstring = clean_doc(docstring)
print("Cleaned docstring:", cleaned_docstring)

# Convert Python source code for classical summarization and print
source_code = "DCNLdef hello_world():DCSPDCNLprint('Hello, world!')DCNL"
converted_source = convert_python_source_classical_summarization(source_code)
print("Converted source code:", converted_source)



Trimmed string: public void put(String hostname, int netId, InetAddress[] addresses ) { cache.put(new AddressCacheKey(hostname, netId), new AddressCacheEntry(addresses)); }
Tokenized Python code: public void put ( String hostname , int netId , InetAddress [ ] addresses ) { cache . put ( new AddressCacheKey ( hostname , netId ) , new AddressCacheEntry ( addresses ) ) ; }
Tokenized code: public void put ( String hostname , int netId , InetAddress [ ] addresses ) { cache . put ( new AddressCacheKey ( hostname , netId ) , new AddressCacheEntry ( addresses ) ) ; }
Count of non-space characters: 141
Aligned former code: public static double Sinc ( double x ) { return Math . sin ( Math . PI * x ) / ( Math . PI * x ) ; }
Latter code: 
Regular tokenized code: public void put ( String hostname , int netId , InetAddress [ ] addresses ) { cache . put ( new AddressCacheKey ( hostname , netId ) , new AddressCacheEntry ( addresses ) ) ; }
Cleaned docstring: Gets the proper modulus operation param x I

In [65]:

def parse_for_summarization(source_path, code_path, nl_path, lang):
    """
    Load and parse dataset for code summarization.

    Args:
        source_path (str): Path of source code dataset
        code_path (str): Path of tokenized code dataset, if not file not exist, tokenize on the fly
        nl_path (str): Path of comment dataset
        lang (str): Source code language

    Returns:
        (Dict, list[str], list[str], list[str], list[str]):
            - Dict mapping dataset groups to paths
            - List of tokenized code strings
            - List of linearized AST strings
            - List of name and API sequence strings
            - List of comment strings

    """
    paths = {'source': source_path}
    logger.info(f'    Source code file: {source_path}')
    sources = load_lines(source_path)
    # if lang == enums.LANG_PYTHON:
    #     sources = [convert_python_source_classical_summarization(source) for source in sources]

    if not os.path.isfile(code_path):
        paths['code'] = source_path
        logger.info('    Tokenize source code')
        codes = [tokenize_source(source, lang=lang) for source in sources]
    else:
        paths['code'] = code_path
        logger.info(f'    Tokenized code file: {code_path}')
        codes = load_lines(code_path)
    paths['nl'] = nl_path
    logger.info(f'    Summarization file: {nl_path}')
    nls = load_lines(nl_path)
    # sources, codes, nls = sources[:1000], codes[:1000], nls[:1000]
    assert len(sources) == len(codes) == len(nls)

    new_codes = []
    new_nls = []
    names = []
    asts = []
    for source, code, nl in tqdm(zip(sources, codes, nls), desc='Parsing', leave=False, total=len(sources)):
        try:
            source = remove_comments_and_docstrings(source, lang=lang)
            ast, name = generate_single_ast_nl(source=source, lang=lang) #It generates the AST and method name for each source code using the generate_single_ast_nl function.
            new_codes.append(code)
            new_nls.append(nl)
            names.append(name)
            asts.append(ast)
        except Exception:
            continue

    return paths, new_codes, asts, names, new_nls

In [66]:
# Parse for summarization and print
paths, codes, asts, names, nls = parse_for_summarization("/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/dataset/pre_train/java/train/sourcetest1.code", "/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/dataset/pre_train/java/train/token1.code", "/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/dataset/pre_train/java/train/raw1.docstring", lang="java")
print("Parsed paths:", paths)
print("Parsed codes:", codes)
print("Parsed asts:", asts)
print("Parsed names:", names)
print("Parsed nls:", nls)

                                              

Parsed paths: {'source': '/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/dataset/pre_train/java/train/sourcetest1.code', 'code': '/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/dataset/pre_train/java/train/token1.code', 'nl': '/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/dataset/pre_train/java/train/raw1.docstring'}
Parsed codes: ['public static double Sinc ( double x ) { return Math . sin ( Math . PI * x ) / ( Math . PI * x ) ; }', 'protected void modify ( Transaction t ) { try { this . lock . writeLock ( ) . lock ( ) ; t . perform ( ) ; } finally { this . lock . writeLock ( ) . unlock ( ) ; } }', 'protected < E > E read ( Supplier < E > sup ) { try { this . lock . readLock ( ) . lock ( ) ; return sup . get ( ) ; } finally { this . lock . readLock ( ) . unlock ( ) ; } }']
Parsed asts: ['return_statement__ binary_expression__ binary_expression parenthesized_expression__ binary_expression __parenthesized_expression __binary_expression __return_st



In [67]:
def parse_for_translation(source_path, source_lang, target_path, target_lang):
    """
    Load and parse for code translation.

    Args:
        source_path (str): Path of source dataset
        source_lang (str): Source language
        target_path (str): Path of target dataset
        target_lang (str): Target language

    Returns:
        (list[str], list[str], list[str], list[str]):
            - List of tokenized code strings
            - List of linearized AST strings
            - List of name and API sequence strings
            - List of tokenized target code strings

    """
    logger.info(f'    Source file: {source_path}')
    sources = load_lines(source_path)
    print(sources)
    logger.info(f'    Target file: {target_path}')
    targets = load_lines(target_path)
    print(targets)

    new_sources = []
    new_targets = []
    asts = []
    names = []
    for source, target in tqdm(zip(sources, targets), desc='Parsing', leave=False, total=len(sources)):
        try:
            source = remove_comments_and_docstrings(source, lang=source_lang)
            source = replace_string_literal(source)
            target = remove_comments_and_docstrings(target, lang=target_lang)
            target = replace_string_literal(target)

            ast, name = generate_single_ast_nl(source=source, lang=source_lang)
            code = tokenize_source(source=source, lang=source_lang, use_regular=True)
            tokenized_target = tokenize_source(source=target, lang=target_lang, use_regular=True)

            new_sources.append(code)
            asts.append(ast)
            names.append(name)
            new_targets.append(tokenized_target)
        except Exception:
            continue
    print(asts)
    return new_sources, asts, names, new_targets


In [68]:
# Parse for translation and print
source_path = "/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/dataset/pre_train/java/train/sourcetest1.code"
source_lang = "java"
target_path = "/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/dataset/targetpath.code"
target_lang = "python"
sources, asts, names, targets = parse_for_translation(source_path, source_lang, target_path, target_lang)
print("Parsed translation sources:", sources)
print("Parsed translation ASTs:", asts)
print("Parsed translation names:", names)
print("Parsed translation targets:", targets)

['public static double Sinc(double x) {\\r\\n        return Math.sin(Math.PI * x) / (Math.PI * x);\\r\\n    }', 'protected void modify(Transaction t) {\\n        try {\\n            this.lock.writeLock().lock();\\n            t.perform();\\n        } finally {\\n            this.lock.writeLock().unlock();\\n        }\\n    }', 'protected <E> E read(Supplier<E> sup) {\\n        try {\\n            this.lock.readLock().lock();\\n            return sup.get();\\n        } finally {\\n            this.lock.readLock().unlock();\\n        }\\n    }']
['import math \\n \\nSinc = lambda x: 1.0 if x == 0 else math.sin(math.pi * x) / (math.pi * x)']


                                              

['return_statement__ binary_expression__ binary_expression parenthesized_expression__ binary_expression __parenthesized_expression __binary_expression __return_statement']
Parsed translation sources: ['public static double Sinc ( double x ) { \\r\\n return Math . sin ( Math . PI * x ) / ( Math . PI * x ) ; \\r\\n }']
Parsed translation ASTs: ['return_statement__ binary_expression__ binary_expression parenthesized_expression__ binary_expression __parenthesized_expression __binary_expression __return_statement']
Parsed translation names: ['sinc h s']
Parsed translation targets: ['import math \\n \\nSinc = lambda x : 1 . 0 if x == 0 else math . sin ( math . pi * x ) / ( math . pi * x )']




In [69]:
import json
def parse_for_search(dataset_dir, lang, split):
    """
    Load and parse for code search.

    Args:
        dataset_dir (str): Directory of the dataset
        lang (str): Source code language
        split (str): Split set of the dataset, support `train`, `valid`, `test`, `codebase`

    Returns:
        (list[str], list[str], list[str], list[str])
            - List of tokenized code strings
            - List of AST sequences
            - List of name strings
            - List of nl strings

    """
    urls = []
    codes = []
    asts = []
    names = []
    nls = []

    print(f"Parsing {split} dataset")
    path = os.path.join(dataset_dir, f'{split}.jsonl')

    # #######################################################################
    # Updated to reduce the time to parse, myoungkyu song, 03/23/2024
    if main_args.parse_subset_ratio:
        lines_to_extract = 0
        line_counter = 0
        total_lines = 0

        with open(path, encoding='utf-8') as f:
            total_lines = sum(1 for _ in f)
            lines_to_extract = int(total_lines * main_args.parse_subset_ratio)

        if total_lines > 10_000:
            lines_to_extract = int(lines_to_extract * main_args.parse_subset_ratio)
        if total_lines > 100_000:
            lines_to_extract = int(lines_to_extract * main_args.parse_subset_ratio)

        logger.info('*' * 100)
        logger.info(f'{lang} => The size of trimmed / original fine tunning {split} set to parse: {lines_to_extract} / {total_lines}')
    # #######################################################################

    with open(path, encoding='utf-8') as f:
        logger.info(f'  File: {path}')
        for line in tqdm(f.readlines()):
            if main_args.parse_subset_ratio: # myoungkyu song, 03/23/2024
                if line_counter > lines_to_extract:
                    break
                line_counter += 1

            data = json.loads(line.strip())
            if split in ['train', 'valid', 'test']:
                if 'docstring' not in data:
                    continue
                nl = clean_doc(data['docstring'])
                if nl is None:
                    continue
            try:
                if split in ['codebase', 'train']:
                    code = replace_string_literal(' '.join(data['code_tokens']))
                    name = trim_method_name(data['func_name'])

                    source = data['code'].strip()
                    source = remove_comments_and_docstrings(source, lang)
                    source = replace_string_literal(source)
                    ast, name = generate_single_ast_nl(source=source, lang=lang, name=name)

                    codes.append(code)
                    asts.append(ast)
                    names.append(name)

                if split in ['train', 'valid', 'test']:
                    nls.append(nl)

                if split != 'train':
                    url = data['url']
                    urls.append(url)
            except Exception:
                continue

    if split == 'codebase':
        return urls, codes, asts, names
    elif split == 'train':
        return codes, asts, names, nls
    elif split in ['valid', 'test']:
        return urls, nls

In [70]:

def load_clone_mapping(dataset_root):
    """
    Load json file and transfer to a mapping from code id to source code.

    Args:
        dataset_root (str): Root of the dataset

    Returns:
        dict: Mapping from code id to source code

    """
    path = os.path.join(dataset_root, 'fine_tune', TASK_CLONE_DETECTION, 'data.jsonl')
    if not os.path.exists(path):
        return None
    mapping = dict()
    with open(path, encoding='utf-8') as f:
        for line in f.readlines():
            data = json.loads(line.strip())
            code_id = data['idx']
            source = data['func'].strip()
            mapping[code_id] = source
    return mapping


In [71]:

def parse_for_clone(path, mapping):
    """
    Load and parse for code clone detection.

    Args:
        path (str): Dataset path
        mapping (dict[int, str]): Mapping from code id to source code

    Returns:
        list[str], list[str], list[str], list[str], list[str], list[str], list[int]:
            - List of source code 1 strings
            - List of ast 1 strings
            - List of name 1 strings
            - List of source code 2 strings
            - List of ast 2 strings
            - List of name 2 strings
            - List of label integers

    """
    codes_1 = []
    asts_1 = []
    names_1 = []
    codes_2 = []
    asts_2 = []
    names_2 = []
    labels = []
    with open(path, encoding='utf-8') as f:
        for line in tqdm(f.readlines()):
            id_1, id_2, label = line.split('\t')
            try:
                source_1 = mapping[id_1]
                source_1 = remove_comments_and_docstrings(source_1, lang=LANG_JAVA)
                source_1 = replace_string_literal(source_1)
                ast_1, name_1 = generate_single_ast_nl(source=source_1, lang=LANG_JAVA)
                code_1 = tokenize_source(source=source_1, lang=LANG_JAVA)

                source_2 = mapping[id_2]
                source_2 = remove_comments_and_docstrings(source_2, lang=LANG_JAVA)
                source_2 = replace_string_literal(source_2)
                ast_2, name_2 = generate_single_ast_nl(source=source_2, lang=LANG_JAVA)
                code_2 = tokenize_source(source=source_2, lang=LANG_JAVA)

                label = int(label)

                codes_1.append(code_1)
                asts_1.append(ast_1)
                names_1.append(name_1)
                codes_2.append(code_2)
                asts_2.append(ast_2)
                names_2.append(name_2)
                labels.append(label)
            except Exception:
                continue
    return codes_1, asts_1, names_1, codes_2, asts_2, names_2, labels


In [72]:

# def parse_for_completion(source_path, target_path):
#     """
#     Load and parse for code completion.

#     Args:
#         source_path (str): Path of source
#         target_path (str): Path of target

#     Returns:
#         (list[str], list[str], list[str], list[str]):
#             - List of strings: source code
#             - List of strings: AST sequence
#             - List of strings: name sequence
#             - List of strings: target code

#     """
#     def restore_source(sub_source):
#         """
#         Transfer split source to source code, which can be parsed into AST.

#         Args:
#             sub_source (str): Split code

#         Returns:
#             str: Source code that can be parsed

#         """
#         tokens = sub_source.split()
#         is_subtoken = False
#         restored_source = ''
#         for token in tokens:
#             if token == '_':
#                 is_subtoken = True
#                 continue
#             if token == 'PRED':
#                 token = Vocab.MSK_TOKEN
#             if is_subtoken:
#                 restored_source += token.capitalize()
#             else:
#                 restored_source += f' {token}'
#             is_subtoken = False
#         return restored_source.strip()

#     source_lines = load_lines(source_path)
#     target_lines = load_lines(target_path)
#     assert len(source_lines) == len(target_lines)

#     # #######################################################################
#     # Updated to reduce the time to parse, myoungkyu song, 03/31/2024
#     if main_args.parse_subset_ratio:
#         print(main_args.parse_subset_ratio)
#         line_counter = 0
#         lines_to_extract = int(len(source_lines) * main_args.parse_subset_ratio)

#         if len(source_lines) > 10_000:
#             lines_to_extract = int(lines_to_extract * main_args.parse_subset_ratio)
#         if len(source_lines) > 100_000:
#             lines_to_extract = int(lines_to_extract * main_args.parse_subset_ratio)

#         logger.info('*' * 100)
#         logger.info(f'The size of trimmed / original fine tunning completion set to parse: {lines_to_extract} / {len(source_lines)}')
#     # #######################################################################

#     codes = []
#     asts = []
#     names = []
#     targets = []
#     for source, target in tqdm(zip(source_lines, target_lines), desc='Parsing', total=len(source_lines)):
#         try:
#             if main_args.parse_subset_ratio: # myoungkyu song, 03/31/2024
#                 if line_counter > lines_to_extract:
#                     break
#                 line_counter += 1

#             source = restore_source(source)
#             target = restore_source(target)
#             ast, name = generate_single_ast_nl(source=source, lang=LANG_JAVA)
#             codes.append(source)
#             asts.append(ast)
#             names.append(name)
#             targets.append(target)
#         except Exception:
#             continue
#     return codes, asts, names, targets

In [73]:
def parse_for_completion(source_code, target_code):
    """
    Parse source and target code for code completion.

    Args:
        source_code (str): Source code
        target_code (str): Target code

    Returns:
        (list[str], list[str], list[str], list[str]):
            - List of strings: source code
            - List of strings: AST sequence
            - List of strings: name sequence
            - List of strings: target code

    """
    def restore_source(sub_source):
        tokens = sub_source.split()
        is_subtoken = False
        restored_source = ''
        for token in tokens:
            if token == '_':
                is_subtoken = True
                continue
            if token == 'PRED':
                token = Vocab.MSK_TOKEN
            if is_subtoken:
                restored_source += token.capitalize()
            else:
                restored_source += f' {token}'
            is_subtoken = False
        return restored_source.strip()

    try:
        ast, name = generate_single_ast_nl(source=source_code, lang=LANG_JAVA)
        return [source_code], [ast], [name], [target_code]
    except Exception as e:
        # If there's any exception, return empty lists
        return [], [], [], []


In [74]:
# Parse for completion and print
source_path = "public static double Sinc(double x) {\\r\\n        return Math.sin(Math.PI * x) /"
target_path = "public static double Sinc(double x) {\\r\\n        return Math.sin(Math.PI * x) / (Math.PI * x);\\r\\n    }"
codes, asts, names, targets = parse_for_completion(source_path, target_path)
print("Parsed completion codes:", codes)
print("Parsed completion ASTs:", asts)
print("Parsed completion names:", names)
print("Parsed completion targets:", targets)

Parsed completion codes: ['public static double Sinc(double x) {\\r\\n        return Math.sin(Math.PI * x) /']
Parsed completion ASTs: ['binary_expression']
Parsed completion names: ['sinc h s']
Parsed completion targets: ['public static double Sinc(double x) {\\r\\n        return Math.sin(Math.PI * x) / (Math.PI * x);\\r\\n    }']


In [75]:

def parse_for_bug_fix(buggy_path, fixed_path):
    """
    Load and parse for bug fix.

    Args:
        buggy_path (str): Path of buggy code
        fixed_path (str): Path of fixed code

    Returns:
        (list[str], list[str], list[str], list[str]):
            - List of strings: source code
            - List of strings: AST sequence
            - List of strings: name sequence
            - List of strings: target code

    """
    buggy_lines = load_lines(buggy_path)
    fixed_lines = load_lines(fixed_path)
    assert len(buggy_lines) == len(fixed_lines)
    codes = []
    asts = []
    names = []
    targets = []
    for buggy, fixed in tqdm(zip(buggy_lines, fixed_lines), desc='Parsing', total=len(buggy_lines)):
        try:
            ast, name = generate_single_ast_nl(source=buggy, lang=LANG_JAVA)
            codes.append(buggy.lower())
            asts.append(ast)
            names.append(name.lower())
            targets.append(fixed.lower())
        except Exception:
            continue
    return codes, asts, names, targets


In [76]:
# # Parse for search and print
# dataset_dir = "/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/dataset/pre_train/java"
# lang = "java"
# split = "train"
# codes, asts, names, nls = parse_for_search(dataset_dir, lang, split)
# print("Parsed search codes:", codes)
# print("Parsed search ASTs:", asts)
# print("Parsed search names:", names)
# print("Parsed search nls:", nls)

In [81]:
# # Load clone mapping and print
# dataset_root = "/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/dataset"
# mapping = load_clone_mapping(dataset_root)
# print("Loaded clone mapping:", mapping)

Loaded clone mapping: None


In [None]:
# # Parse for clone and print
# path = "clone_dataset_path"
# codes_1, asts_1, names_1, codes_2, asts_2, names_2, labels = parse_for_clone(path, mapping)
# print("Parsed clone codes_1:", codes_1)
# print("Parsed clone ASTs_1:", asts_1)
# print("Parsed clone names_1:", names_1)
# print("Parsed clone codes_2:", codes_2)
# print("Parsed clone ASTs_2:", asts_2)
# print("Parsed clone names_2:", names_2)
# print("Parsed clone labels:", labels)

In [None]:

# # Parse for bug fix and print
# buggy_path = "buggy_path"
# fixed_path = "fixed_path"
# codes, asts, names, targets = parse_for_bug_fix(buggy_path, fixed_path)
# print("Parsed bug fix codes:", codes)
# print("Parsed bug fix ASTs:", asts)
# print("Parsed bug fix names:", names)
# print("Parsed bug fix targets:", targets)

In [71]:
import torch.utils.data
from torch.utils.data.dataset import Dataset

import os
import random
import logging
import pickle
import shutil

class CodeDataset(Dataset):

    def __init__(self, args, dataset_name, mode, task=None, language=None, split=None, clone_mapping=None):
        """
        Initialization definition.

        Args:
            args (argparse.Namespace): Arguments
            dataset_name (str): Name of the dataset
            mode (str): Training mode, ``pre_train`` or ``fine_tune``
            task (str): Dataset mode, support pre-training tasks: ['cap', 'mass', 'mnp'],
                and downstream fine-tuning task: ['summarization', 'translation'],
                future support ['completion', 'search', 'clone', 'summarization', 'translation']
            language (str): Only for downstream fine-tuning
            split (str): Only for downstream fine-tuning, support ['train', 'valid', 'test', 'codebase']
            clone_mapping (dict[int, str]): Mapping from code id to source code string, use only for clone detection

        """
        super(CodeDataset, self).__init__()
        self.args = args
        self.dataset_name = dataset_name
        self.task = task
        self.mode = mode
        self.split = split
        self.paths = {}

        # dataset dir for files, all files in this dir meeting the filename will be used as dataset files
        self.dataset_dir = os.path.join(args.dataset_root, self.mode)

        # load pre-training dataset
        if self.mode == 'pre_train':
            set_args(args=args)
            self.paths, self.languages, self.sources, self.codes, self.asts, self.names, self.codes_wo_name, \
                self.names_wo_name, self.only_names, self.docs = load_dataset_from_dir(dataset_dir=self.dataset_dir)
            self.size = len(self.codes)
        # load fine-tuning dataset
        else:
            assert split
            logger.info(f'  Loading {split} set')
            self.dataset_dir = os.path.join(self.dataset_dir, task)
            # code summarization
            if task == TASK_SUMMARIZATION:
                assert language, '\'Language\' must be specific if downstream task is code summarization'
                assert split in ['train', 'valid', 'test']
                self.dataset_dir = os.path.join(self.dataset_dir, language, split)

                self.source_path = os.path.join(self.dataset_dir, 'source.code')
                self.code_path = os.path.join(self.dataset_dir, 'token.code')
                self.nl_path = os.path.join(self.dataset_dir, 'token.docstring')

                self.paths, self.codes, self.asts, self.names, self.nls = parse_for_summarization(
                    source_path=self.source_path,
                    code_path=self.code_path,
                    nl_path=self.nl_path,
                    lang=language)
                assert len(self.codes) == len(self.asts) == len(self.names) == len(self.nls)
                self.size = len(self.codes)
            # code translation
            elif task == TASK_TRANSLATION:
                assert split in ['train', 'valid', 'test']
                assert language in ['java-c_sharp', 'c_sharp-java']
                source_lang, target_lang = language.split('-')
                java_path = f'{split}.java-cs.txt.java'
                c_sharp_path = f'{split}.java-cs.txt.cs'
                source_path = os.path.join(self.dataset_dir,
                                           c_sharp_path if source_lang == 'c_sharp' else java_path)
                target_path = os.path.join(self.dataset_dir,
                                           c_sharp_path if target_lang == 'c_sharp' else java_path)
                self.paths['source'] = source_path
                self.paths['target'] = target_path
                self.codes, self.asts, self.names, self.targets = parse_for_translation(
                    source_path=source_path,
                    source_lang=args.translation_source_language,
                    target_path=target_path,
                    target_lang=args.translation_target_language)

                assert len(self.codes) == len(self.asts) == len(self.names) == len(self.targets)
                self.size = len(self.codes)
            # code search
            elif task == TASK_SEARCH:
                assert language, '``Language`` must be specific if downstream task is code search'
                assert split in ['codebase', 'train', 'valid', 'test']
                self.dataset_dir = os.path.join(self.dataset_dir, language)
                self.paths['file'] = os.path.join(self.dataset_dir, f'{split}.jsonl')

                if split == 'codebase':
                    set_args(args=args) # Added to pass args, myoungkyu song, 03/24/2024
                    self.urls, self.codes, self.asts, self.names = parse_for_search(dataset_dir=self.dataset_dir,
                                                                                    lang=language,
                                                                                    split=split)
                    assert len(self.urls), len(self.codes) == len(self.asts) == len(self.names)
                    self.size = len(self.urls)
                elif split == 'train':
                    self.codes, self.asts, self.names, self.nls = parse_for_search(dataset_dir=self.dataset_dir,
                                                                                   lang=language,
                                                                                   split=split)
                    assert len(self.codes) == len(self.asts) == len(self.names) == len(self.nls)
                    self.size = len(self.codes)
                else:
                    self.urls, self.nls = parse_for_search(dataset_dir=self.dataset_dir, lang=language, split=split)
                    self.size = len(self.urls)
            # code clone detection
            elif task == TASK_CLONE_DETECTION:
                assert split in ['train', 'valid', 'test']
                assert clone_mapping
                path = os.path.join(self.dataset_dir, f'{split}.txt')
                self.paths['file'] = path
                self.codes_1, self.asts_1, self.names_1, \
                    self.codes_2, self.asts_2, self.names_2, self.labels = parse_for_clone(path=path,
                                                                                           mapping=clone_mapping)
                assert len(self.codes_1) == len(self.asts_1) == len(self.names_1) \
                       == len(self.codes_2) == len(self.asts_2) == len(self.names_2) == len(self.labels)
                self.size = len(self.codes_1)
            # completion
            elif task == TASK_COMPLETION:
                assert split in ['train', 'valid', 'test']
                source_path = os.path.join(self.dataset_dir, f'data.TargetType.seq.{split}.source.txt')
                target_path = os.path.join(self.dataset_dir, f'data.TargetType.seq.{split}.target.txt')
                self.paths['source'] = source_path
                self.paths['target'] = target_path
                set_args(args=args) # Added to pass args, myoungkyu song, 03/31/2024
                self.codes, self.asts, self.names, self.targets = parse_for_completion(source_path=source_path,
                                                                                       target_path=target_path)
                assert len(self.codes) == len(self.asts) == len(self.names) == len(self.targets)
                self.size = len(self.codes)
            # bug fix
            elif task == TASK_BUG_FIX:
                assert split in ['train', 'valid', 'test']
                # language here stands for dataset scale
                assert language in ['small', 'medium']
                self.dataset_dir = os.path.join(self.dataset_dir, language)
                buggy_path = os.path.join(self.dataset_dir, f'{split}.buggy-fixed.buggy')
                fixed_path = os.path.join(self.dataset_dir, f'{split}.buggy-fixed.fixed')
                self.paths['buggy'] = buggy_path
                self.paths['fixed'] = fixed_path
                self.codes, self.asts, self.names, self.targets = parse_for_bug_fix(buggy_path=buggy_path,
                                                                                    fixed_path=fixed_path)
                assert len(self.codes) == len(self.asts) == len(self.names) == len(self.targets)
                self.size = len(self.codes)
    
    def __getitem__(self, index):
        # cap
        
#Randomly decides whether to return an original AST or a different one along with the corresponding code and name.
#If an original AST is chosen, it returns the code, the original AST, the name, and a label indicating 1 (original).
#If a different AST is chosen, it returns the code, a randomly selected AST (excluding the original one), the name, and a label indicating 0 (different).
        if self.task == TASK_CODE_AST_PREDICTION:
            # print(f'[DBG] index: {index}')
            # org_ast = self.asts[index]
            # if index == 184215:
            #     self.keep_ast_index = 184215
            # elif index == 59442:
            #     print(f'[DBG] returned index: {self.keep_ast_index}')
            #     return self.codes[self.keep_ast_index], self.asts[self.keep_ast_index], self.names[self.keep_ast_index], 1

            is_ast = random.random() < 0.5
            if is_ast:
                return self.codes[index], self.asts[index], self.names[index], 1
            else:
                other_ast = self.asts[random.randint(0, self.size - 1)]
                while other_ast == self.asts[index]:
                    other_ast = self.asts[random.randint(0, self.size - 1)]
                return self.codes[index], other_ast, self.names[index], 0
        # mass
#Randomly selects a portion of the code tokens to mask.
#Constructs input tokens with a masked section.
#Returns the modified code, the AST, the name, and the masked tokens.
        elif self.task == TASK_MASS:
            # print(f'[DBG] index: {index}')

            code_tokens = self.codes[index].split()
            mask_len = int(self.args.mass_mask_ratio * len(code_tokens))
            mask_start = random.randint(0, len(code_tokens) - mask_len)
            mask_tokens = code_tokens[mask_start: mask_start + mask_len]
            input_tokens = code_tokens[:mask_start] + [Vocab.MSK_TOKEN] + code_tokens[mask_start + mask_len:]
            # print(f'[DBG] code {code_tokens}')
            # print(f'[DBG] input {input_tokens}')
            # print(f'[DBG] mask {mask_tokens}')
            return ' '.join(input_tokens), self.asts[index], self.names[index], ' '.join(mask_tokens)
        # mnp
#Returns the code without method names, the AST, the name without method names, and the original name
        elif self.task == TASK_METHOD_NAME_PREDICTION:
            return self.codes_wo_name[index], self.asts[index], self.names_wo_name[index], self.names[index]
        # summarization
        elif self.task == TASK_SUMMARIZATION:
            return self.codes[index], self.asts[index], self.names[index], self.nls[index]
            # return self.codes[index], None, None, self.nls[index]
        # translation
        elif self.task == TASK_TRANSLATION:
            return self.codes[index], self.asts[index], self.names[index], self.targets[index]
        # search
        elif self.task == TASK_SEARCH:
            if self.split == 'codebase':
                return self.split, self.urls[index], self.codes[index], self.asts[index], self.names[index]
            elif self.split == 'train':
                pos_nl = self.nls[index]
                # while True:
                #     neg_index = random.randint(0, self.size - 1)
                #     neg_nl = self.nls[neg_index]
                #     if avg_bleu(references=[pos_nl.split()], candidates=[neg_nl.split()]) < 0.5:
                #         break
                # return self.split, self.codes[index], self.asts[index], self.names[index], pos_nl, neg_nl
                return self.split, self.codes[index], self.asts[index], self.names[index], pos_nl
            else:
                return self.split, self.urls[index], self.nls[index]
        # clone detection
        elif self.task == TASK_CLONE_DETECTION:
            return self.codes_1[index], self.asts_1[index], self.names_1[index], \
                   self.codes_2[index], self.asts_2[index], self.names_2[index], self.labels[index]
        # code completion
        elif self.task == TASK_COMPLETION:
            return self.codes[index], self.asts[index], self.names[index], self.targets[index]
        # bug fix
        elif self.task == TASK_BUG_FIX:
            return self.codes[index], self.asts[index], self.names[index], self.targets[index]
#For tasks like code summarization, translation, search, etc., it returns source code, AST, name/API sequence, and target code.
#For pre-training tasks like CAP (Code-AST Prediction), MASS (Masked Sequence Prediction), and MNP (Method Name Prediction), it returns different combinations of source code, AST, name/API sequence, and masked tokens.

    def __len__(self):
        return self.size

    def set_task(self, task):
        self.task = task

    def save(self):
        """Save to binary pickle file"""
        path = os.path.join(self.args.dataset_save_dir, f'{self.dataset_name}.pk')
        with open(path, mode='wb') as f:
            pickle.dump(self, f)
        logger.info(f'Dataset saved to {path}')

    def subset(self, ratio):
        """
        Return a subset of self.

        Args:
            ratio (float): The ratio of size, must greater than 0 and less than/equal to 1

        Returns:
            Dataset: the subset

        """
        assert 0 < ratio <= 1, f'The subset ratio supposed to be 0 < ratio <= 1, but got ratio={ratio}'
        if ratio == 1:
            return self
        indices = random.sample(range(self.size), int(self.size * ratio))
        return torch.utils.data.Subset(self, indices)

In [72]:
#If a saved instance exists and load_if_saved is True, it loads and returns the saved dataset.
#If no saved instance exists or load_if_saved is False, it initializes a new dataset using CodeDataset class and saves it.
def init_dataset(args, mode, task=None, language=None, split=None, clone_mapping=None,
                 load_if_saved=True) -> CodeDataset:
    """
    Find dataset, if the dataset is saved, load and return, else initialize and return.

    Args:
        args (argparse.Namespace): Arguments
        mode (str): Training mode, ``pre_train`` or ``fine_tune``
        task (str): Dataset mode, support pre-training tasks: ['cap', 'mass', 'mnp'],
            and downstream fine-tuning task: ['summarization', 'translation'],
            future support ['completion', 'search', 'clone', 'summarization', 'translation']
        language (str): Only for downstream fine-tuning
        split (str): Only for downstream fine-tuning, support ['train', 'valid', 'test', 'codebase(only for search)']
        clone_mapping (dict[int, str]): Mapping from code id to source code string, use only for clone detection
        load_if_saved (bool): Whether to load the saved instance if it exists, default to True

    Returns:
        CodeDataset: Loaded or initialized dataset

    """
    name = '.'.join([sub_name for sub_name in [mode, task, language, split] if sub_name is not None])
    if load_if_saved:
        path = os.path.join(args.dataset_save_dir, f'{name}.pk') # '../../dataset/dataset_saved/pre_train.pk'
        path_org = os.path.join(args.dataset_save_dir, f'{name}_org.pk') # '../../dataset/dataset_saved/pre_train_org.pk'

        # #######################################################################
        # Updated it with an argument `remove_existing_saved_file`, myoungkyu song, 03/23/2024
        if os.path.exists(path) and \
                (args.remove_existing_saved_file is not None and 'fine_tune' in args.remove_existing_saved_file) and \
                ('fine_tune' in path):
            logger.info(f'Removing the existing file: {path}')
            os.remove(path)
        if os.path.exists(path) and \
                (args.remove_existing_saved_file is not None and 'pre_train' in args.remove_existing_saved_file) and \
                ('pre_train' in path):
            logger.info(f'Removing the existing file: {path}')
            os.remove(path)
        if os.path.exists(path_org) and (args.copy_existing_saved_file is not None and 'pre_train_org' in args.copy_existing_saved_file) and ('pre_train' in path):
            logger.info(f'Copying the existing file: {path_org}')
            shutil.copy(path_org, path)
            
        # #######################################################################

        if os.path.exists(path) and os.path.isfile(path):
            logger.info(f'Trying to load saved binary pickle file from: {path}')
            with open(path, mode='rb') as f:
                obj = pickle.load(f)
            assert isinstance(obj, CodeDataset)
            obj.args = args
            logger.info(f'Dataset instance loaded from: {path}')
            print_paths(obj.paths)
            return obj
    dataset = CodeDataset(args=args,
                          dataset_name=name,
                          mode=mode,
                          task=task,
                          language=language,
                          split=split,
                          clone_mapping=clone_mapping)
    dataset.save()
    return dataset

In [73]:

def print_paths(paths):
    """
    Print paths.

    Args:
        paths (dict): Dict mapping path group to path string or list of path strings.

    """
    logger.info('Dataset loaded from these files:')
    for key, value in paths.items():
        if isinstance(value, list):
            for v in value:
                logger.info(f'  {key}: {v}')
        else:
            logger.info(f'  {key}: {value}')

In [1]:

def save_all_datasets(args):
    # logger.info('*' * 100)
    # logger.info('Pre-training dataset')
    # _ = init_dataset(args=args,
    #                  mode=enums.TRAINING_MODE_PRE_TRAIN,
    #                  load_if_saved=False)
    # summarization
    for lang in [LANG_JAVA, LANG_GO, LANG_PHP, LANG_PYTHON, LANG_RUBY,
                 LANG_JAVASCRIPT]:
        for split in ['train', 'valid', 'test']:
            logger.info('*' * 100)
            logger.info(f'Summarization - {lang} - {split}')
            _ = init_dataset(args=args,
                             mode= TRAINING_MODE_FINE_TUNE,
                             task= TASK_SUMMARIZATION,
                             language=lang,
                             split=split,
                             load_if_saved=False)

In [2]:
import sys
sys.path.append('/home/user1-selab3/Documents/research-shradha/CODE-SPT-Code/spt-code/sources')

In [3]:
import torch
from torch.nn import CrossEntropyLoss, MSELoss
import torch.nn.functional as f

from transformers import BartForConditionalGeneration, BartConfig
from transformers.models.bart.modeling_bart import BartClassificationHead, shift_tokens_right
from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqSequenceClassifierOutput

Initialized Configuration:


RuntimeError: Failed to import transformers.models.bart.modeling_bart because of the following error (look up to see its traceback):
Object of type type is not JSON serializable