## edge index

In [1]:
import numpy as np
import torch


def create_board(base=0):
    board = np.zeros((8, 8), dtype=int)
    num = 0
    for i in range(0, 8, 2):
        for j in range(0, 8, 2):
            board[i, j] = num
            board[i, j + 1] = num + 1
            board[i + 1, j] = num + 2
            board[i + 1, j + 1] = num + 3
            num += 4
    return np.fliplr(np.rot90(board, k=2)) + (base * 64)


def get_neighbors(face, i, j):
    neighbors = []
    if i > 0:
        neighbors.append(face[i - 1, j])
    if i < 7:
        neighbors.append(face[i + 1, j])
    if j > 0:
        neighbors.append(face[i, j - 1])
    if j < 7:
        neighbors.append(face[i, j + 1])
    return neighbors


def get_boundary_neighbors(faces):
    boundary_neighbors = {i: [] for i in range(384)}

    # 境界の接続を定義（面1、エッジ1、面2、エッジ2、順序）
    # left ↓、up→, right ↓, down →
    connections = [
        (0, "left", 3, "right", "same"),
        (0, "up", 5, "down", "same"),
        (0, "right", 1, "left", "same"),
        (0, "down", 4, "up", "same"),
        (1, "left", 0, "right", "same"),
        (1, "up", 5, "right", "reverse"),
        (1, "right", 2, "left", "same"),
        (1, "down", 4, "right", "same"),
        (2, "left", 1, "right", "same"),
        (2, "up", 5, "up", "reverse"),
        (2, "right", 3, "left", "same"),
        (2, "down", 4, "down", "reverse"),
        (3, "left", 2, "right", "same"),
        (3, "up", 5, "left", "same"),
        (3, "right", 0, "left", "same"),
        (3, "down", 4, "left", "reverse"),
        (4, "left", 3, "down", "reverse"),
        (4, "up", 0, "down", "same"),
        (4, "right", 1, "down", "same"),
        (4, "down", 2, "down", "reverse"),
        (5, "left", 3, "up", "same"),
        (5, "up", 2, "up", "reverse"),
        (5, "right", 1, "up", "reverse"),
        (5, "down", 0, "up", "same"),
    ]

    def get_edge_coord(index, edge, reverse):
        if edge == "left":
            coord = (index, 0)
        elif edge == "right":
            coord = (index, 7)
        elif edge == "up":
            coord = (0, index)
        elif edge == "down":
            coord = (7, index)
        if reverse:
            if edge in ["left", "right"]:
                coord = (7 - index, coord[1])
            elif edge in ["up", "down"]:
                coord = (coord[0], 7 - index)
        return coord

    for face1, edge1, face2, edge2, order in connections:
        for i in range(8):
            coord1 = get_edge_coord(i, edge1, False)
            coord2 = get_edge_coord(i, edge2, order == "reverse")
            boundary_neighbors[faces[face1][coord1]].append(faces[face2][coord2])
            boundary_neighbors[faces[face2][coord2]].append(faces[face1][coord1])

    return boundary_neighbors


def create_adjacency_matrix(self_loop=True, spectral_connection=True):
    # 隣接行列の作成
    N = 384
    adjacency_matrix = np.zeros((N, N), dtype=int)

    # 各面の生成
    faces = [create_board(base=i) for i in range(6)]

    # 各面内の隣接を追加
    for face in faces:
        for i in range(8):
            for j in range(8):
                idx = face[i, j]
                neighbors = get_neighbors(face, i, j)
                for neighbor in neighbors:
                    adjacency_matrix[idx, neighbor] = 1
                    adjacency_matrix[neighbor, idx] = 1

    # 境界の隣接を追加
    boundary_neighbors = get_boundary_neighbors(faces)
    for idx, neighbors in boundary_neighbors.items():
        for neighbor in neighbors:
            adjacency_matrix[idx, neighbor] = 1
            adjacency_matrix[neighbor, idx] = 1

    assert adjacency_matrix[0, 205] == 1
    assert adjacency_matrix[13, 64] == 1
    assert adjacency_matrix[77, 128] == 1
    assert adjacency_matrix[141, 192] == 1
    assert adjacency_matrix[338, 250] == 1
    assert adjacency_matrix[336, 251] == 1
    assert adjacency_matrix[0, 306] == 1
    assert adjacency_matrix[50, 320] == 1
    assert adjacency_matrix[64, 319] == 1
    assert adjacency_matrix[118, 349] == 1

    assert adjacency_matrix.sum() == 384 * 4

    # 4つずつのまとまりについては spectral element 内で結合しているとみなす
    if spectral_connection:
        for i in range(384):
            for j in range(i + 1, 384):
                if i // 4 == j // 4:
                    adjacency_matrix[i, j] = 1
                    adjacency_matrix[j, i] = 1

    for i in range(384):
        if self_loop:
            adjacency_matrix[i][i] = 1
        else:
            adjacency_matrix[i][i] = 0

    return adjacency_matrix


