In [41]:
import numpy as np
import torch


def create_board(base=0):
    board = np.zeros((8, 8), dtype=int)
    num = 0
    for i in range(0, 8, 2):
        for j in range(0, 8, 2):
            board[i, j] = num
            board[i, j + 1] = num + 1
            board[i + 1, j] = num + 2
            board[i + 1, j + 1] = num + 3
            num += 4
    return np.fliplr(np.rot90(board, k=2)) + (base * 64)


def get_neighbors(face, i, j):
    neighbors = []
    if i > 0:
        neighbors.append(face[i - 1, j])
    if i < 7:
        neighbors.append(face[i + 1, j])
    if j > 0:
        neighbors.append(face[i, j - 1])
    if j < 7:
        neighbors.append(face[i, j + 1])
    return neighbors


def get_boundary_neighbors(faces):
    boundary_neighbors = {i: [] for i in range(384)}

    # 境界の接続を定義（面1、エッジ1、面2、エッジ2、順序）
    # left ↓、up→, right ↓, down →
    connections = [
        (0, "left", 3, "right", "same"),
        (0, "up", 5, "down", "same"),
        (0, "right", 1, "left", "same"),
        (0, "down", 4, "up", "same"),
        (1, "left", 0, "right", "same"),
        (1, "up", 5, "right", "reverse"),
        (1, "right", 2, "left", "same"),
        (1, "down", 4, "right", "same"),
        (2, "left", 1, "right", "same"),
        (2, "up", 5, "up", "reverse"),
        (2, "right", 3, "left", "same"),
        (2, "down", 4, "down", "reverse"),
        (3, "left", 2, "right", "same"),
        (3, "up", 5, "left", "same"),
        (3, "right", 0, "left", "same"),
        (3, "down", 4, "left", "reverse"),
        (4, "left", 3, "down", "reverse"),
        (4, "up", 0, "down", "same"),
        (4, "right", 1, "down", "same"),
        (4, "down", 2, "down", "reverse"),
        (5, "left", 3, "up", "same"),
        (5, "up", 2, "up", "reverse"),
        (5, "right", 1, "up", "reverse"),
        (5, "down", 0, "up", "same"),
    ]

    def get_edge_coord(index, edge, reverse):
        if edge == "left":
            coord = (index, 0)
        elif edge == "right":
            coord = (index, 7)
        elif edge == "up":
            coord = (0, index)
        elif edge == "down":
            coord = (7, index)
        if reverse:
            if edge in ["left", "right"]:
                coord = (7 - index, coord[1])
            elif edge in ["up", "down"]:
                coord = (coord[0], 7 - index)
        return coord

    for face1, edge1, face2, edge2, order in connections:
        for i in range(8):
            coord1 = get_edge_coord(i, edge1, False)
            coord2 = get_edge_coord(i, edge2, order == "reverse")
            boundary_neighbors[faces[face1][coord1]].append(faces[face2][coord2])
            boundary_neighbors[faces[face2][coord2]].append(faces[face1][coord1])

    return boundary_neighbors


def create_adjacency_matrix(self_loop=True, spectral_connection=True):
    # 隣接行列の作成
    N = 384
    adjacency_matrix = np.zeros((N, N), dtype=int)

    # 各面の生成
    faces = [create_board(base=i) for i in range(6)]

    # 各面内の隣接を追加
    for face in faces:
        for i in range(8):
            for j in range(8):
                idx = face[i, j]
                neighbors = get_neighbors(face, i, j)
                for neighbor in neighbors:
                    adjacency_matrix[idx, neighbor] = 1
                    adjacency_matrix[neighbor, idx] = 1

    # 境界の隣接を追加
    boundary_neighbors = get_boundary_neighbors(faces)
    for idx, neighbors in boundary_neighbors.items():
        for neighbor in neighbors:
            adjacency_matrix[idx, neighbor] = 1
            adjacency_matrix[neighbor, idx] = 1

    assert adjacency_matrix[0, 205] == 1
    assert adjacency_matrix[13, 64] == 1
    assert adjacency_matrix[77, 128] == 1
    assert adjacency_matrix[141, 192] == 1
    assert adjacency_matrix[338, 250] == 1
    assert adjacency_matrix[336, 251] == 1
    assert adjacency_matrix[0, 306] == 1
    assert adjacency_matrix[50, 320] == 1
    assert adjacency_matrix[64, 319] == 1
    assert adjacency_matrix[118, 349] == 1

    assert adjacency_matrix.sum() == 384 * 4

    # 4つずつのまとまりについては spectral element 内で結合しているとみなす
    if spectral_connection:
        for i in range(384):
            for j in range(i + 1, 384):
                if i // 4 == j // 4:
                    adjacency_matrix[i, j] = 1
                    adjacency_matrix[j, i] = 1

    for i in range(384):
        if self_loop:
            adjacency_matrix[i][i] = 1
        else:
            adjacency_matrix[i][i] = 0

    return adjacency_matrix


