In [1]:
import pickle
import numpy as np
import pandas as pd
import warnings
import itertools
import random
import gc
import torch
import os
from copy import deepcopy
from torch import nn
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch_geometric.nn import GCNConv, Sequential
from tqdm.notebook import tqdm
from pathlib import Path
from dataclasses import dataclass, field
from matplotlib import pyplot as plt
import seaborn as sns
import wandb
from dataclasses import asdict

sns.set()

warnings.simplefilter("ignore")

## データセットを準備


In [2]:
rootdir = Path().resolve().parent.parent
inputdir = rootdir / "data" / "predict-ai-model-runtime"
node_feat_dir = rootdir / "data" / "google-slow-vs-fast-layout-7-85"
trans_node_feat_dir = rootdir / "data" / "google-slow-vs-fastlayout6-92-dataset"
trans_node_config_feat_dir = rootdir / "data" / "google-slow-vs-fastlayout7-81-dataset"
workdir = Path().resolve() / "out"
workdir.mkdir(exist_ok=True, parents=True)

In [3]:
dataset_dict = {}
ignores = []
for ds in ["train", "valid", "test"]:
    records = []
    for arch, perm in itertools.product(["nlp", "xla"], ["default", "random"]):
        datadir = inputdir / f"npz_all/npz/layout/{arch}/{perm}/{ds}"
        for filepath in sorted(datadir.glob("*.npz")):
            filename = str(filepath).split("/")[-1].replace(".npz", "")

            if (ds != "test") and (("mlperf" in filename) or ("openai" in filename)):
                ignores.append(filepath)
                continue
            records.append(
                {
                    "arch": arch,
                    "perm": perm,
                    "filename": filename,
                    "filepath": filepath,
                    "node_feat_filepath": str(
                        node_feat_dir / arch / perm / ds / f"{filename}.npz"
                    ),
                    "trans_node_feat_filepath": str(
                        trans_node_feat_dir
                        / "layout"
                        / arch
                        / perm
                        / ds
                        / f"{filename}.npz"
                    ),
                    "trans_node_config_filepath": str(
                        trans_node_config_feat_dir
                        / arch
                        / perm
                        / ds
                        / f"{filename}.npz"
                    ),
                }
            )
    dataset_dict[ds] = pd.DataFrame(records)

In [4]:
# for filepath in tqdm(ignores):
#     node_config_feat = np.load(filepath)["node_config_feat"]

#     for i in range(1, node_config_feat.shape[0]):
#         if not (node_config_feat[0] == node_config_feat[i]).all():
#             filepath
#             break

In [5]:
dfcat = pd.DataFrame(
    [
        {"number": 0, "num_dims": 1, "num_cats": 19, "cats": list(range(19))},
        {"number": 1, "num_dims": 54 + 14, "num_cats": 6, "cats": list(range(6))},
    ]
)
dfcat.head()

Unnamed: 0,number,num_dims,num_cats,cats
0,0,1,19,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,..."
1,1,68,6,"[0, 1, 2, 3, 4, 5]"


In [6]:
dfcat_config = pd.DataFrame(
    [
        {
            "number": 0,
            "num_dims": 18,
            "num_cats": 8,
        },  # output_layout, input_layout, kernel_layout
    ]
)
dfcat_config

Unnamed: 0,number,num_dims,num_cats
0,0,18,8


In [7]:
for ds in dataset_dict:
    for i, row in dataset_dict[ds].iterrows():
        np.load(row["filepath"])
        np.load(row["node_feat_filepath"])
        np.load(row["trans_node_feat_filepath"])
        np.load(row["trans_node_config_filepath"])

# データクラスを定義


In [22]:
@dataclass
class CatStatus:
    dfcat: pd.DataFrame
    prefix: str
    num_cat_dict: dict[str, int] = field(init=False)
    index_dict: dict[str, list[int]] = field(init=False)

    def __post_init__(self) -> None:
        self.num_cat_dict, self.index_dict = {}, {}
        dim_start = 0
        for i, row in self.dfcat.iterrows():
            self.num_cat_dict[f"{self.prefix}cat_feat{i + 1}"] = row["num_cats"]
            self.index_dict[f"{self.prefix}cat_feat{i + 1}"] = list(
                range(dim_start, dim_start + row["num_dims"])
            )
            dim_start += row["num_dims"]


cat_status = CatStatus(dfcat=dfcat, prefix="")
cat_config_status = CatStatus(dfcat=dfcat_config, prefix="config_")


@dataclass
class Const:
    num_node_flag_feat_dim: int
    num_node_cont_feat_dim: int
    num_node_cat_feat_dim: int
    num_node_config_cont_feat_dim: int

    # 演算子の種類
    num_operations: int = 120
    # 各configの次元数
    num_config_dims: int = 6


fileobj = np.load(dataset_dict["train"].iloc[0]["node_feat_filepath"])
trans_fileobj = np.load(dataset_dict["train"].iloc[0]["trans_node_feat_filepath"])
trans_config_fileobj = np.load(
    dataset_dict["train"].iloc[0]["trans_node_config_filepath"]
)

node_flag_feat, node_cont_feat = fileobj["node_flag_feat"], fileobj["node_cont_feat"]
node_enum_feat, node_dimension_number_feat = (
    fileobj["node_enum_feat"],
    fileobj["node_dimension_number_feat"],
)
trans_node_cont_feat, trans_node_cat_feat = (
    trans_fileobj["node_cont_feat"],
    trans_fileobj["node_cat_feat"],
)
trans_node_config_cont_feat = trans_config_fileobj["node_config_cont_feat"]
const = Const(
    num_node_flag_feat_dim=node_flag_feat.shape[1] + 1,  # config_idsの分+1
    num_node_cont_feat_dim=node_cont_feat.shape[1] + trans_node_cont_feat.shape[1],
    num_node_cat_feat_dim=node_enum_feat.shape[1]
    + node_dimension_number_feat.shape[1]
    + trans_node_cat_feat.shape[1],
    num_node_config_cont_feat_dim=trans_node_config_cont_feat.shape[2],
)


@dataclass
class NodeFeatExtractor:
    dims: list[int] = field(default_factory=lambda: [64, 64])
    leakyrelu_negative_slope: float = 0.1
    dropout_p: float = 0.2


@dataclass
class GNNExtractor:
    dims: list[int] = field(default_factory=lambda: [64, 64])
    leakyrelu_negative_slope = 0.1
    dropout_p: float = 0.2


@dataclass
class CatEmbedding:
    num_cat: int
    embedding_dim: int