def create_edge_index(self_loop=True, spectral_connection=True):
    adjacency_matrix = create_adjacency_matrix(self_loop, spectral_connection)
    edge_index = np.array(np.nonzero(adjacency_matrix))
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    return edge_index


def create_edge_attr(edge_index, is_same_spectral=True) -> torch.tensor:
    """
    output: (len(edge), 4)
    lat,lonの差をsin,cosで表現して入力
    """
    grid_path = "/kaggle/working/misc/grid_info/ClimSim_low-res_grid-info.nc"
    grid_info = xr.open_dataset(grid_path)
    latitude = grid_info["lat"].to_numpy()
    longitude = grid_info["lon"].to_numpy()
    latitude_radian = np.radians(latitude)
    longitude_radian = np.radians(longitude)

    lat_diff = latitude_radian[edge_index[0, :]] - latitude_radian[edge_index[1, :]]
    lon_diff = longitude_radian[edge_index[0, :]] - longitude_radian[edge_index[1, :]]

    lat_diff_sin = torch.tensor(np.sin(lat_diff))
    lat_diff_cos = torch.tensor(np.cos(lat_diff))
    lon_diff_sin = torch.tensor(np.sin(lon_diff))
    lon_diff_cos = torch.tensor(np.cos(lon_diff))
    is_same_spectral = edge_index[0, :] // 4 == edge_index[1, :] // 4
    edge_attr = torch.stack(
        [lat_diff_sin, lat_diff_cos, lon_diff_sin, lon_diff_cos, is_same_spectral],
        dim=1,
    )
    return edge_attr

In [5]:
edge_index = create_edge_index(self_loop=False)
edge_index

tensor([[  0,   0,   0,  ..., 383, 383, 383],
        [  1,   2,   3,  ..., 380, 381, 382]])

## edge attr

In [6]:
def create_edge_attr(edge_index) -> torch.tensor:
    """
    output: (len(edge), 4)
    lat,lonの差をsin,cosで表現して入力
    """
    grid_path = "/kaggle/working/misc/grid_info/ClimSim_low-res_grid-info.nc"
    grid_info = xr.open_dataset(grid_path)
    latitude = grid_info["lat"].to_numpy()
    longitude = grid_info["lon"].to_numpy()
    latitude_radian = np.radians(latitude)
    longitude_radian = np.radians(longitude)

    lat_diff = latitude_radian[edge_index[0, :]] - latitude_radian[edge_index[1, :]]
    lon_diff = longitude_radian[edge_index[0, :]] - longitude_radian[edge_index[1, :]]

    lat_diff_sin = torch.tensor(np.sin(lat_diff))
    lat_diff_cos = torch.tensor(np.cos(lat_diff))
    lon_diff_sin = torch.tensor(np.sin(lon_diff))
    lon_diff_cos = torch.tensor(np.cos(lon_diff))
    edge_attr = torch.stack(
        [lat_diff_sin, lat_diff_cos, lon_diff_sin, lon_diff_cos], dim=1
    )

    return edge_attr

In [7]:
import xarray as xr

create_edge_attr(edge_index)

tensor([[ 0.0594,  0.9982, -0.1953,  0.9808],
        [-0.1719,  0.9851, -0.0029,  1.0000],
        [-0.1256,  0.9921, -0.1981,  0.9802],
        ...,
        [-0.2132,  0.9770,  0.0000,  1.0000],
        [-0.0862,  0.9963,  0.2032,  0.9791],
        [-0.0862,  0.9963, -0.2032,  0.9791]], dtype=torch.float64)

## GNN model

In [9]:
"""
大体同じサイズのサンプル
"""

import torch
from torch import nn
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree


