In [1]:
import numpy as np
import seaborn as sns
from copy import deepcopy
import os
from pathlib import Path
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import process_map

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root = Path("/home/edu/code/google_fast_or_slow/data/npz_pad")
collection = "layout/xla"
ctype = "random"

In [3]:
def vec_to_int(vec: np.ndarray) -> np.ndarray:
    # Powers of 7: [1, 7, 49, 343, 2401, 16807]
    powers_of_7 = np.array([7**i for i in range(6)])
    return np.dot(vec, powers_of_7).astype(np.int32)


def int_to_vec(integers: np.ndarray) -> np.ndarray:
    # Create an empty array of shape (N, 6) to store the results
    vectors = np.empty((len(integers), 6), dtype=np.int64)

    # Divide by powers of 7 and take the remainder to find each digit
    for i in range(6):
        vectors[:, i] = integers % 7
        integers //= 7

    return vectors.astype(np.int32)


def compress_configs(node_configs):
    vecs = node_configs.reshape(-1, 6).astype(np.int32) + 1
    ints = vec_to_int(vecs)
    ints = ints.reshape(node_configs.shape[0], node_configs.shape[1], 3)
    return ints


def decompress_configs(node_configs):
    ints = node_configs.astype(np.int32).reshape(-1)
    vecs = int_to_vec(ints)
    vecs = vecs.reshape(node_configs.shape[0], -1, 18) - 1
    return vecs
    
def prune_graph(data):
    print("Pruning graph...")
    new_data = deepcopy(dict(data))
    print("Original graph has {} nodes and {} edges".format(data["node_feat"].shape[0], data["edge_index"].shape[0]))
    in_edge_index = data["edge_index"][np.isin(data["edge_index"], data["node_config_ids"]).any(1)]

    in_node_ids = np.unique(in_edge_index)
    assert len(set(data["node_config_ids"]) - set(in_node_ids)) == 0
    lookup = np.ones(data["node_feat"].shape[0]) * -1
    lookup[in_node_ids] = np.arange(in_node_ids.shape[0])

    in_node_feats = data["node_feat"][in_node_ids, :]
    in_node_opcode = data["node_opcode"][in_node_ids]
    in_edge_index = lookup[in_edge_index]
    in_node_config_ids = lookup[data["node_config_ids"]]

    new_data["node_feat"] = in_node_feats
    new_data["node_opcode"] = in_node_opcode
    new_data["edge_index"] = in_edge_index
    new_data["node_config_ids"] = in_node_config_ids
    print("New graph has {} nodes and {} edges".format(new_data["node_feat"].shape[0], new_data["edge_index"].shape[0]))
    return new_data


In [4]:
def remove_dupplicated_node_configs(data):
    reshaped_config_feat = data["node_config_feat"].reshape(data["node_config_feat"].shape[0], -1) + 2 # avoid zeros
    positional_array = np.random.random(reshaped_config_feat.shape[1])  # multiply each value by its position to avoid removing permutations by accident
    reshaped_values = (reshaped_config_feat * positional_array[None, :]).sum(1)
    is_equal_matrix = reshaped_values[None, :] == reshaped_values[:, None] # quadratic matrix of all pairwise equalities
    # is_equal_matrix[np.triu_indices(is_equal_matrix.shape[0], 0)] = 0 # only get diagonal to avoid remove twice
    is_equal_matrix = np.tril(is_equal_matrix, -1) # only get diagonal to avoid remove twice
    to_remove_ids = np.unique(np.where(is_equal_matrix)[0])
    print("Removing {} duplicated node configs out of {}".format(to_remove_ids.shape[0], data["node_config_feat"].shape[0]))
    data["config_runtime"] = np.delete(data["config_runtime"], to_remove_ids)
    data["node_config_feat"] = np.delete(data["node_config_feat"], to_remove_ids, axis=0)
    return data

In [5]:
dst_dir = root / f"{collection}_pruned" / ctype