@dataclass
class Params:
    device: str
    cat_embeddings: dict[str, CatEmbedding]
    random_batch_size: int = 30
    batch_size: int = 30
    node_feat_extractor: NodeFeatExtractor = field(
        default_factory=lambda: NodeFeatExtractor()
    )
    node_config_feat_extractor: NodeFeatExtractor = field(
        default_factory=lambda: NodeFeatExtractor()
    )
    gnn_extractor: GNNExtractor = field(default_factory=lambda: GNNExtractor())
    subgraph_extractor: NodeFeatExtractor = field(
        default_factory=lambda: NodeFeatExtractor()
    )
    epoch: int = 20
    T_max: int = 20
    eta_min: float = 1e-5
    lr: float = 1e-3
    weight_decay: float = 0
    grad_clip_max_norm: float = 1.0
    grad_clip_norm_type: float = 2.0


cat_embeddings = {}
cat_embeddings.update(
    {"op": CatEmbedding(num_cat=const.num_operations, embedding_dim=16)}
)
cat_embeddings.update(
    {
        k: CatEmbedding(num_cat=v, embedding_dim=16)
        for k, v in cat_status.num_cat_dict.items()
    }
)
cat_embeddings.update(
    {
        k: CatEmbedding(num_cat=v, embedding_dim=16)
        for k, v in cat_config_status.num_cat_dict.items()
    }
)
params = Params(
    device="cuda" if torch.cuda.is_available() else "cpu",
    cat_embeddings=cat_embeddings,
)


@dataclass
class LayoutConfigs:
    """
    Attributes
    ----------
    node_cont_feat: np.ndarray
        ノード特徴量、(ノード数, 108)

    node_cat_feat: np.ndarray
        ノード特徴量、(ノード数, 3)

    node_opcode: np.ndarray
        ノード演算子、(ノード数,)
    edge_index: np.ndarray
        エッジ、(エッジ数, 2)

    node_config_feat: np.ndarray
        設定毎のノード特徴量、(設定数, 設定可能なノード数, 3)

    node_config_ids: np.ndarray
        設定可能なノードのIndex、(設定可能なノード数,)
    config_runtime: np.ndarray
        実行時間、(設定数,)
    node_splits: np.ndarray
        同じパーティションでの計算を意味する。今回は使用しない。(パーティション数, 2)
    """

    node_flag_feat: np.ndarray
    node_cont_feat: np.ndarray
    node_cat_feat: np.ndarray
    node_opcode: np.ndarray
    edge_index: np.ndarray
    node_config_feat: np.ndarray
    node_config_cont_feat: np.ndarray
    node_config_ids: np.ndarray
    config_runtime: np.ndarray
    node_splits: np.ndarray

    cat_status: CatStatus
    cat_config_status: CatStatus
    target: np.ndarray = field(init=False)
    argsorted_indexs: list[int] = field(init=False)

    NUM_SAMPLES: int = 1000

    def __post_init__(self) -> None:
        # 設定が存在するノードのフラグ
        node_active_feat = np.zeros((self.num_nodes, 1))
        node_active_feat[self.node_config_ids, :] = 1
        self.node_flag_feat = np.concatenate(
            [self.node_flag_feat, node_active_feat], axis=1
        )
        self.node_cont_feat = self.apply_normalization(x=self.node_cont_feat)
        self.node_config_feat = self.node_config_feat + 1  # カテゴリは0~7にする
        self.node_splits = np.array(
            [
                [self.node_splits[0][i], self.node_splits[0][i + 1] - 1]
                for i in range(self.node_splits.shape[1] - 1)
            ]
        )
        self.target = self.apply_target_normalization(x=self.config_runtime)
        self.argsorted_indexs = np.argsort(self.config_runtime).tolist()

    @property
    def num_nodes(self) -> int:
        """ノード数"""
        return self.node_cont_feat.shape[0]

    def get_random_config_idxs(self) -> list[int]:
        """tpu_graphのサンプリング方法
        https://github.com/google-research-datasets/tpu_graphs/blob/main/tpu_graphs/baselines/layout/data.py#L352
        """
        num_configs = self.config_runtime.shape[0]
        num_samples = min(self.NUM_SAMPLES, num_configs)
        third = num_samples // 3

        middle_samples = np.random.choice(
            self.argsorted_indexs[third:-third], num_samples - 2 * third
        ).tolist()
        samples = (
            self.argsorted_indexs[:third]
            + self.argsorted_indexs[-third:]
            + middle_samples
        )
        samples = random.sample(samples, len(samples))

        return samples

    def get_filled_node_config_feat(
        self, index_list: list[int]
    ) -> tuple[np.ndarray, np.ndarray]:
        """指定された設定の設定毎のノード特徴量を取得する。設定がない場合は補完する。
        Parameters
        ----------
        index_list: list[int]
            設定のIndex

        Returns
        -------
        np.ndarray [(len(index_list),ノード数, 18), (len(index_list),ノード数, 連続次元数)]
        """
        # (サンプル数, ノード数) x 3
        node_config_feat = np.full(
            (len(index_list), self.num_nodes, Const.num_config_dims * 3),
            Const.num_config_dims + 1,
        )
        node_config_feat[:, self.node_config_ids] = self.node_config_feat[
            index_list, :, :
        ]

        node_config_cont_feat = np.zeros(
            (len(index_list), self.num_nodes, self.node_config_cont_feat.shape[2])
        )
        node_config_cont_feat[:, self.node_config_ids] = self.node_config_cont_feat[
            index_list, :, :
        ]
        return node_config_feat, node_config_cont_feat

    def get_target(self, index_list: list[int]) -> np.ndarray:
        """指定された設定の目的変数を取得する

        Parameters
        ----------
        index_list: list[int]
            設定のIndex

        Returns
        -------
        np.ndarray
        """
        return self.apply_target_ranking(x=self.config_runtime[index_list])

    def apply_normalization(self, x: np.ndarray) -> np.ndarray:
        """特徴量の正規化

        Parameters
        ----------
        x: np.ndarray
            2次元行列

        Returns
        -------
        x: np.ndarray
            行方向に正規化された行列
        """
        x /= 128
        x = np.where(x >= 0, np.log1p(x / 128), -np.log1p(-x / 128))
        return x

    def apply_target_normalization(self, x: np.ndarray) -> np.ndarray:
        """目的変数の正規化

        Parameters
        ----------
        x: np.ndarray
            ベクトル

        Returns
        -------
        x: np.ndarray
            正規化されたベクトル
        """
        return np.log(x / x.min())

    def apply_target_ranking(self, x: np.ndarray) -> np.ndarray:
        """降順でランキング"""
        return np.argsort(np.argsort(-x))

## データセットを定義


