In [1]:
import numpy as np
import seaborn as sns
from copy import deepcopy
import os
from pathlib import Path
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
root = Path("/home/edu/code/google_fast_or_slow/data/npz_all/npz")
collection = "layout/xla"
ctype = "default"

In [4]:
def prune_graph(data):
    print("Pruning graph...")
    new_data = deepcopy(dict(data))
    print("Original graph has {} nodes and {} edges".format(data["node_feat"].shape[0], data["edge_index"].shape[0]))
    in_edge_index = data["edge_index"][np.isin(data["edge_index"], data["node_config_ids"]).any(1)]

    in_node_ids = np.unique(in_edge_index)
    assert len(set(data["node_config_ids"]) - set(in_node_ids)) == 0
    lookup = np.ones(data["node_feat"].shape[0]) * -1
    lookup[in_node_ids] = np.arange(in_node_ids.shape[0])

    in_node_feats = data["node_feat"][in_node_ids, :]
    in_node_opcode = data["node_opcode"][in_node_ids]
    in_edge_index = lookup[in_edge_index]
    in_node_config_ids = lookup[data["node_config_ids"]]

    new_data["node_feat"] = in_node_feats
    new_data["node_opcode"] = in_node_opcode
    new_data["edge_index"] = in_edge_index
    new_data["node_config_ids"] = in_node_config_ids
    print("New graph has {} nodes and {} edges".format(new_data["node_feat"].shape[0], new_data["edge_index"].shape[0]))
    return new_data


In [5]:
def remove_dupplicated_node_configs(data):
    reshaped_config_feat = data["node_config_feat"].reshape(data["node_config_feat"].shape[0], -1) + 2 # avoid zeros
    positional_array = np.random.random(reshaped_config_feat.shape[1])  # multiply each value by its position to avoid removing permutations by accident
    reshaped_values = (reshaped_config_feat * positional_array[None, :]).sum(1)
    is_equal_matrix = reshaped_values[None, :] == reshaped_values[:, None] # quadratic matrix of all pairwise equalities
    # is_equal_matrix[np.triu_indices(is_equal_matrix.shape[0], 0)] = 0 # only get diagonal to avoid remove twice
    is_equal_matrix = np.tril(is_equal_matrix, -1) # only get diagonal to avoid remove twice
    to_remove_ids = np.unique(np.where(is_equal_matrix)[0])
    print("Removing {} duplicated node configs out of {}".format(to_remove_ids.shape[0], data["node_config_feat"].shape[0]))
    data["config_runtime"] = np.delete(data["config_runtime"], to_remove_ids)
    data["node_config_feat"] = np.delete(data["node_config_feat"], to_remove_ids, axis=0)
    return data

In [6]:
dst_dir = root / f"{collection}_pruned" / ctype
for split in ["train", "valid", "test"]:
    print("Loading {} data...".format(split))
    split_src_dir = root / collection / ctype / split
    split_dst_dir = dst_dir / split
    split_dst_dir.mkdir(parents=True, exist_ok=True)

    for npz_path in tqdm(list(split_src_dir.glob("*.npz"))):
        print(npz_path)
        data = dict(np.load(str(npz_path), allow_pickle=True))
        data = prune_graph(data)
        if split == "train":
            data = remove_dupplicated_node_configs(data)
        np.savez(split_dst_dir / npz_path.name, **data)
        if split == "valid":
            data = remove_dupplicated_node_configs(data)
            dedup_dst_dir = Path(str(split_dst_dir).replace("valid", "valid_dedup"))
            dedup_dst_dir.mkdir(parents=True, exist_ok=True)
            np.savez(dedup_dst_dir / npz_path.name, **data)

Loading valid data...


  0%|          | 0/7 [00:00<?, ?it/s]

/home/edu/code/google_fast_or_slow/data/npz_all/npz/layout/xla/default/valid/unet_3d.4x4.bf16.npz
Pruning graph...
Original graph has 3163 nodes and 5112 edges
New graph has 223 nodes and 181 edges
Removing 496 duplicated node configs out of 1965
/home/edu/code/google_fast_or_slow/data/npz_all/npz/layout/xla/default/valid/mlperf_bert_batch_24_2x2.npz
Pruning graph...
Original graph has 19541 nodes and 30520 edges
New graph has 4807 nodes and 4498 edges
Removing 547 duplicated node configs out of 6048


 29%|██▊       | 2/7 [00:03<00:07,  1.56s/it]

/home/edu/code/google_fast_or_slow/data/npz_all/npz/layout/xla/default/valid/inception_v3_batch_128_train.npz
Pruning graph...
Original graph has 12067 nodes and 21319 edges
New graph has 1059 nodes and 1084 edges


 43%|████▎     | 3/7 [00:03<00:04,  1.16s/it]

Removing 670 duplicated node configs out of 5984
/home/edu/code/google_fast_or_slow/data/npz_all/npz/layout/xla/default/valid/resnet50.4x4.fp16.npz
Pruning graph...
Original graph has 5673 nodes and 9099 edges
New graph has 493 nodes and 495 edges


 57%|█████▋    | 4/7 [00:04<00:02,  1.14it/s]

Removing 602 duplicated node configs out of 6120
/home/edu/code/google_fast_or_slow/data/npz_all/npz/layout/xla/default/valid/resnet_v1_50_official_batch_128_bf16.npz
Pruning graph...
Original graph has 6135 nodes and 10670 edges
New graph has 488 nodes and 493 edges


 71%|███████▏  | 5/7 [00:04<00:01,  1.30it/s]

Removing 764 duplicated node configs out of 8512
/home/edu/code/google_fast_or_slow/data/npz_all/npz/layout/xla/default/valid/tf2_bert_pretrain_dynamic_batch_size.npz
Pruning graph...
Original graph has 21664 nodes and 38485 edges
New graph has 3889 nodes and 3480 edges
Removing 1821 duplicated node configs out of 18160


 86%|████████▌ | 6/7 [00:11<00:02,  2.74s/it]

/home/edu/code/google_fast_or_slow/data/npz_all/npz/layout/xla/default/valid/bert_pretraining.4x4.fp16.npz
Pruning graph...
Original graph has 21335 nodes and 37236 edges
New graph has 3538 nodes and 3283 edges
Removing 1934 duplicated node configs out of 19232


100%|██████████| 7/7 [00:18<00:00,  2.62s/it]


In [None]:
split_src_dir

In [5]:
data_default = dict(np.load(str("/home/edu/code/google_fast_or_slow/data/npz_all/npz/layout/xla/default/valid/bert_pretraining.4x4.fp16.npz"), allow_pickle=True))

In [4]:
data_random = dict(np.load(str("/home/edu/code/google_fast_or_slow/data/npz_all/npz/layout/xla/random/valid/bert_pretraining.4x4.fp16.npz"), allow_pickle=True))