for split in ["train", "valid", "test"]:
    print("Loading {} data...".format(split))
    split_src_dir = root / collection / ctype / split
    split_dst_dir = dst_dir / split
    split_dst_dir.mkdir(parents=True, exist_ok=True)

    def _process_one_npz(npz_path):
        data = dict(np.load(str(npz_path), allow_pickle=True))
        data = prune_graph(data)
        if split == "train":
            data = remove_dupplicated_node_configs(data)
        data["node_config_feat"] = compress_configs(data["node_config_feat"])
        np.savez_compressed(split_dst_dir / npz_path.name, **data)
    
    process_map(_process_one_npz, list(split_src_dir.glob("*.npz")), max_workers=3)

    # for npz_path in tqdm(list(split_src_dir.glob("*.npz"))):
        # print(npz_path)
        # data = dict(np.load(str(npz_path), allow_pickle=True))
        # data = prune_graph(data)
        # if split == "train":
        #     data = remove_dupplicated_node_configs(data)
        # data["node_config_feat"] = compress_configs(data["node_config_feat"])
        # np.savez_compressed(split_dst_dir / npz_path.name, **data)
        # if split == "valid":
        #     data = remove_dupplicated_node_configs(data)
        #     dedup_dst_dir = Path(str(split_dst_dir).replace("valid", "valid_dedup"))
        #     dedup_dst_dir.mkdir(parents=True, exist_ok=True)
        #     np.savez(dedup_dst_dir / npz_path.name, **data)

Loading train data...


  0%|          | 0/69 [00:00<?, ?it/s]

Pruning graph...
Original graph has 13502 nodes and 22209 edges
New graph has 2143 nodes and 2137 edges
Pruning graph...
Original graph has 490 nodes and 729 edges
New graph has 65 nodes and 56 edges
Removing 39 duplicated node configs out of 1632
Pruning graph...
Original graph has 4185 nodes and 7209 edges
New graph has 224 nodes and 200 edges


  1%|▏         | 1/69 [00:00<00:32,  2.07it/s]

Pruning graph...
Original graph has 6328 nodes and 12014 edges
New graph has 426 nodes and 375 edges
Removing 459 duplicated node configs out of 100040
Pruning graph...
Original graph has 13342 nodes and 21709 edges
New graph has 1820 nodes and 1622 edges
Removing 39 duplicated node configs out of 1152
Pruning graph...
Original graph has 5809 nodes and 10345 edges
New graph has 594 nodes and 599 edges
Removing 39 duplicated node configs out of 7208
Pruning graph...
Original graph has 15642 nodes and 25387 edges
New graph has 3878 nodes and 3394 edges
Removing 19 duplicated node configs out of 20
Pruning graph...
Original graph has 8847 nodes and 14797 edges
New graph has 811 nodes and 806 edges
Removing 39 duplicated node configs out of 7440
Pruning graph...
Original graph has 17472 nodes and 27130 edges
New graph has 4211 nodes and 4097 edges
Removing 39 duplicated node configs out of 156
Pruning graph...
Original graph has 40332 nodes and 71912 edges
New graph has 7304 nodes and 6804

  3%|▎         | 2/69 [00:30<20:10, 18.07s/it]

Removing 40 duplicated node configs out of 9552
Pruning graph...
Original graph has 19662 nodes and 35460 edges
New graph has 2085 nodes and 2100 edges
Pruning graph...
Original graph has 13867 nodes and 22162 edges
New graph has 995 nodes and 1045 edges
Removing 39 duplicated node configs out of 1155
Removing 39 duplicated node configs out of 5502


 16%|█▌        | 11/69 [00:31<02:06,  2.19s/it]

Pruning graph...


 20%|██        | 14/69 [00:32<01:26,  1.57s/it]