In [9]:
class LayoutDataset(Dataset):
    """
    Attributes
    ----------
    rows: list[dict[str, np.ndarray]]
        設定をリストでもつ
    """

    def __init__(
        self,
        dataset: pd.DataFrame,
        params: Params,
        cat_status: CatStatus,
        cat_config_status: CatStatus,
    ) -> None:
        self.rows = dataset.to_dict("records")
        self.params = params
        self.cat_status = cat_status
        self.cat_config_status = cat_config_status
        self.cache_idx = None
        self.cache_filepath = None

    @property
    def device(self) -> str:
        return self.params.device

    def __len__(self) -> int:
        return len(self.rows)

    def create_layout_config(self, idx: int) -> LayoutConfigs:
        if self.cache_idx != idx:
            self.cache_idx = idx
            fileobj = np.load(self.rows[self.cache_idx]["filepath"])
            node_feat_fileobj = np.load(self.rows[self.cache_idx]["node_feat_filepath"])
            trans_feat_fileobj = np.load(
                self.rows[self.cache_idx]["trans_node_feat_filepath"]
            )
            trans_config_feat_fileobj = np.load(
                self.rows[self.cache_idx]["trans_node_config_filepath"]
            )

            node_cont_feat = np.concatenate(
                [
                    node_feat_fileobj["node_cont_feat"],
                    trans_feat_fileobj["node_cont_feat"],
                ],
                axis=1,
            )

            node_cat_feat = np.concatenate(
                [
                    node_feat_fileobj["node_enum_feat"],
                    node_feat_fileobj["node_dimension_number_feat"],
                    trans_feat_fileobj["node_cat_feat"],
                ],
                axis=1,
            )

            self.cache_layout_config = LayoutConfigs(
                node_opcode=fileobj["node_opcode"],
                edge_index=fileobj["edge_index"],
                node_config_ids=fileobj["node_config_ids"],
                config_runtime=fileobj["config_runtime"],
                node_splits=fileobj["node_splits"],
                node_flag_feat=node_feat_fileobj["node_flag_feat"],
                node_cont_feat=node_cont_feat,
                node_cat_feat=node_cat_feat,
                node_config_feat=fileobj["node_config_feat"],
                node_config_cont_feat=trans_config_feat_fileobj[
                    "node_config_cont_feat"
                ],
                cat_status=self.cat_status,
                cat_config_status=self.cat_config_status,
            )
        return self.cache_layout_config

    def __getitem__(
        self, idx: int
    ) -> tuple[
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
    ]:
        raise NotImplementedError()

    def getitem_as_random_batch(
        self, idx: int
    ) -> list[
        tuple[
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
        ]
    ]:
        layout_configs = self.create_layout_config(idx=idx)

        index_list = layout_configs.get_random_config_idxs()
        for i_chunk in range(0, len(index_list), self.params.random_batch_size):
            chunk_index_list = index_list[
                i_chunk : i_chunk + self.params.random_batch_size
            ]
            yield self._get_tensors(
                layout_configs=layout_configs, index_list=chunk_index_list
            )

    def getitem_as_batch(
        self, idx: int
    ) -> list[
        tuple[
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
        ]
    ]:
        """設定をバッチで取得する"""
        layout_configs = self.create_layout_config(idx=idx)

        index_list = list(range(layout_configs.config_runtime.shape[0]))
        for i_chunk in range(0, len(index_list), self.params.batch_size):
            chunk_index_list = index_list[i_chunk : i_chunk + self.params.batch_size]
            yield self._get_tensors(
                layout_configs=layout_configs, index_list=chunk_index_list
            )

    def _get_tensors(
        self, layout_configs: LayoutConfigs, index_list: list[int]
    ) -> tuple[
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
    ]:
        """渡された設定のIndexのテンソルを取得する

        Parameters
        ----------
        layout_configs: LayoutConfigs
            Layoutのデータクラス
        index_list: list[int]
            設定のインデックス

        Returns
        -------
        torch.Tensor
            ノード特徴量(フラグ)
        torch.Tensor
            ノード特徴量(連続)
        dict[str, torch.Tensor]
            ノード特徴量(カテゴリ)
        torch.Tensor
            設定毎のノード特徴量
        torch.Tensor
            設定毎のノード特徴量(連続)
        torch.Tensor
            ノード演算子
        torch.Tensor
            エッジ
        torch.Tensor
            目的変数
        """
        # ノード特徴量(フラグ)
        node_flag_feat = torch.tensor(
            layout_configs.node_flag_feat,
            dtype=torch.float32,
        ).to(self.device)
        # ノード特徴量(連続)
        node_cont_feat = torch.tensor(
            layout_configs.node_cont_feat,
            dtype=torch.float32,
        ).to(self.device)
        # ノード特徴量(カテゴリ)
        node_cat_feat = torch.tensor(
            layout_configs.node_cat_feat,
            dtype=torch.int64,
        ).to(self.device)
        # 設定毎のノード特徴量(カテゴリ)
        (
            node_config_feat,
            node_config_cont_feat,
        ) = layout_configs.get_filled_node_config_feat(index_list=index_list)
        node_config_feat = torch.tensor(node_config_feat, dtype=torch.int64).to(
            self.device
        )
        node_config_cont_feat = torch.tensor(
            node_config_cont_feat, dtype=torch.float32
        ).to(self.device)
        # ノード演算子
        node_opcode = torch.tensor(layout_configs.node_opcode, dtype=torch.int64).to(
            self.device
        )
        # エッジ
        edge_index = torch.tensor(
            np.swapaxes(layout_configs.edge_index, 0, 1), dtype=torch.int64
        ).to(self.device)
        # サブグラフ
        node_splits = torch.tensor(layout_configs.node_splits, dtype=torch.int64).to(
            self.device
        )
        # ターゲット
        target = torch.tensor(
            layout_configs.get_target(index_list=index_list),
            dtype=torch.float32,
        ).to(self.device)

        return (
            node_opcode,
            node_flag_feat,
            node_cont_feat,
            node_cat_feat,
            node_config_feat,
            node_config_cont_feat,
            edge_index,
            node_splits,
            target,
        )

    def get_ith_file_info(self, i: int) -> dict[str, str]:
        row = self.rows[i]
        return {
            "arch": row["arch"],
            "perm": row["perm"],
            "filename": row["filename"],
        }

    def get_ith_runtime(self, i: int) -> np.ndarray:
        layout_configs = self.create_layout_config(idx=i)
        return layout_configs.config_runtime

## モデルを定義


In [10]:
from torch_geometric.nn import MessagePassing


class EdgeConv(MessagePassing):
    """
    ノード特徴 + 隣接ノード特徴 + 隣接ノード特徴の一致
    参考： https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html#implementing-the-edge-convolution
    補足: 集約関数はデフォルトでdim(axis) = -2。つまりノード方向で集約するので気にしなくてOK
    https://github.com/pyg-team/pytorch_geometric/blob/1e12d41c28b1fb9793f17646b018071b508864d7/torch_geometric/nn/aggr/basic.py#L38
    """

    def __init__(self, x_input_dim: int, x_output_dim: int, dropout_p: float):
        # "Add" aggregation
        super().__init__(aggr="max")
        self.mlp = nn.Sequential(
            # nn.LayerNorm(x_input_dim * 2),
            nn.Linear(x_input_dim * 2, x_output_dim),
            # nn.Dropout(dropout_p),
            nn.ReLU(),
            # nn.LayerNorm(x_output_dim),
            nn.Linear(x_output_dim, x_output_dim),
            # nn.Dropout(dropout_p),
        )

    def forward(self, x, edge_index):
        # x has shape [設定数, N, in_channels]
        # edge_index has shape [2, E]
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        """propagate()で渡された引数xから自動でx_i, x_jノードを取り出して随時処理を実装する関数"""
        # x_i has shape [設定数, エッジ数, in_channels]
        # x_j has shape [設定数, エッジ数, in_channels]
        x_cat = torch.cat(
            [x_i, x_i - x_j], dim=2
        )  # tmp has shape [設定数, エッジ数, 2 * in_channels]
        return self.mlp(x_cat)


