In [1]:
import tree_sitter_java as tsjava
from tree_sitter import Language, Parser

In [2]:
JAVA_LANGUAGE  = Language(tsjava.language())
parser = Parser(JAVA_LANGUAGE)

In [3]:
java_code = """
public class Example {
    public int add(int a, int b) {
        return a + b;
    }
}
"""

In [7]:
def parse_all_java_code(code):
    # Parse the code to get the syntax tree
    tree = parser.parse(code)
    root_node = tree.root_node

    # Extract tokens by traversing the syntax tree
    tokens = extract_all_class(root_node, code)
    return tokens

def extract_all_tokens(node, code):
    tokens = []

    # Capture the text of every node, regardless of whether it’s named or not
    if len(node.children) == 0:  # This check ensures we only take leaf nodes (individual tokens)
        token_text = code[node.start_byte:node.end_byte].decode('utf8')
        tokens.append(token_text)

    # Recursively process child nodes
    for child in node.children:
        tokens.extend(extract_all_tokens(child, code))

    return tokens

def extract_all_class(node, code):
    classes = []

    # Capture the text of every node, regardless of whether it’s named or not
    if node.type == "class_declaration":
        classes.append(code[node.start_byte:node.end_byte].decode('utf8'))  # This check ensures we only take leaf nodes (individual tokens)

    # Recursively process child nodes
    for child in node.children:
        classes.extend(extract_all_class(child, code))

    return classes



In [8]:
# Parse the Java code and print out tokens
tokens = parse_all_java_code(bytes(java_code, "utf8"))
for token in tokens:
    print(token)

public class Example {
    public int add(int a, int b) {
        return a + b;
    }
}


In [9]:
type(token)

str

In [10]:
len(token)

87

In [11]:
type(java_code)

str

In [12]:
len(java_code)

89

In [13]:
import torch
from unixcoder import UniXcoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UniXcoder("microsoft/unixcoder-base")
model.to(device)

  from .autonotebook import tqdm as notebook_tqdm

A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.1.3 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/home/zzhang30/.conda/envs/tda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/zzhang30/.conda/envs/tda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/zzhang30/.conda/envs/tda/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/home/zzhang30/.conda/envs/tda/lib/python3.10/site-packages/traitle

UniXcoder(
  (model): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(51416, 768, padding_idx=1)
      (position_embeddings): Embedding(1026, 768, padding_idx=1)
      (token_type_embeddings): Embedding(10, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm

In [14]:
tokens_ids = model.tokenize([token],max_length=512,mode="<encoder-only>")
source_ids = torch.tensor(tokens_ids).to(device)
tokens_embeddings,max_func_embedding = model(source_ids)

In [15]:
print(max_func_embedding.shape)
print(max_func_embedding)

torch.Size([1, 768])
tensor([[ 1.2696e+00, -2.3301e-01,  1.8232e+00, -4.4726e-01, -9.9912e-01,
         -5.2362e-01,  2.3983e+00, -1.2501e+00,  1.3665e+00,  9.0616e-01,
          1.8832e-01,  1.6184e+00, -6.1772e-01,  9.0358e-01, -2.1050e+00,
          4.5995e+00, -1.7061e+00, -4.5903e-01,  7.2220e-01, -5.7143e-01,
         -2.2592e+00, -4.4027e-01,  9.4130e-01, -9.7455e-01,  1.4747e-01,
         -9.1937e-01, -3.7102e-01,  1.4971e+00, -3.8413e-01,  1.3043e+00,
         -1.4114e+00, -1.7055e-01, -1.8694e-01, -8.6831e-01, -4.1051e-02,
          2.2829e+00, -8.0621e-01, -1.6337e+00,  6.9517e-01, -1.6069e-02,
          5.5674e-01, -1.5100e-01,  1.2971e+00,  1.5624e+00, -5.0009e-01,
          6.8947e-01,  1.1770e+00,  1.8670e+00,  5.7725e-01, -1.2897e+00,
         -1.6898e+00, -2.8330e-01,  2.3026e-01, -3.9383e-01, -4.6569e-01,
         -1.1339e+00, -1.7961e+00, -5.9333e-01,  1.4201e+00, -5.4858e-01,
          1.4784e+00,  2.3811e-02, -1.7851e+00,  2.1208e+00, -2.0515e+00,
         -1.2704e

In [16]:
with open('ant-ivy/src/java/org/apache/ivy/Ivy.java', 'r') as f:
    java_code = f.read()

# Parse the code to get the syntax tree
tree = parser.parse(bytes(java_code, "utf8"))
root_node = tree.root_node

In [17]:
def extract_code_fragments(node, code, level: str):
    fragments = []
    if level == "class" and node.type == "class_declaration":
        fragments.append(code[node.start_byte:node.end_byte].decode('utf8'))
    elif level == "method" and node.type == "method_declaration":
        fragments.append(code[node.start_byte:node.end_byte].decode('utf8'))
    elif level == "token" and len(node.children) == 0 and (node.type != 'block_comment' and node.type != 'line_comment'):
        fragments.append(code[node.start_byte:node.end_byte].decode('utf8'))
    
    for child in node.children:
        fragments.extend(extract_code_fragments(child, code, level))
    
    return fragments

In [19]:
frag = extract_code_fragments(root_node, bytes(java_code, "utf8"), level = "class")

In [20]:
len(frag)

1

In [21]:
tokens_ids = model.tokenize(frag,max_length=512,mode="<encoder-only>")
source_ids = torch.tensor(tokens_ids).to(device)
tokens_embeddings,max_func_embedding = model(source_ids)
print(max_func_embedding.shape)
print(max_func_embedding)

torch.Size([1, 768])
tensor([[-2.4892e-01,  3.1219e+00,  2.7004e+00,  1.2355e+00,  2.1337e+00,
          4.7657e-01, -1.4245e+00,  4.8459e-01,  8.9200e-01,  6.8004e-02,
         -1.5642e-01,  1.7327e-01, -6.2941e-01, -1.3353e+00, -3.0683e+00,
          3.7050e-01,  2.2509e+00,  1.9055e+00, -6.7900e-01,  1.8871e+00,
          4.9967e-01,  1.1181e+00,  7.5263e-01, -1.1676e+00, -1.0094e+00,
          1.7055e+00, -4.0090e-01,  1.1298e+00, -7.9223e-01, -8.5194e-01,
          5.8504e-01, -1.6198e+00,  1.0759e+00, -1.6877e+00,  1.1285e-01,
          1.1941e+00,  7.3064e-01, -3.4761e-01, -1.8249e+00,  1.0853e+00,
          1.8725e+00,  3.1203e-02, -8.3108e-01,  1.0319e+00,  8.9844e-01,
         -2.3439e-01,  2.5451e-01, -2.0299e+00, -6.8937e-01, -5.2779e-01,
         -1.9567e+00,  1.0081e+00,  4.1833e-01,  4.7369e-01, -2.6472e-01,
          3.1131e-01, -3.6007e-01, -3.5514e+00,  3.9939e+00,  1.1234e+00,
         -1.4195e+00, -4.0785e-01,  3.6380e-01, -6.5640e-01,  1.0852e+00,
          9.4433e

In [22]:
print(frag)

