In [None]:
import numpy as np
import seaborn as sns
from copy import deepcopy
import os
from pathlib import Path
from tqdm.auto import tqdm
import sys
import copy
import networkx as nx




def get_obj_mbs(obj, cp=True):
    return sys.getsizeof(copy.deepcopy(obj) if cp else obj) / (1<<20)

In [None]:
root = Path("./npz_all/npz")
# collection = "layout/xla"
collection = "layout/nlp"
# ctype = "default"
ctype = "random"

In [None]:
def expand_nodes(data, expand=1):
    in_node_ids = data["node_config_ids"]

    for _ in range(expand):
        in_edge_index = data["edge_index"][np.isin(data["edge_index"], in_node_ids).any(1)]
        in_node_ids = np.unique(in_edge_index)

    return in_node_ids, in_edge_index


def prune_graph(data, expand=1):
    # print("Pruning graph...")
    new_data = deepcopy(dict(data))
    # print("Original graph has {} nodes and {} edges".format(data["node_feat"].shape[0], data["edge_index"].shape[0]))

    in_node_ids, in_edge_index = expand_nodes(data, expand=expand)

    # assert len(set(data["node_config_ids"]) - set(in_node_ids)) == 0

    lookup = np.ones(data["node_feat"].shape[0]) * -1
    lookup[in_node_ids] = np.arange(in_node_ids.shape[0])

    in_node_feats = data["node_feat"][in_node_ids, :]
    in_node_opcode = data["node_opcode"][in_node_ids]
    in_edge_index = lookup[in_edge_index]
    in_node_config_ids = lookup[data["node_config_ids"]]

    new_data["node_feat"] = in_node_feats
    new_data["node_opcode"] = in_node_opcode
    new_data["edge_index"] = in_edge_index
    new_data["node_config_ids"] = in_node_config_ids
    print("New graph has {}/{} nodes and {}/{} edges".format(data["node_feat"].shape[0], new_data["node_feat"].shape[0], data["edge_index"].shape[0], new_data["edge_index"].shape[0]))
    return new_data


In [None]:
def get_subgraph(G, nodes):
    H = nx.DiGraph()
    
    # For each node of interest, find all reachable nodes from it
    for src in nodes:
        reachable = nx.single_source_shortest_path_length(G, src)
        
        # For each node that's reachable and is also in our nodes_of_interest list, add an edge
        for dst, _ in reachable.items():
            if dst in nodes_of_interest and src != dst:
                H.add_edge(src, dst)

    return H

def prune_graph_new(data):
    new_data = deepcopy(dict(data))

    G = nx.DiGraph()
    G.add_edges_from(data["edge_index"])

    nodes = expand_nodes(data["node_config_ids"], expand=1)
    H = get_subgraph(G, nodes)

    in_node_ids = np.array(list(H.nodes))
    in_edge_index = np.array(list(H.edges))

    lookup = np.ones(data["node_feat"].shape[0]) * -1
    lookup[in_node_ids] = np.arange(in_node_ids.shape[0])

    in_node_feats = data["node_feat"][in_node_ids, :]
    in_node_opcode = data["node_opcode"][in_node_ids]
    in_edge_index = lookup[in_edge_index]
    in_node_config_ids = lookup[data["node_config_ids"]]

    new_data["node_feat"] = in_node_feats
    new_data["node_opcode"] = in_node_opcode
    new_data["edge_index"] = in_edge_index
    new_data["node_config_ids"] = in_node_config_ids

    return new_data

In [None]:
def remove_dupplicated_node_configs(data):
    reshaped_config_feat = data["node_config_feat"].reshape(data["node_config_feat"].shape[0], -1) + 2 # avoid zeros
    positional_array = np.random.random(reshaped_config_feat.shape[1])  # multiply each value by its position to avoid removing permutations by accident
    reshaped_values = (reshaped_config_feat * positional_array[None, :]).sum(1)
    is_equal_matrix = reshaped_values[None, :] == reshaped_values[:, None] # quadratic matrix of all pairwise equalities
    # is_equal_matrix[np.triu_indices(is_equal_matrix.shape[0], 0)] = 0 # only get diagonal to avoid remove twice
    is_equal_matrix = np.tril(is_equal_matrix, -1) # only get diagonal to avoid remove twice
    to_remove_ids = np.unique(np.where(is_equal_matrix)[0])
    # print("Removing {} duplicated node configs out of {}".format(to_remove_ids.shape[0], data["node_config_feat"].shape[0]))
    data["config_runtime"] = np.delete(data["config_runtime"], to_remove_ids)
    data["node_config_feat"] = np.delete(data["node_config_feat"], to_remove_ids, axis=0)
    return data

In [None]:
def find_duplicate_rows(data):
    matrix = data["node_config_feat"].reshape(data["node_config_feat"].shape[0], -1).astype(np.int32)

    # Get unique rows and inverse index
    _, unique_idx, inverse = np.unique(matrix, axis=0, return_index=True, return_inverse=True)
    
    # Create a dictionary of duplicates
    duplicates = {}
    for i, inv in enumerate(inverse):
        if list(np.where(inverse == inv)[0]) != [i]:
            duplicates.setdefault(unique_idx[inv], []).append(i)
    
    # Filter out entries with only one index (i.e., unique rows)
    dup_config_dct = {k: np.array(v) for k, v in duplicates.items() if len(v) > 1}

    all_dup_idx = [v[v != k] for k, v in dup_config_dct.items()]
    all_dup_idx = np.concatenate(all_dup_idx) if len(all_dup_idx) else []

    return dup_config_dct, all_dup_idx