class SimpleLayoutModel(torch.nn.Module):
    """

    Attributes
    ----------
    params: Params
        実験設定のデータクラス
    node_embeddings: torch.Tensor
        カテゴリ変数の埋め込み表現(ノード毎)
    node_config_embeddings: torch.Tensor
        カテゴリ変数の埋め込み表現(設定xノード毎)
    node_feat_extractor: torch.nn.Module
        ノードの特徴量を抽出するネットワーク
    gnn_extractor: torch.nn.Module
        グラフの特徴量を抽出するネットワーク
    gc: torch.nn.Module
        最終層の全結合層
    """

    def __init__(
        self,
        params: Params,
        const: Const,
        cat_status: CatStatus,
        cat_config_status: CatStatus,
    ) -> None:
        super().__init__()
        self.params = params
        self.cat_status = cat_status
        self.cat_config_status = cat_config_status

        # カテゴリ変数の埋め込み表現
        self.embeddings = nn.ModuleDict(
            {
                k: torch.nn.Embedding(v.num_cat, v.embedding_dim)
                for k, v in self.params.cat_embeddings.items()
            }
        )

        # node_featのfeature_extractorを定義
        num_node_feat_extractor_input_dim = (
            const.num_node_flag_feat_dim
            + const.num_node_cont_feat_dim
            + self.num_node_feat_embedding_dims
        )

        node_feat_extractor_layer = []
        node_feat_extractor_dims = [
            num_node_feat_extractor_input_dim
        ] + self.params.node_feat_extractor.dims
        for i in range(len(node_feat_extractor_dims) - 1):
            node_feat_extractor_layer += [
                # nn.LayerNorm(node_feat_extractor_dims[i]),
                nn.Linear(
                    in_features=node_feat_extractor_dims[i],
                    out_features=node_feat_extractor_dims[i + 1],
                ),
                # nn.Dropout(params.node_feat_extractor.dropout_p),
                nn.LeakyReLU(params.node_feat_extractor.leakyrelu_negative_slope),
            ]
            self.node_feat_extractor = nn.Sequential(*node_feat_extractor_layer)

        # node_config_featのfeature_extractorを定義
        num_node_config_feat_extractor_input_dim = (
            self.num_node_config_feat_embedding_dims
            + const.num_node_config_cont_feat_dim
        )

        node_config_feat_extractor_layer = []
        node_config_feat_extractor_dims = [
            num_node_config_feat_extractor_input_dim
        ] + self.params.node_config_feat_extractor.dims
        for i in range(len(node_feat_extractor_dims) - 1):
            node_config_feat_extractor_layer += [
                # nn.LayerNorm(node_config_feat_extractor_dims[i]),
                nn.Linear(
                    in_features=node_config_feat_extractor_dims[i],
                    out_features=node_config_feat_extractor_dims[i + 1],
                ),
                # nn.Dropout(params.node_config_feat_extractor.dropout_p),
                nn.LeakyReLU(
                    params.node_config_feat_extractor.leakyrelu_negative_slope
                ),
            ]
        self.node_config_feat_extractor = nn.Sequential(
            *node_config_feat_extractor_layer
        )

        # ノード間のfeature_extractorの定義
        num_gnn_extractor_input_dim = (
            node_feat_extractor_dims[-1] + node_config_feat_extractor_dims[-1]
        )

        gnn_extractor_layer = []
        gnn_extractor_dims = [
            num_gnn_extractor_input_dim
        ] + self.params.gnn_extractor.dims
        for i in range(len(gnn_extractor_dims) - 1):
            gnn_extractor_layer += [
                (
                    EdgeConv(
                        x_input_dim=gnn_extractor_dims[i],
                        x_output_dim=gnn_extractor_dims[i + 1],
                        dropout_p=params.gnn_extractor.dropout_p,
                    ),
                    "x, edge_index -> x",
                ),
                nn.LeakyReLU(params.gnn_extractor.leakyrelu_negative_slope),
            ]
        self.gnn_extractor = Sequential("x, edge_index", gnn_extractor_layer)

        # # サブグラフのfeature_extractorの定義
        # num_subgraph_extractor_input_dim = (
        #     self.params.gnn_extractor.dims[-1] + num_gnn_extractor_input_dim
        # )

        # subgraph_extractor_layer = []
        # subgraph_extractor_dims = [
        #     num_subgraph_extractor_input_dim
        # ] + self.params.node_feat_extractor.dims
        # for i in range(len(subgraph_extractor_dims) - 1):
        #     subgraph_extractor_layer += [
        #         # nn.LayerNorm(subgraph_extractor_dims[i]),
        #         nn.Linear(
        #             in_features=subgraph_extractor_dims[i],
        #             out_features=subgraph_extractor_dims[i + 1],
        #         ),
        #         # nn.Dropout(params.subgraph_extractor.dropout_p),
        #         nn.LeakyReLU(params.subgraph_extractor.leakyrelu_negative_slope),
        #     ]
        # self.subgraph_extractor = nn.Sequential(*subgraph_extractor_layer)

        fc_layer = [
            # nn.LayerNorm(subgraph_extractor_dims[-1]),
            # nn.Linear(in_features=subgraph_extractor_dims[-1], out_features=1),
            nn.Linear(
                in_features=self.params.gnn_extractor.dims[-1]
                + num_gnn_extractor_input_dim,
                out_features=1,
            ),
        ]
        self.fc = nn.Sequential(*fc_layer)
        self.to(self.params.device)

    @property
    def num_node_feat_embedding_dims(self) -> int:
        num_embedding_dims = 0
        num_embedding_dims += 1 * self.params.cat_embeddings["op"].embedding_dim
        for cat_name, cat_index in self.cat_status.index_dict.items():
            num_embedding_dims += (
                len(cat_index) * self.params.cat_embeddings[cat_name].embedding_dim
            )
        return num_embedding_dims

    @property
    def num_node_config_feat_embedding_dims(self) -> int:
        num_embedding_dims = 0
        for cat_name, cat_index in self.cat_config_status.index_dict.items():
            num_embedding_dims += (
                len(cat_index) * self.params.cat_embeddings[cat_name].embedding_dim
            )
        return num_embedding_dims

    def forward(
        self,
        node_opcode: torch.Tensor,
        node_flag_feat: torch.Tensor,
        node_cont_feat: torch.Tensor,
        node_cat_feat: torch.Tensor,
        node_config_feat: torch.Tensor,
        node_config_cont_feat: torch.Tensor,
        edge_index: torch.Tensor,
        node_splits: torch.Tensor,
    ) -> torch.Tensor:
        """
        Parameters
        ------
        node_flag_feat:
            ノードの特徴量(node数, フラグ次元数)
        node_cont_feat:
            ノードの特徴量(node数, 連続次元数)
        node_cat_feat:
            ノードの特徴量(node数, カテゴリ次元数*埋め込み次元数)
        node_config_feat:
            設定毎のノードの特徴量(設定数, node数, 特徴次元数)
        node_config_cont_feat:
            設定毎のノードの特徴量(設定数, node数, 連続次元数)
        edge_index:
            エッジ(2, エッジ数)
        node_splits:
            サブグラフのインデックス（サブグラフ数, 2)

        Returns:
        torch.tensor: (設定数)
        """
        # (ノード数,特徴数)のテンソルを作成
        node_feat = self._join_node_feature(
            node_opcode=node_opcode,
            node_flag_feat=node_flag_feat,
            node_cont_feat=node_cont_feat,
            node_cat_feat=node_cat_feat,
        )

        # (設定数,ノード数,特徴数)のテンソルを作成
        node_config_feat = self._join_node_config_feature(
            node_config_feat=node_config_feat,
            node_config_cont_feat=node_config_cont_feat,
        )

        # node_featの抽出器を通す
        extracted_node_feat = self.node_feat_extractor(node_feat)

        # node_config_featの抽出器を通す
        extracted_node_config_feat = self.node_config_feat_extractor(node_config_feat)

        # 設定毎のノード特徴に結合する
        extracted_feat = self._join_entire_node_config_feat(
            node_feat=extracted_node_feat,
            node_config_feat=extracted_node_config_feat,
        )

        # GNN抽出器を通す
        conved_extracted_feat = self.gnn_extractor(
            x=extracted_feat,
            edge_index=edge_index,
        )

        # 残差を足すイメージ
        concat_feat = torch.concat([extracted_feat, conved_extracted_feat], 2)

        # subgraph_global_pool_feat_list = []
        # for subgraph_start_node_idx, subgraph_end_node_idx in node_splits:
        #     subgraph_concat_feat = concat_feat[
        #         :, subgraph_start_node_idx : subgraph_end_node_idx + 1, :
        #     ]
        #     # ノードの特徴量を足し合わせる(Global mean Pooling)
        #     subgraph_global_pool_feat = torch.mean(concat_feat, dim=1)
        #     subgraph_global_pool_feat_list.append(
        #         torch.reshape(
        #             subgraph_global_pool_feat,
        #             (
        #                 subgraph_global_pool_feat.shape[0],
        #                 1,
        #                 subgraph_global_pool_feat.shape[1],
        #             ),
        #         )
        #     )
        # # （設定数,サブグラフ数,特徴数)
        # subgraph_global_pool_feat = torch.concat(subgraph_global_pool_feat_list, 1)
        # subgraph_extracted_feat = self.subgraph_extractor(subgraph_global_pool_feat)

        # ノードの特徴量を足し合わせる(Global mean Pooling)
        # global_pool_feat = torch.mean(subgraph_extracted_feat, dim=1)
        global_pool_feat = torch.mean(concat_feat, dim=1)

        return torch.squeeze(self.fc(global_pool_feat))

    def _join_node_feature(
        self,
        node_opcode: torch.Tensor,
        node_flag_feat: torch.Tensor,
        node_cont_feat: torch.Tensor,
        node_cat_feat: torch.Tensor,
    ) -> torch.Tensor:
        """node_featのテンソルを作成"""
        # ノードの埋め込み表現
        node_embeddings_list = []
        node_embeddings_list.append(self.embeddings["op"](node_opcode))
        for cat_name, cat_index in self.cat_status.index_dict.items():
            node_embeddings = self.embeddings[cat_name](node_cat_feat[:, cat_index])
            node_embeddings = torch.reshape(
                node_embeddings,
                (-1, node_embeddings.shape[-2] * node_embeddings.shape[-1]),
            )
            node_embeddings_list.append(node_embeddings)

        # ノード毎で埋め込み、結合(ノード数, 特徴数)
        node_embedding_feat = torch.concat(node_embeddings_list, 1)
        node_feat = torch.concat(
            [node_flag_feat, node_cont_feat, node_embedding_feat], 1
        )
        return node_feat

    def _join_node_config_feature(
        self, node_config_feat: torch.Tensor, node_config_cont_feat: torch.Tensor
    ) -> torch.Tensor:
        """node_config_featのテンソルを作成"""
        # 設定xノード毎で埋め込み(設定数, ノード数, 特徴数)
        node_config_embeddings_list = []
        for cat_name, cat_index in self.cat_config_status.index_dict.items():
            node_embeddings = self.embeddings[cat_name](
                node_config_feat[:, :, cat_index]
            )
            node_embeddings = torch.reshape(
                node_embeddings,
                (
                    node_embeddings.shape[0],
                    -1,
                    node_embeddings.shape[-2] * node_embeddings.shape[-1],
                ),
            )
            node_config_embeddings_list.append(node_embeddings)
        node_config_feat = torch.concat(
            node_config_embeddings_list + [node_config_cont_feat], 2
        )
        return node_config_feat

    def _join_entire_node_config_feat(
        self, node_feat: torch.Tensor, node_config_feat: torch.Tensor
    ) -> torch.Tensor:
        # ノード毎の特徴量を設定数だけ縦に並べる
        node_tiled_feat = torch.tile(
            torch.reshape(node_feat, (1, node_feat.shape[0], node_feat.shape[1])),
            (node_config_feat.shape[0], 1, 1),
        )
        return torch.concat([node_tiled_feat, node_config_feat], 2)