Pruning graph...
Original graph has 24790 nodes and 32709 edges
New graph has 6411 nodes and 4774 edges
Original graph has 7324 nodes and 12087 edges
New graph has 1067 nodes and 1230 edges
Removing 39 duplicated node configs out of 3718
Pruning graph...
Removing 39 duplicated node configs out of 22632
Original graph has 21196 nodes and 37779 edges
New graph has 3848 nodes and 3395 edges


 23%|██▎       | 16/69 [00:33<01:15,  1.43s/it]

Pruning graph...
Original graph has 5345 nodes and 8775 edges
New graph has 599 nodes and 601 edges
Removing 42 duplicated node configs out of 10664


 25%|██▍       | 17/69 [00:34<01:09,  1.34s/it]

Pruning graph...
Original graph has 1111 nodes and 1557 edges
New graph has 215 nodes and 164 edges
Pruning graph...
Original graph has 8836 nodes and 15567 edges
New graph has 768 nodes and 782 edges
Removing 39 duplicated node configs out of 5744
Pruning graph...
Original graph has 5605 nodes and 9018 edges
New graph has 476 nodes and 483 edges
Removing 41 duplicated node configs out of 9536
Pruning graph...
Original graph has 5162 nodes and 9160 edges
New graph has 535 nodes and 550 edges
Removing 39 duplicated node configs out of 17592Removing 39 duplicated node configs out of 8368

Pruning graph...
Original graph has 650 nodes and 1100 edges
New graph has 180 nodes and 139 edges
Removing 26983 duplicated node configs out of 31584
Pruning graph...
Original graph has 5605 nodes and 9018 edges
New graph has 476 nodes and 483 edges
Removing 39 duplicated node configs out of 7560
Pruning graph...


 26%|██▌       | 18/69 [00:39<01:39,  1.94s/it]

Original graph has 21338 nodes and 37243 edges
Pruning graph...
New graph has 3538 nodes and 3283 edges
Original graph has 4748 nodes and 8547 edges
New graph has 766 nodes and 834 edges
Removing 39 duplicated node configs out of 15112
Pruning graph...
Original graph has 17135 nodes and 26568 edges
New graph has 4211 nodes and 4097 edges
Removing 39 duplicated node configs out of 148
Pruning graph...
Original graph has 15022 nodes and 27044 edges
New graph has 1589 nodes and 1604 edges
Removing 39 duplicated node configs out of 18776
Removing 39 duplicated node configs out of 6792
Pruning graph...
Original graph has 20665 nodes and 36659 edges
New graph has 3633 nodes and 3380 edges
Pruning graph...
Removing 39 duplicated node configs out of 22944
Original graph has 12062 nodes and 21268 edges
New graph has 2424 nodes and 2518 edges
Pruning graph...
Original graph has 19348 nodes and 30351 edges
New graph has 1756 nodes and 1767 edges
Removing 39 duplicated node configs out of 2374
Pru

 35%|███▍      | 24/69 [00:56<01:48,  2.41s/it]

Pruning graph...
Original graph has 1111 nodes and 1557 edges
New graph has 215 nodes and 164 edges
Pruning graph...
Original graph has 24793 nodes and 32713 edges
New graph has 6417 nodes and 4779 edges
Removing 457 duplicated node configs out of 21736
Pruning graph...


 45%|████▍     | 31/69 [00:58<00:51,  1.36s/it]

Pruning graph...
Original graph has 22385 nodes and 39976 edges
New graph has 4023 nodes and 3532 edges
Original graph has 10092 nodes and 18102 edges
New graph has 1062 nodes and 1077 edges
Removing 39 duplicated node configs out of 11240
Removing 39 duplicated node configs out of 7040
Pruning graph...
Original graph has 1277 nodes and 2063 edges
New graph has 189 nodes and 172 edges
Removing 19 duplicated node configs out of 20
Pruning graph...
Original graph has 43615 nodes and 73881 edges
New graph has 3142 nodes and 3089 edges
Removing 39 duplicated node configs out of 1871
Removing 39 duplicated node configs out of 17088
Pruning graph...
Original graph has 14680 nodes and 23604 edges
New graph has 2964 nodes and 2521 edges
Pruning graph...
Original graph has 1383 nodes and 2167 edges
New graph has 209 nodes and 188 edges
Removing 19 duplicated node configs out of 20
Pruning graph...
Original graph has 16818 nodes and 26250 edges
New graph has 4099 nodes and 4017 edges
Removing 19

 51%|█████     | 35/69 [01:02<00:43,  1.28s/it]