def dedup_configs(data):
    dup_config_dct, all_dup_idx = find_duplicate_rows(data)

    for org_idx, idx_list in dup_config_dct.items():
        data["config_runtime"][org_idx] = round(np.mean(data["config_runtime"][idx_list]))

    if len(all_dup_idx):
        data["config_runtime"] = np.delete(data["config_runtime"], all_dup_idx)
        data["node_config_feat"] = np.delete(data["node_config_feat"], all_dup_idx, axis=0)

    return data


def test_dedup_configs(data):
    res = remove_dupplicated_node_configs(copy.deepcopy(data))["node_config_feat"].shape == dedup_configs(copy.deepcopy(data))["node_config_feat"].shape
    assert res
    return res


In [None]:
def vec_to_int(vec: np.ndarray) -> np.ndarray:
    # Powers of 7: [1, 7, 49, 343, 2401, 16807]
    powers_of_7 = np.array([7**i for i in range(6)])
    return np.dot(vec, powers_of_7).astype(np.int32)


def int_to_vec(integers: np.ndarray) -> np.ndarray:
    # Create an empty array of shape (N, 6) to store the results
    vectors = np.empty((len(integers), 6), dtype=np.int64)

    # Divide by powers of 7 and take the remainder to find each digit
    for i in range(6):
        vectors[:, i] = integers % 7
        integers //= 7

    return vectors.astype(np.int32)


def compress_configs(node_configs):
    vecs = node_configs.reshape(-1, 6).astype(np.int32) + 1
    ints = vec_to_int(vecs)
    ints = ints.reshape(node_configs.shape[0], node_configs.shape[1], 3)
    return ints


def decompress_configs(node_configs):
    ints = node_configs.astype(np.int32).reshape(-1)
    vecs = int_to_vec(ints)
    vecs = vecs.reshape(node_configs.shape[0], -1, 18) - 1
    return vecs


def test_compression(data, db=False):
    org = data["node_config_feat"].astype(np.int32)
    comp = compress_configs(data["node_config_feat"])
    decomp = decompress_configs(comp)

    if db:
        print(org.shape, comp.shape, decomp.shape)
        print(org[0, :2], comp[0, :2], decomp[0, :2], sep="\n")
        print(get_obj_mbs(org), get_obj_mbs(comp), get_obj_mbs(decomp))
    
    res = (org == decomp).all()

    assert res
    assert round(get_obj_mbs(org) / get_obj_mbs(comp)) == 6

    return res


In [None]:
dst_dir = root / f"{collection}_pruned_new" / ctype
for split in ["train", "valid", "test"]:
    print("Loading {} data...".format(split))
    split_src_dir = root / collection / ctype / split
    split_dst_dir = dst_dir / split
    split_dst_dir.mkdir(parents=True, exist_ok=True)

    for npz_path in tqdm(list(split_src_dir.glob("*.npz"))):
        out_p = split_dst_dir / npz_path.name

        if out_p.exists():
            continue

        data = dict(np.load(str(npz_path), allow_pickle=True))

        # data = prune_graph(data, expand=3)
        data = prune_graph_new(data, expand=1)
        
        # if split == "train":
        #       data = dedup_configs(data)
        #     # assert test_dedup_configs(data)

        if split != "test":
            data = dedup_configs(data)

        # if split == "valid":
        #     if len(data["config_runtime"]) <= 10000:
        #         sel_idx = np.arange(data["config_runtime"].shape[0])
        #     else:
        #         best_idx = np.argsort(data["config_runtime"])[:5000]
        #         sel_idx = np.random.choice(np.arange(data["config_runtime"].shape[0]), 5000, replace=False)
        #         sel_idx = np.unique(np.concatenate([best_idx, sel_idx]))

        #     print(data["config_runtime"].shape[0], sel_idx.shape[0])

        #     data["node_config_feat"] = data["node_config_feat"][sel_idx]
        #     data["config_runtime"] = data["config_runtime"][sel_idx]

        # assert test_compression(data)
        data["node_config_feat"] = compress_configs(data["node_config_feat"])

        # np.savez(split_dst_dir / npz_path.name, **data)
        np.savez_compressed(split_dst_dir / npz_path.name, **data)



In [None]:
%matplotlib inline
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt

def contains_cycle(edge_matrix):
    G = nx.Graph()
    G.add_edges_from(edge_matrix)
    try:
        nx.find_cycle(G)
        return True
    except nx.NetworkXNoCycle:
        return False

def plot_graph(edge_matrix, node_names):
    G = nx.Graph()
    G.add_edges_from(edge_matrix)
    
    # Using a different layout and adjusting its parameters
    pos = nx.kamada_kawai_layout(G)
    
    plt.figure(figsize=(12,12))  # Increase as needed
    
    # Reducing node size and edge width
    nx.draw(G, pos, with_labels=False, node_size=10, width=0.1)
    
    plt.axis('off')  # Turn off the axis
    plt.show()


plot_graph(data["edge_index"], data["node_opcode"])