## 学習


In [11]:
class ListMLE(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """

        Parameters
        ----------
        logits: torch.Tensor
            予測（要素数, ）
        labels: torch.Tensor
            目的変数（要素数, ）

        Returns
        -------
        torch.Tensor
        """
        # 正解をソート
        labels_sorted, labels_sorted_indice = labels.sort(descending=True, dim=1)
        # 予測を正解順でソート
        logits_sorted_by_true = torch.gather(logits, dim=1, index=labels_sorted_indice)
        # 予測値の最大値で予測値を引く（expの爆発予防）
        logits_max, _ = logits_sorted_by_true.max(dim=1, keepdim=True)
        logits_sorted_by_true = logits_sorted_by_true - logits_max
        # ランキングが低いものから累積する(その後正解順に戻す)
        cumsums = torch.cumsum(logits_sorted_by_true.exp().flip(dims=[1]), dim=1).flip(
            dims=[1]
        )
        # 誤差
        negative_log_likelihood = torch.sum(
            torch.log(cumsums) - logits_sorted_by_true, dim=1
        )
        return torch.mean(negative_log_likelihood)


def rankNet(y_pred, y_true):
    """
    RankNet loss introduced in "Learning to Rank using Gradient Descent".
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :return: loss value, a torch.Tensor
    """
    y_pred = y_pred.clone()
    y_true = y_true.clone()

    # here we generate every pair of indices from the range of document length in the batch
    document_pairs_candidates = list(
        itertools.product(range(y_true.shape[1]), repeat=2)
    )

    pairs_true = y_true[:, document_pairs_candidates]
    selected_pred = y_pred[:, document_pairs_candidates]

    # here we calculate the relative true relevance of every candidate pair
    true_diffs = pairs_true[:, :, 0] - pairs_true[:, :, 1]
    pred_diffs = selected_pred[:, :, 0] - selected_pred[:, :, 1]

    # here we filter just the pairs that are 'positive' and did not involve a padded instance
    # we can do that since in the candidate pairs we had symetric pairs so we can stick with
    # positive ones for a simpler loss function formulation
    the_mask = (true_diffs > 0) & (~torch.isinf(true_diffs))

    pred_diffs = pred_diffs[the_mask]

    weight = None
    # here we 'binarize' true relevancy diffs since for a pairwise loss we just need to know
    # whether one document is better than the other and not about the actual difference in
    # their relevancy levels
    true_diffs = (true_diffs > 0).type(torch.float32)
    true_diffs = true_diffs[the_mask]

    return nn.BCEWithLogitsLoss(weight=weight)(pred_diffs, true_diffs)


def to_cpu_numpy(
    params: Params, pred: torch.Tensor, truth: torch.Tensor
) -> tuple[np.ndarray, np.ndarray]:
    if params.device == "cuda":
        pred_ = pred.cpu().detach().numpy()
        truth_ = truth.cpu().detach().numpy()
        torch.cuda.empty_cache()
    else:
        pred_ = pred.detach().numpy()
        truth_ = truth.detach().numpy()
    return pred_, truth_

In [13]:
from scipy.stats import kendalltau


def evaluate_score(dataset: LayoutDataset, model: torch.nn.Module) -> pd.DataFrame:
    """データセット全件に対してコンペの評価指標を算出する
    https://www.kaggle.com/competitions/predict-ai-model-runtime/overview
    """
    model.eval()
    # criterion = ListMLE()

    records = []
    # 各グラフ毎にスコアを算出
    for graph_index in range(len(dataset)):
        # グラフ毎に1000件をバッチに分けて取得
        preds, truths = [], []
        for (
            node_opcode,
            node_flag_feat,
            node_cont_feat,
            node_cat_feat,
            node_config_feat,
            node_config_cont_feat,
            edge_index,
            node_splits,
            target,
        ) in dataset.getitem_as_random_batch(graph_index):
            pred = model(
                node_opcode=node_opcode,
                node_flag_feat=node_flag_feat,
                node_cont_feat=node_cont_feat,
                node_cat_feat=node_cat_feat,
                node_config_feat=node_config_feat,
                node_config_cont_feat=node_config_cont_feat,
                edge_index=edge_index,
                node_splits=node_splits,
            )
            pred, truth = to_cpu_numpy(params, pred, target)
            preds.append(pred)
            truths.append(truth)

        preds, truths = np.hstack(preds), np.hstack(truths)

        loss = rankNet(
            torch.tensor(preds.reshape(1, -1)),
            torch.tensor(truths.reshape(1, -1)),
        )
        graph_loss = loss.item()
        score = kendalltau(truth, pred).correlation

        record = dataset.get_ith_file_info(graph_index)
        record.update(
            {
                "graph_loss": graph_loss,
                "score": score,
            }
        )
        records.append(record)
    return pd.DataFrame(records)

### 学習


In [14]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def train_model(
    dftrain: pd.DataFrame,
    dfvalid: pd.DataFrame,
    params: Params,
    const: Const,
    cat_status: CatStatus,
    cat_config_status: CatStatus,
    savedir: Path,
    checkpoint_dir: Path = None,
) -> None:
    train_layout_dataset = LayoutDataset(
        dataset=dftrain,
        params=params,
        cat_status=cat_status,
        cat_config_status=cat_config_status,
    )
    valid_layout_dataset = LayoutDataset(
        dataset=dfvalid,
        params=params,
        cat_status=cat_status,
        cat_config_status=cat_config_status,
    )

    model = SimpleLayoutModel(
        params=params,
        const=const,
        cat_status=cat_status,
        cat_config_status=cat_config_status,
    )
    if checkpoint_dir is not None:
        print("学習済みモデルを読み込みます")
        model.load_state_dict(torch.load(checkpoint_dir / f"final_model.pt"))

    optimizer = torch.optim.Adam(
        model.parameters(), lr=params.lr, weight_decay=params.weight_decay
    )
    scheduler = CosineAnnealingLR(
        optimizer=optimizer, T_max=params.T_max, eta_min=params.eta_min
    )
    # criterion = ListMLE()

    best_score = -np.inf
    records = []
    for epoch in range(params.epoch):
        model.train()

        num_graph = len(train_layout_dataset)
        pbar = tqdm(range(num_graph))
        graph_indexes = random.sample(list(range(num_graph)), num_graph)

        epoch_losses = []
        epoch_loss = 0  # 各グラフの誤差を総和（エポックの誤差）

        # グラフをシャッフルして取得
        for i_graph, graph_index in enumerate(graph_indexes):
            graph_info = train_layout_dataset.get_ith_file_info(graph_index)
            graph_arch, graph_perm = graph_info["arch"], graph_info["perm"]
            # 各グラフで1000件をバッチに分けて取得
            preds, truths = [], []
            graph_loss = 0  # バッチの誤差を総和（グラフの誤差）
            num_batch_count = 0
            for (
                node_opcode,
                node_flag_feat,
                node_cont_feat,
                node_cat_feat,
                node_config_feat,
                node_config_cont_feat,
                edge_index,
                node_splits,
                target,
            ) in train_layout_dataset.getitem_as_random_batch(graph_index):
                out = model(
                    node_opcode=node_opcode,
                    node_flag_feat=node_flag_feat,
                    node_cont_feat=node_cont_feat,
                    node_cat_feat=node_cat_feat,
                    node_config_feat=node_config_feat,
                    node_config_cont_feat=node_config_cont_feat,
                    edge_index=edge_index,
                    node_splits=node_splits,
                )
                # loss = criterion(
                #     torch.reshape(out, (1, out.shape[0])),
                #     torch.reshape(target, (1, target.shape[0])),
                # )
                loss = rankNet(
                    torch.reshape(out, (1, out.shape[0])),
                    torch.reshape(target, (1, target.shape[0])),
                )
                loss.backward()
                graph_loss += loss.item()

                pred, truth = to_cpu_numpy(params, out, target)
                preds.append(pred)
                truths.append(truth)
                num_batch_count += 1

            # 各グラフ毎に勾配降下
            nn.utils.clip_grad_norm_(
                model.parameters(),
                max_norm=params.grad_clip_max_norm,
                norm_type=params.grad_clip_norm_type,
            )
            optimizer.step()
            scheduler.step(epoch + i_graph / num_graph)
            optimizer.zero_grad()

            preds, truths = np.hstack(preds), np.hstack(truths)
            score = kendalltau(truth, pred).correlation
            graph_loss /= num_batch_count  # 各バッチの平均をグラフの誤差とする
            epoch_loss += graph_loss

            record = {
                "epoch": epoch,
                "i_graph": i_graph,
                f"train-{graph_arch}-{graph_perm}/epoch_loss": epoch_loss
                / (i_graph + 1),
                f"train-{graph_arch}-{graph_perm}/graph_loss": graph_loss,
                f"train-{graph_arch}-{graph_perm}/score": score,
                "lr": scheduler.get_last_lr()[0],
            }
            record.update(graph_info)
            records.append(record)

            wandb.log(record)
            pbar.set_description(
                f"running loss: {epoch_loss / (i_graph + 1):.5f}, graph loss: {graph_loss:.5f} score: {score:.3f}"
            )
            pbar.update(1)

        model.eval()
        torch.cuda.empty_cache()

        dfscore = evaluate_score(dataset=valid_layout_dataset, model=model)
        avg_loss = dfscore["graph_loss"].mean()
        avg_score = dfscore["score"].mean()
        for _, row_score in dfscore.iterrows():
            graph_arch, graph_perm = row_score["arch"], row_score["perm"]
            record = {
                "epoch": epoch,
                "i_graph": -1,
                "arch": graph_arch,
                "perm": graph_perm,
                "filename": row_score["filename"],
                f"valid-{graph_arch}-{graph_perm}/epoch_loss": avg_loss,
                f"valid-{graph_arch}-{graph_perm}/graph_loss": row_score["graph_loss"],
                f"valid-{graph_arch}-{graph_perm}/score": row_score["score"],
                "lr": scheduler.get_last_lr()[0],
            }
            records.append(record)
            wandb.log(record)

        print(f"[valid] current loss: {avg_loss:.5f} score: {avg_score:.3f}")

        if best_score < avg_score:
            best_score = avg_score
            torch.save(model.state_dict(), savedir / "best_model.pt")
        torch.save(model.state_dict(), savedir / f"epoch{epoch + 1}_model.pt")

    dflog = pd.DataFrame(records)
    dflog.to_csv(savedir / "log.csv", index=False)

    torch.save(model.state_dict(), savedir / "final_model.pt")

    del (
        train_layout_dataset,
        valid_layout_dataset,
        model,
        optimizer,
        dfscore,
        dflog,
        records,
    )
    gc.collect()
    torch.cuda.empty_cache()

In [15]:
# exptname = str(Path().resolve()).split("/")[-1]

# wandb.init(
#     # set the wandb project where this run will be logged
#     project="predict-ai-model-runtime-for-sun-scan-clan",
#     # track hyperparameters and run metadata
#     config={
#         "params": asdict(params),
#         "const": asdict(const),
#         "validation": "hold-out",
#     },
#     name=exptname,
#     tags=["all"],
# )

In [16]:
# seed_everything(43)
# dftrain = dataset_dict["train"]
# dfvalid = dataset_dict["valid"]

# train_model(
#     dftrain=dftrain,
#     dfvalid=dfvalid,
#     params=params,
#     const=const,
#     cat_status=cat_status,
#     cat_config_status=cat_config_status,
#     savedir=workdir,
#     checkpoint_dir=None,
# )
# wandb.alert(title=exptname, text=f"Train End")

In [17]:
# savedir = workdir / f"{arch}-{perm}"
# dflog = pd.read_csv(savedir / "log.csv")

# fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# for i, ds in enumerate(["train", "valid"]):
#     dflog_ = dflog.query("(phase == @ds)").groupby("epoch")
#     axes[i][0].plot(dflog_["current_loss"].mean(), label="total")
#     axes[i][1].plot(dflog_["score"].mean(), label="taotal")
#     if i == 0:
#         axes[i][0].legend()
# fig.show()

## 推論


In [18]:
# savedir = workdir

# records = []

# dftest = dataset_dict["test"]

# test_layout_dataset = LayoutDataset(
#     dataset=dftest,
#     params=params,
#     cat_status=cat_status,
#     cat_config_status=cat_config_status,
# )
# model = SimpleLayoutModel(
#     params=params,
#     const=const,
#     cat_status=cat_status,
#     cat_config_status=cat_config_status,
# )
# model.load_state_dict(torch.load(workdir / "final_model.pt"))
# model.eval()

# with tqdm(range(len(test_layout_dataset))) as pbar:
#     for i in pbar:
#         file_info = test_layout_dataset.get_ith_file_info(i)

#         pred_list = []
#         for (
#             node_opcode,
#             node_flag_feat,
#             node_cont_feat,
#             node_cat_feat,
#             node_config_feat,
#             node_config_cont_feat,
#             edge_index,
#             node_splits,
#             target,
#         ) in test_layout_dataset.getitem_as_batch(i):
#             pred_batch = model(
#                 node_opcode=node_opcode,
#                 node_flag_feat=node_flag_feat,
#                 node_cont_feat=node_cont_feat,
#                 node_cat_feat=node_cat_feat,
#                 node_config_feat=node_config_feat,
#                 node_config_cont_feat=node_config_cont_feat,
#                 edge_index=edge_index,
#                 node_splits=node_splits,
#             )
#             if params.device == "cuda":
#                 pred_batch = pred_batch.cpu().detach().numpy()
#             else:
#                 pred_batch = pred_batch.detach().numpy()
#             # pred_batchは高いものほどよい
#             pred_batch = -pred_batch
#             pred_list.append(pred_batch)

#             del (
#                 node_opcode,
#                 node_flag_feat,
#                 node_cont_feat,
#                 node_cat_feat,
#                 node_config_feat,
#                 node_config_cont_feat,
#                 edge_index,
#                 node_splits,
#                 target,
#             )
#             gc.collect()
#             torch.cuda.empty_cache()

#         pred = np.hstack(pred_list)

#         ID = f"layout:{file_info['arch']}:{file_info['perm']}:{file_info['filename']}"
#         records.append({"ID": ID, "pred": ";".join(list(map(str, pred.argsort())))})

# del test_layout_dataset, model
# gc.collect()
# torch.cuda.empty_cache()

# dfpred = pd.DataFrame(records)
# dfsub = pd.read_csv(inputdir / "sample_submission.csv")
# dfsub = dfsub.merge(dfpred, on="ID", how="left")
# dfsub["TopConfigs"] = np.where(
#     dfsub["pred"].isnull(), dfsub["TopConfigs"], dfsub["pred"]
# )
# dfsub[["ID", "TopConfigs"]].to_csv(savedir / f"submission_final_model.csv", index=False)

In [19]:
# wandb.alert(title=exptname, text=f"Inference End")
# wandb.finish()

In [27]:
valid_layout_dataset = LayoutDataset(
    dataset=dataset_dict["valid"],
    params=params,
    cat_status=cat_status,
    cat_config_status=cat_config_status,
)
model = SimpleLayoutModel(
    params=params,
    const=const,
    cat_status=cat_status,
    cat_config_status=cat_config_status,
)
model.load_state_dict(torch.load(workdir / "final_model.pt"))
model.eval()

num_params = 0
for p in model.parameters():
    if p.requires_grad:
        num_params += p.numel()
num_params

147857

In [30]:
model.eval()

eval_dict = {}
# 各グラフ毎にスコアを算出
for graph_index in tqdm(range(len(valid_layout_dataset))):
    graph_info = valid_layout_dataset.get_ith_file_info(graph_index)
    arch, perm, filename = (
        graph_info["arch"],
        graph_info["perm"],
        graph_info["filename"],
    )
    # グラフ毎に1000件をバッチに分けて取得
    preds, truths = [], []
    for (
        node_opcode,
        node_flag_feat,
        node_cont_feat,
        node_cat_feat,
        node_config_feat,
        node_config_cont_feat,
        edge_index,
        node_splits,
        target,
    ) in valid_layout_dataset.getitem_as_batch(graph_index):
        pred = model(
            node_opcode=node_opcode,
            node_flag_feat=node_flag_feat,
            node_cont_feat=node_cont_feat,
            node_cat_feat=node_cat_feat,
            node_config_feat=node_config_feat,
            node_config_cont_feat=node_config_cont_feat,
            edge_index=edge_index,
            node_splits=node_splits,
        )
        pred, truth = to_cpu_numpy(params, pred, target)
        preds.append(pred)
        truths.append(truth)

    preds, truths = np.hstack(preds), np.hstack(truths)
    eval_dict[f"{arch}-{perm}-{filename}"] = {
        "pred": preds,
        "truth": truths,
    }

  0%|          | 0/52 [00:00<?, ?it/s]

In [33]:
import pickle as pkl

with open(workdir / "evaluation.pkl", "wb") as f:
    pkl.dump(eval_dict, f)

In [42]:
with open(workdir / "evaluation.pkl", "rb") as f:
    eval_dict = pkl.load(f)

In [43]:
for graph_i in range(len(valid_layout_dataset)):
    row = valid_layout_dataset.rows[graph_i]

    filepath, filename = row["filepath"], row["filename"]
    arch, perm = row["arch"], row["perm"]

    runtime = np.load(filepath)["config_runtime"]
    key = f"{arch}-{perm}-{filename}"
    argsorted = np.argsort(runtime)

    eval_dict[key]["runtime"] = runtime
    eval_dict[key]["sorted"] = argsorted

In [46]:
row = eval_dict["nlp-default-albert_en_xlarge_batch_size_16_test"]
row

{'pred': array([84.65656 , 84.65656 , 83.9334  , ..., 84.674095, 84.560555,
        84.674736], dtype=float32),
 'truth': array([ 9.,  6., 10., ...,  9.,  4.,  6.], dtype=float32),
 'runtime': array([57941591, 57946389, 57941475, ..., 57240847, 57260677, 57243437]),
 'sorted': array([16826, 17568, 36924, ..., 43862, 43871, 43793])}

In [58]:
def evaluate(pred: np.ndarray, runtime: np.ndarray) -> dict[str, float]:
    truth_sorted = np.argsort(runtime)
    top1000 = truth_sorted[:1000]

    num_samples = min(pred.shape[0], 1000)
    score_samples = []
    for _ in range(100):
        samples = random.sample(list(range(pred.shape[0])), num_samples)
        _score_sample = kendalltau(runtime[samples], -pred[samples]).correlation
        score_samples.append(_score_sample)

    third = num_samples // 3
    numbers = list(range(pred.shape[0]))
    paper_samples = (
        numbers[:third]
        + numbers[-third:]
        + random.sample(numbers[third:-third], num_samples - 2 * third)
    )
    paper_score = kendalltau(runtime[paper_samples], -pred[paper_samples]).correlation

    score = kendalltau(runtime, -pred).correlation
    top_score = kendalltau(runtime[top1000], -pred[top1000]).correlation
    return {
        "entire": score,
        "top": top_score,
        "sample_min": np.min(score_samples),
        "sample_mean": np.mean(score_samples),
        "sample_max": np.max(score_samples),
        "tpu_graph": paper_score,
    }

{'entire': 0.2977103817573796,
 'top': 0.027264926979169702,
 'sample_min': 0.2435410908853032,
 'sample_mean': 0.2939847372078976,
 'sample_max': 0.34074091871167567,
 'tpu_graph': 0.2742359213237993}

In [62]:
records = []
for key in eval_dict:
    row = eval_dict[key]
    record = evaluate(row["pred"], row["runtime"])
    arch, perm = key.split("-")[0], key.split("-")[1]
    filename = key.replace(f"{arch}-{perm}-", "")
    record.update({"arch": arch, "perm": perm, "filename": filename})
    records.append(record)
dfeval = pd.DataFrame(records)
dfeval

Unnamed: 0,entire,top,sample_min,sample_mean,sample_max,tpu_graph,arch,perm,filename
0,0.29771,0.027265,0.239272,0.3018,0.354712,0.261422,nlp,default,albert_en_xlarge_batch_size_16_test
1,0.520705,0.011956,0.458884,0.520852,0.568973,0.493009,nlp,default,bert_en_cased_L-12_H-768_A-12_batch_size_16_test
2,0.464523,0.0295,0.410444,0.463333,0.504307,0.48204,nlp,default,bert_multi_cased_L-12_H-768_A-12_batch_size_16...
3,0.571434,0.043426,0.53476,0.568467,0.614591,0.587233,nlp,default,small_bert_bert_en_uncased_L-10_H-128_A-2_batc...
4,0.25986,0.002121,0.205376,0.263365,0.319668,0.08038,nlp,default,small_bert_bert_en_uncased_L-10_H-128_A-2_batc...
5,0.460203,0.090241,0.42144,0.461021,0.51586,0.543119,nlp,default,small_bert_bert_en_uncased_L-10_H-256_A-4_batc...
6,0.318574,0.06287,0.266653,0.317277,0.37255,0.220994,nlp,default,small_bert_bert_en_uncased_L-10_H-256_A-4_batc...
7,0.35207,0.06148,0.285631,0.350225,0.40219,0.232781,nlp,default,small_bert_bert_en_uncased_L-10_H-512_A-8_batc...
8,0.479221,0.000266,0.426352,0.483275,0.533377,0.514883,nlp,default,small_bert_bert_en_uncased_L-10_H-768_A-12_bat...
9,0.47638,-0.02786,0.43125,0.477242,0.516558,0.509227,nlp,default,small_bert_bert_en_uncased_L-10_H-768_A-12_bat...


In [67]:
dfeval.to_excel(workdir / "evaluate.xlsx", index=False)