def create_edge_index(self_loop=True, spectral_connection=True):
    adjacency_matrix = create_adjacency_matrix(self_loop, spectral_connection)
    edge_index = np.array(np.nonzero(adjacency_matrix))
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    return edge_index


def create_edge_attr(edge_index, is_same_spectral=True) -> torch.tensor:
    """
    output: (len(edge), 4)
    lat,lonの差をsin,cosで表現して入力
    """
    grid_path = "/kaggle/working/misc/grid_info/ClimSim_low-res_grid-info.nc"
    grid_info = xr.open_dataset(grid_path)
    latitude = grid_info["lat"].to_numpy()
    longitude = grid_info["lon"].to_numpy()
    latitude_radian = np.radians(latitude)
    longitude_radian = np.radians(longitude)

    lat_diff = latitude_radian[edge_index[0, :]] - latitude_radian[edge_index[1, :]]
    lon_diff = longitude_radian[edge_index[0, :]] - longitude_radian[edge_index[1, :]]

    lat_diff_sin = torch.tensor(np.sin(lat_diff))
    lat_diff_cos = torch.tensor(np.cos(lat_diff))
    lon_diff_sin = torch.tensor(np.sin(lon_diff))
    lon_diff_cos = torch.tensor(np.cos(lon_diff))
    is_same_spectral = edge_index[0, :] // 4 == edge_index[1, :] // 4
    edge_attr = torch.stack(
        [lat_diff_sin, lat_diff_cos, lon_diff_sin, lon_diff_cos, is_same_spectral],
        dim=1,
    )
    return edge_attr

In [42]:
a = create_adjacency_matrix()

In [43]:
edge = create_edge_index(spectral_connection=True)

In [44]:
cnt = 0
for i in range(edge.shape[1]):
    e = edge[:, i].numpy()
    cnt += 1
    print(e[0], e[1])

0 0
0 1
0 2
0 3
0 205
0 306
1 0
1 1
1 2
1 3
1 4
1 307
2 0
2 1
2 2
2 3
2 16
2 207
3 0
3 1
3 2
3 3
3 6
3 17
4 1
4 4
4 5
4 6
4 7
4 310
5 4
5 5
5 6
5 7
5 8
5 311
6 3
6 4
6 5
6 6
6 7
6 20
7 4
7 5
7 6
7 7
7 10
7 21
8 5
8 8
8 9
8 10
8 11
8 314
9 8
9 9
9 10
9 11
9 12
9 315
10 7
10 8
10 9
10 10
10 11
10 24
11 8
11 9
11 10
11 11
11 14
11 25
12 9
12 12
12 13
12 14
12 15
12 318
13 12
13 13
13 14
13 15
13 64
13 319
14 11
14 12
14 13
14 14
14 15
14 28
15 12
15 13
15 14
15 15
15 29
15 66
16 2
16 16
16 17
16 18
16 19
16 221
17 3
17 16
17 17
17 18
17 19
17 20
18 16
18 17
18 18
18 19
18 32
18 223
19 16
19 17
19 18
19 19
19 22
19 33
20 6
20 17
20 20
20 21
20 22
20 23
21 7
21 20
21 21
21 22
21 23
21 24
22 19
22 20
22 21
22 22
22 23
22 36
23 20
23 21
23 22
23 23
23 26
23 37
24 10
24 21
24 24
24 25
24 26
24 27
25 11
25 24
25 25
25 26
25 27
25 28
26 23
26 24
26 25
26 26
26 27
26 40
27 24
27 25
27 26
27 27
27 30
27 41
28 14
28 25
28 28
28 29
28 30
28 31
29 15
29 28
29 29
29 30
29 31
29 80
30 27
30 28
30 29
30

In [37]:
cnt

2304