Pruning graph...
Original graph has 15022 nodes and 27044 edges
New graph has 1589 nodes and 1604 edges
Pruning graph...
Removing 39 duplicated node configs out of 6080
Original graph has 8636 nodes and 14661 edges
New graph has 1188 nodes and 1086 edges
Pruning graph...
Original graph has 10449 nodes and 16824 edges
New graph has 1024 nodes and 1084 edges
Removing 39 duplicated node configs out of 2118
Pruning graph...
Original graph has 9257 nodes and 15281 edges
New graph has 1188 nodes and 1086 edges
Removing 39 duplicated node configs out of 11880
Removing 39 duplicated node configs out of 11296
Pruning graph...
Original graph has 2774 nodes and 4730 edges
New graph has 168 nodes and 168 edges
Removing 39 duplicated node configs out of 380
Pruning graph...
Pruning graph...
Original graph has 5689 nodes and 9131 edges
New graph has 476 nodes and 483 edges
Original graph has 15748 nodes and 23128 edges
New graph has 3829 nodes and 3095 edges


 59%|█████▉    | 41/69 [01:04<00:25,  1.08it/s]

Pruning graph...
Original graph has 5670 nodes and 10166 edges
New graph has 933 nodes and 1030 edges
Removing 39 duplicated node configs out of 2239
Removing 39 duplicated node configs out of 2887
Removing 39 duplicated node configs out of 7464
Pruning graph...
Original graph has 15809 nodes and 25564 edges
New graph has 3884 nodes and 3400 edges
Removing 19 duplicated node configs out of 20


 71%|███████   | 49/69 [01:04<00:11,  1.79it/s]

Pruning graph...
Original graph has 12069 nodes and 21320 edges
New graph has 1059 nodes and 1084 edges
Removing 54 duplicated node configs out of 7088
Pruning graph...
Pruning graph...
Original graph has 7324 nodes and 12087 edges
New graph has 1067 nodes and 1230 edges
Original graph has 39452 nodes and 70171 edges
New graph has 6990 nodes and 6516 edges
Pruning graph...
Original graph has 5673 nodes and 9099 edges
New graph has 493 nodes and 495 edges
Removing 39 duplicated node configs out of 7384
Pruning graph...
Original graph has 13867 nodes and 22162 edges
New graph has 995 nodes and 1045 edges
Removing 39 duplicated node configs out of 1489
Removing 39 duplicated node configs out of 9760
Removing 39 duplicated node configs out of 22992
Pruning graph...


 78%|███████▊  | 54/69 [01:08<00:09,  1.60it/s]

Pruning graph...
Original graph has 5279 nodes and 8694 edges
New graph has 582 nodes and 589 edges
Removing 39 duplicated node configs out of 9600
Original graph has 25544 nodes and 33522 edges
New graph has 9294 nodes and 6926 edges


 80%|███████▉  | 55/69 [01:09<00:09,  1.55it/s]

Pruning graph...
Original graph has 372 nodes and 597 edges
New graph has 73 nodes and 73 edges
Pruning graph...
Original graph has 21126 nodes and 37368 edges
New graph has 3698 nodes and 3437 edges
Removing 39 duplicated node configs out of 10480
Removing 39 duplicated node configs out of 18368
Removing 191 duplicated node configs out of 29144
Pruning graph...
Original graph has 7768 nodes and 13121 edges
New graph has 630 nodes and 633 edges
Removing 39 duplicated node configs out of 8376
Pruning graph...
Original graph has 5358 nodes and 9631 edges
New graph has 933 nodes and 1030 edges
Removing 39 duplicated node configs out of 2994
Pruning graph...
Original graph has 21335 nodes and 37236 edges
New graph has 3538 nodes and 3283 edges
Pruning graph...
Original graph has 26234 nodes and 48094 edges
New graph has 5529 nodes and 5136 edges
Removing 39 duplicated node configs out of 15704
Removing 39 duplicated node configs out of 19128


 86%|████████▌ | 59/69 [01:17<00:10,  1.02s/it]

