In [None]:
import json
import torch
from torch import Tensor
from tqdm.notebook import tqdm
import os
import re
import pickle
import numpy as np
from gensim.models import Word2Vec
import random

# BCB dataset

node_type_dict = {'MethodDeclaration' : 0, 'FormalParameter' : 1, 'LocalVariableDeclaration' : 2, 'VariableDeclarator' : 3, 
                  'BinaryOperation' : 4, 'IfStatement' : 5, 'BlockStatement' : 6, 'StatementExpression' : 7, 'Assignment' : 8, 
                  'MethodInvocation' : 9, 'ForStatement' : 10, 'TryStatement' : 11, 'ClassCreator' : 12, 'CatchClause' : 13, 
                  'WhileStatement' : 14, 'ReturnStatement' : 15, 'OtherStmt' : 16}

stmt_list = node_type_dict.keys()
model = Word2Vec.load("models/word2vec_model.bin")

def find_root_name(list):
    root_name = None
    for i in list:
        if i in stmt_list:
            root_name = i
            break
    
    if root_name is None:
        root_name = 'OtherStmt'
    return root_name


def get_stmt_sequence(node_sequence):
    stmt_index = []
    word_bag = []
    
    for i in range(len(node_sequence)):
        if node_sequence[i] in stmt_list:
            stmt_index.append(i)
    
    pairs = [[stmt_index[i], stmt_index[i+1]] for i in range(len(stmt_index)-1)]
    if len(stmt_index) >= 2:
        if pairs[-1][1] != len(node_sequence):
            pairs.append([pairs[-1][1], len(node_sequence)])
        for i in range(len(pairs)):
            if pairs[i][1] - pairs[i][0] <= 2 and i != len(pairs) - 1:
                pairs[i+1][0] = pairs[i][0]
            else:
                word_bag.append(node_sequence[pairs[i][0]:pairs[i][1]])
    else:
        word_bag.append(node_sequence)
    
    return word_bag


def compute_Embedding(path):
    with open(path, 'rb') as file_1:
        node_dict = pickle.load(file_1)

#     embeddings = torch.zeros(17, 100)
    embeddings = torch.zeros(58, 100)

    for k, v in node_dict.items():
        node_sequence = v[1]
        centrality = v[2]  # 该社区的中心性
        word_bag = get_stmt_sequence(node_sequence)
        for sequence in word_bag:
            root_name = find_root_name(sequence) # 子树根节点的名称
          
            index = node_type_dict[root_name]  # 在嵌入维度中的行索引

            embedding = np.mean([model.wv[word] for word in sequence], axis=0)

            embedding = torch.tensor(embedding)
            embedding = embedding * centrality

            embeddings[index] = embeddings[index] + embedding

    return embeddings


def main():
    input_path = "processed_data/node_pkl_subgraph/"
#     output_path = "processed_data/embeddings/"
    output_path = "processed_data/embeddings_3/"

    files = os.listdir(input_path)
    in_files = os.listdir(output_path)
    process_file = []
    print(len(files))
    for i in files:
        result = re.findall(r'\d+\w', i)
        if result[0] + '.pt' not in in_files:
            process_file.append(i)
    
#     print(process_file)
    i = 0
    for index, file in tqdm(enumerate(process_file), total=len(process_file), desc="Processing"):
        embeddings = compute_Embedding(input_path + file)
        
        a = embeddings.shape
        if a[0] == 16:
            i += 1
        if index == 0:
            print(a[0])
        file_name = re.findall(r'\d+\w', file)
        torch.save(embeddings, output_path + file_name[0] + '.pt')
    
    print(i)


main()

In [None]:
import json
import torch
from torch import Tensor
from tqdm.notebook import tqdm
import os
import re
import pickle
import numpy as np
from gensim.models import Word2Vec
import random

# GCJ dataset

node_type_dict = {'MethodDeclaration' : 0, 'FormalParameter' : 1, 'LocalVariableDeclaration' : 2, 'VariableDeclarator' : 3, 
                  'BinaryOperation' : 4, 'IfStatement' : 5, 'BlockStatement' : 6, 'StatementExpression' : 7, 'Assignment' : 8, 
                  'MethodInvocation' : 9, 'ForStatement' : 10, 'TryStatement' : 11, 'ClassCreator' : 12, 'CatchClause' : 13, 
                  'WhileStatement' : 14, 'ReturnStatement' : 15, 'OtherStmt' : 16}

stmt_list = node_type_dict.keys()
model = Word2Vec.load("models/word2vec_model_gcj.bin")

def find_root_name(list):
    root_name = None
    for i in list:
        if i in stmt_list:
            root_name = i
            break
    
    if root_name is None:
        root_name = 'OtherStmt'
    return root_name


def get_stmt_sequence(node_sequence):
    stmt_index = []
    word_bag = []
    
    for i in range(len(node_sequence)):
        if node_sequence[i] in stmt_list:
            stmt_index.append(i)
    
    pairs = [[stmt_index[i], stmt_index[i+1]] for i in range(len(stmt_index)-1)]
    if len(stmt_index) >= 2:
        if pairs[-1][1] != len(node_sequence):
            pairs.append([pairs[-1][1], len(node_sequence)])
        for i in range(len(pairs)):
            if pairs[i][1] - pairs[i][0] <= 2 and i != len(pairs) - 1:
                pairs[i+1][0] = pairs[i][0]
            else:
                word_bag.append(node_sequence[pairs[i][0]:pairs[i][1]])
    else:
        word_bag.append(node_sequence)
    
    return word_bag


def compute_Embedding(path):
    with open(path, 'rb') as file_1:
        node_dict = pickle.load(file_1)

    embeddings = torch.zeros(17, 100)

    for k, v in node_dict.items():
        node_sequence = v[1]
        centrality = v[2]  # 该社区的中心性
        word_bag = get_stmt_sequence(node_sequence)
        for sequence in word_bag:
            root_name = find_root_name(sequence) # 子树根节点的名称
          
            index = node_type_dict[root_name]  # 在嵌入维度中的行索引

#             embedding = np.mean([model.wv[word] for word in sequence], axis=0)
            embedding = np.sum([model.wv[word] for word in sequence], axis=0)

            embedding = torch.tensor(embedding)
            embedding = embedding * centrality

            embeddings[index] = embeddings[index] + embedding

    return embeddings


def main():
    for s in range(12):
        input_path = "processed_data/GCJ/GCJ_json/" + str(s+1) + '/'
    #     output_path = "processed_data/embeddings/"
        output_path = "processed_data/GCJ/embeddings_1/googlejam4_src/" + str(s+1) + '/'

        files = os.listdir(input_path)

        i = 0
        for index, file in tqdm(enumerate(files), total=len(files), desc="Processing"):
            embeddings = compute_Embedding(input_path + file)

            a = embeddings.shape
#             file_name = re.findall(r'\d+\w', file)
            torch.save(embeddings, output_path + file + '.pt')

        print(i)


main()