In [1]:
import importlib
import numpy as np
import pandas as pd
import config
importlib.reload(config)
from config import *
import torch
import xml.etree.ElementTree as ET
import re

In [2]:
def extract_method_parameters(method_signature):
    match = re.search(r'\((.*)\)', method_signature)
    if match:
        parameters = match.group(1)
        parameter_list = [param.strip() for param in parameters.split(',')]
        return parameter_list
    else:
        return []
    

def replace_generics_with_object(method_name):
    gen_types = ["T", "O", "K", "V", "E","L","M","R"]
    gen_types_arr=[]
    for gen_type in gen_types:
        gen_types_arr.append(gen_type+"[]")

    params=extract_method_parameters(method_name)
    params_replace=[]
    for param in params:
        if param in gen_types:
            params_replace.append("java.lang.Object")
        elif param in gen_types_arr:
            params_replace.append("java.lang.Object[]")
        else:
            params_replace.append(param)
    method_name=method_name[:method_name.find("(")]+f"({','.join(params_replace)})"
    return method_name


def extract_simple_name(full_name:str):
    method_name=full_name[:full_name.find("(")]
    simple_name=method_name[method_name.rfind(".")+1:]
    params=full_name[full_name.find("("):]
    simple_name=simple_name+params
    return simple_name


def method_name_match(name_xml:str,names_csv:list,simple_names_csv:list):
    name_xml_tran=replace_generics_with_object(name_xml)
    if name_xml in names_csv:
        return names_csv.index(name_xml)
    if name_xml_tran in names_csv:
        return names_csv.index(name_xml_tran)
    simple_name_xml=extract_simple_name(name_xml)
    simple_name_xml_tran=extract_method_parameters(name_xml_tran)
    if simple_name_xml in simple_names_csv:
        idx=simple_names_csv.index(simple_name_xml)
        if len(idx)>1:
            return -1
        return simple_names_csv.index(simple_name_xml)
    if simple_name_xml_tran in simple_names_csv:
        idx=simple_names_csv.index(simple_name_xml_tran)
        if len(idx)>1:
            return -1
        return simple_names_csv.index(simple_name_xml_tran)
    return -1

In [3]:
method_df=pd.read_csv(method_dir)
method_names=method_df["method name"].tolist()
method_simple_names=method_df["simple name"].tolist()
name2idx_dt={name:idx for idx,name in enumerate(method_names)}
adj=torch.eye(len(method_names),dtype=torch.float)
st_embeddings=torch.zeros((len(method_names),5),dtype=torch.float)
print(f"adj shape: {adj.shape}")
print(f"st_embeddings shape: {st_embeddings.shape}")

callgraph_xml_root=ET.parse(callgraph_dir).getroot()
namespace = {'graphml': 'http://graphml.graphdrawing.org/xmlns'}
xml_edges = callgraph_xml_root.findall(".//graphml:graph/graphml:edge",namespace)
print(f"edge nums: {len(xml_edges)}")

adj shape: torch.Size([6706, 6706])
st_embeddings shape: torch.Size([6706, 5])
edge nums: 7960


In [4]:
sources=method_df["method body"]
for i,source in enumerate(sources):
    if type(source)==str:
        st_embeddings[i][0]=len(source.split("\n"))
    else:
        st_embeddings[i][0]=0

In [5]:
not_found=set()
for xml_edge in xml_edges:
    source_xml=xml_edge.get("source")
    target_xml=xml_edge.get("target")
    source_idx=method_name_match(source_xml,method_names,method_simple_names)
    target_idx=method_name_match(target_xml,method_names,method_simple_names)
    if source_idx!=-1 and target_idx!=-1:
        adj[source_idx][target_idx]=1
        st_embeddings[source_idx][1]+=1
        st_embeddings[target_idx][2]+=1
    if source_idx==-1:
        not_found.add(source_xml)
    if target_idx==-1:
        not_found.add(target_xml)
print(f"not found size: {len(not_found)}")
for method in not_found:
    print(not_found)

not found size: 0


In [6]:
adj_no_self = adj.clone().to("cuda")
adj_no_self.fill_diagonal_(0)
st_embeddings = st_embeddings.to("cuda")

two_hop = adj_no_self @ adj_no_self

for i in range(len(method_names)):
    st_embeddings[i][3] = torch.sum(adj_no_self[i]).item()
    
    two_hop_node = (two_hop[i] > 0).float()
    two_hop_node[adj_no_self[i] > 0] = 0.0
    two_hop_node[i] = 0.0
    st_embeddings[i][4] = torch.sum(two_hop_node).item()

In [8]:
print(f"file save to dir: {st_embedding_pt_dir}")
torch.save(st_embeddings,st_embedding_pt_dir)

file save to dir: .\datasets\embeddings\commons-collections\st_embeddings.pt


In [9]:
print(project_name)
print(st_embeddings.shape)

commons-collections
torch.Size([6706, 5])