Pruning graph...
Original graph has 26906 nodes and 48823 edges
New graph has 5461 nodes and 5066 edges


 94%|█████████▍| 65/69 [01:20<00:03,  1.27it/s]

Pruning graph...
Original graph has 19662 nodes and 35460 edges
New graph has 2085 nodes and 2100 edges
Removing 40 duplicated node configs out of 6015
Pruning graph...
Original graph has 8836 nodes and 15567 edges
New graph has 768 nodes and 782 edges
Removing 51 duplicated node configs out of 7584
Removing 39 duplicated node configs out of 12720


100%|██████████| 69/69 [01:23<00:00,  1.21s/it]

Loading valid data...



  0%|          | 0/7 [00:00<?, ?it/s]

Pruning graph...
Original graph has 3163 nodes and 5112 edges
New graph has 223 nodes and 181 edges
Pruning graph...
Pruning graph...
Original graph has 5673 nodes and 9099 edges
New graph has 493 nodes and 495 edges
Original graph has 12067 nodes and 21319 edges
New graph has 1059 nodes and 1084 edges
Pruning graph...
Original graph has 6135 nodes and 10670 edges
New graph has 488 nodes and 493 edges
Pruning graph...
Original graph has 19541 nodes and 30520 edges
New graph has 4807 nodes and 4498 edges
Pruning graph...
Original graph has 21335 nodes and 37236 edges
New graph has 3538 nodes and 3283 edges
Pruning graph...


 29%|██▊       | 2/7 [00:01<00:04,  1.02it/s]

Original graph has 21664 nodes and 38485 edges
New graph has 3889 nodes and 3480 edges


100%|██████████| 7/7 [00:04<00:00,  1.54it/s]

Loading test data...



  0%|          | 0/8 [00:00<?, ?it/s]

Pruning graph...
Pruning graph...
Original graph has 5279 nodes and 8694 edges
Original graph has 5810 nodes and 10345 edges
New graph has 582 nodes and 589 edges
New graph has 594 nodes and 599 edges
Pruning graph...
Original graph has 490 nodes and 749 edges
New graph has 118 nodes and 116 edges
Pruning graph...
Original graph has 43615 nodes and 73881 edges
New graph has 3142 nodes and 3089 edges
Pruning graph...
Pruning graph...
Original graph has 23363 nodes and 38984 edges
New graph has 6140 nodes and 6482 edges
Original graph has 24790 nodes and 32709 edges
New graph has 6411 nodes and 4774 edges
Pruning graph...
Original graph has 41522 nodes and 72902 edges
New graph has 6959 nodes and 6466 edges


 25%|██▌       | 2/8 [00:00<00:01,  3.83it/s]

Pruning graph...
Original graph has 43615 nodes and 73881 edges
New graph has 3142 nodes and 3089 edges


100%|██████████| 8/8 [00:00<00:00, 11.33it/s]


In [None]:
import pandas as pd
root = Path("/home/edu/code/google_fast_or_slow/data/npz_all/npz")
collection = "layout/xla_pruned"
ctype = "default"

split_src_dir = root / collection / ctype / "train"
df = []
for path in tqdm(list(split_src_dir.glob("*.npz"))):
    data = dict(np.load(str(path), allow_pickle=True))
    times = data['config_runtime']
    if len(times) < 2:
        continue
    times = np.sort(times)
    deltas = times[1:] - times[:-1]
    stats = {"file": path.stem, "median": np.median(deltas), "mean": np.mean(deltas), "std": np.std(deltas), "min": np.min(deltas), "max": np.max(deltas), "p10": np.percentile(deltas, 10), "p90": np.percentile(deltas, 90)}
    df.append(stats)
    # ids = np.where(data["node_opcode"][data["node_config_ids"].astype(int)] == 75)
    # if ids:
    #     ids = ids[0]
    #     data["node_config_feat"] = decompress_configs(data["node_config_feat"])
    #     counter.extend(data["node_config_feat"][:, ids][:, :, [2, 3, 4, 5, 8, 9, 10, 11, 14, 15, 16, 17]].reshape(-1))