class Height60Conv(MessagePassing):
    """
    高さを表す次元は結合した状態じゃないとうまく動かない
    viewやflattenを使って整えてからkernel size 1 の1dCNNで高さを保ったまま処理
    """

    def __init__(self, base_channels, edge_channels):
        super().__init__(aggr="add")
        self.edge_channels = edge_channels
        self.base_channels = base_channels
        self.adjacency_conv = nn.Conv1d(
            base_channels + edge_channels,
            base_channels,
            kernel_size=1,
            padding="same",
            bias=False,
        )

    def forward(self, x, edge_attr, edge_index):
        """
        x: (384, base_channels*60)
        edge_attr: (len(edge_index), edge_channels)
        edge_index: (2, 384)  # これは固定
        """

        out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        return out

    def message(self, x_j, edge_attr):
        """
        x_j: (len(edge), base_channels*60)
        edge_attr: (len(edge), edge_channels)
        """
        h = x_j.view(-1, self.base_channels, 60)
        h = torch.concat([h, edge_attr.unsqueeze(-1).repeat(1, 1, 60)], dim=-2)
        h = self.adjacency_conv(h)
        h = h.flatten(start_dim=-2, end_dim=-1)
        return h


base_channels = 60
edge_channels = 4

x = torch.rand((384, base_channels * 60))
edge_attr = torch.rand((edge_index.shape[1], edge_channels))

model = Height60Conv(base_channels, edge_channels)

out = model(x, edge_attr, edge_index)

In [151]:
class GNN(nn.Module):
    def __init__(self, base_channels, n_layers=4):
        super().__init__()
        self.edge_index = create_edge_index(self_loop=True)
        self.edge_attr = create_edge_attr(self.edge_index).float()
        self.base_channels = base_channels

        edge_channels = self.edge_attr.shape[-1]
        self.conv = nn.ModuleList(
            [Height60Conv(base_channels, edge_channels) for _ in range(n_layers)]
        )

    def forward(self, x):
        """
        x: (384, n_base_channels, 60)
        edge_attr: (len(edge), edge_channels)
        edge_index: (2, len(edge))  # これは固定
        """
        edge_index = self.edge_index.to(x.device)
        edge_attr = self.edge_attr.to(x.device)
        x = x.flatten(start_dim=-2, end_dim=-1)
        for conv in self.conv:
            x = conv(x, edge_attr, edge_index)
        x = x.view(-1, self.base_channels, 60)
        return x


base_channels = 32

x = torch.rand((384, base_channels, 60))

model = GNN(base_channels)

model(x).shape

torch.Size([384, 32, 60])

In [155]:
class GNN(nn.Module):
    def __init__(self, base_channels, n_layers=4, activation="relu"):
        super().__init__()
        self.edge_index = create_edge_index(self_loop=True)
        self.edge_attr = create_edge_attr(self.edge_index).float()
        self.base_channels = base_channels
        self.activation = activation

        edge_channels = self.edge_attr.shape[-1]
        self.conv = nn.ModuleList(
            [Height60Conv(base_channels, edge_channels) for _ in range(n_layers)]
        )
        self.norms = nn.ModuleList(
            [nn.LayerNorm(base_channels * 60) for _ in range(n_layers)]
        )
        self.activations = nn.ModuleList(
            [self.get_activation(activation) for _ in range(n_layers)]
        )

    def get_activation(self, activation):
        if activation == "relu":
            return nn.ReLU()
        elif activation == "sigmoid":
            return nn.Sigmoid()
        elif activation == "tanh":
            return nn.Tanh()
        else:
            raise ValueError(f"Unsupported activation function: {activation}")

    def forward(self, x):
        """
        x: (384, n_base_channels, 60)
        edge_attr: (len(edge), edge_channels)
        edge_index: (2, len(edge))  # これは固定
        """
        edge_index = self.edge_index.to(x.device)
        edge_attr = self.edge_attr.to(x.device)

        # 一旦高さ次元をまとめて処理。conv内で分けて処理される
        x = x.flatten(start_dim=-2, end_dim=-1)
        for conv, norm, activation in zip(self.conv, self.norms, self.activations):
            x = conv(x, edge_attr, edge_index)  # Conv1d適用
            x = norm(x)
            x = activation(x)
        x = x.view(
            -1, self.base_channels, 60
        )  # (384, n_base_channels*60) -> (384, n_base_channels, 60)
        return x


base_channels = 32

x = torch.rand((384, base_channels, 60))

model = GNN(base_channels)

model(x).shape

torch.Size([384, 32, 60])

In [99]:
"""
サンプル
"""

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree


class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr="add")  # "Add" aggregation (Step 5).
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: Apply a final bias vector.
        out = out + self.bias

        return out

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j


n_feat = 5
n_out = 3


model = GCNConv(n_feat, n_out)

x = torch.rand((384, n_feat))

model(x, edge_index).shape

torch.Size([384, 3])