df_default = pd.DataFrame(df)

In [None]:
df_random.iloc[:, 1:].mean()

In [None]:
df_default.iloc[:, 1:].mean()

In [None]:
stats # random

In [None]:
stats # default

In [None]:
np.unique(counter, return_counts=True)

In [None]:
data = dict(np.load("/home/edu/code/google_fast_or_slow/data/npz_pad/layout/xla/random/train/resnet_v2_50_batch_16.npz"))

In [None]:
data["node_config_feat"] = decompress_configs(data["node_config_feat"])

In [None]:
seqs = []
for i in data["node_config_feat"][:, 0, :]:
    seqs.append(str(i))

In [None]:
np.unique(seqs, return_counts=True)

In [None]:
data["node_opcode"][497]

In [None]:
list(enumerate(data["node_feat"][497, :]))

In [None]:
list(enumerate(data["node_feat"][497, :]))

In [None]:
data["node_config_ids"][0]

In [None]:
data["node_opcode"][data["node_config_ids"][0:450].astype(int)]

In [None]:
np.unique(data["node_opcode"][data["node_config_ids"].astype(int)], return_counts=True)

In [None]:
data["node_config_feat"].shape

In [None]:
def inspect_node(idx):
    print("Node idx: {}".format(idx))
    print("Node opcode: {}".format(data["node_opcode"][idx]))
    print("Node config id: {}".format(data["node_config_ids"][idx]))
    print("Node config feat: {}".format(data["node_config_feat"][idx]))
    print("Node feat: {}".format(data["node_feat"][idx]))

In [8]:
import torch
x = torch.load("/home/edu/code/google_fast_or_slow/outputs_xla_default/checkpoint-12000/pytorch_model.bin")

In [9]:
x.keys()

odict_keys(['embedding_op.weight', 'embedding_layout_cfg.weight', 'linear.linear.weight', 'linear.linear.bias', 'linear.cross_attn.temperature', 'linear.attn.linear1.weight', 'linear.attn.linear1.bias', 'linear.attn.linear2.weight', 'linear.attn.linear2.bias', 'convs.0.conv.lin_l.weight', 'convs.0.conv.lin_l.bias', 'convs.0.conv.lin_r.weight', 'convs.0.attn.linear1.weight', 'convs.0.attn.linear1.bias', 'convs.0.attn.linear2.weight', 'convs.0.attn.linear2.bias', 'convs.0.cross_attn.temperature', 'convs.1.conv.lin_l.weight', 'convs.1.conv.lin_l.bias', 'convs.1.conv.lin_r.weight', 'convs.1.attn.linear1.weight', 'convs.1.attn.linear1.bias', 'convs.1.attn.linear2.weight', 'convs.1.attn.linear2.bias', 'convs.1.cross_attn.temperature', 'classifier.weight', 'classifier.bias'])

In [14]:
[(_, x[_]) for _ in x.keys() if "cross_attn.temperature" in _]

[('linear.cross_attn.temperature', tensor(0.6401, device='cuda:0')),
 ('convs.0.cross_attn.temperature', tensor(0.3068, device='cuda:0')),
 ('convs.1.cross_attn.temperature', tensor(0.3477, device='cuda:0'))]

In [6]:
import numpy as np
np.random.randint(1, 5)

2

In [1]:
import pickle
with open("/home/edu/code/google_fast_or_slow/outputs_xla_random/fold_0/scaler.pkl", "rb") as f:
    data = pickle.load(f)