In [2]:
import numpy as np
import torch


def create_board(base=0):
    board = np.zeros((8, 8), dtype=int)
    num = 0
    for i in range(0, 8, 2):
        for j in range(0, 8, 2):
            board[i, j] = num
            board[i, j + 1] = num + 1
            board[i + 1, j] = num + 2
            board[i + 1, j + 1] = num + 3
            num += 4
    return np.fliplr(np.rot90(board, k=2)) + (base * 64)


def get_neighbors(face, i, j):
    neighbors = []
    if i > 0:
        neighbors.append(face[i - 1, j])
    if i < 7:
        neighbors.append(face[i + 1, j])
    if j > 0:
        neighbors.append(face[i, j - 1])
    if j < 7:
        neighbors.append(face[i, j + 1])
    return neighbors


def get_boundary_neighbors(faces):
    boundary_neighbors = {i: [] for i in range(384)}

    # 境界の接続を定義（面1、エッジ1、面2、エッジ2、順序）
    # left ↓、up→, right ↓, down →
    connections = [
        (0, "left", 3, "right", "same"),
        (0, "up", 5, "down", "same"),
        (0, "right", 1, "left", "same"),
        (0, "down", 4, "up", "same"),
        (1, "left", 0, "right", "same"),
        (1, "up", 5, "right", "reverse"),
        (1, "right", 2, "left", "same"),
        (1, "down", 4, "right", "same"),
        (2, "left", 1, "right", "same"),
        (2, "up", 5, "up", "reverse"),
        (2, "right", 3, "left", "same"),
        (2, "down", 4, "down", "reverse"),
        (3, "left", 2, "right", "same"),
        (3, "up", 5, "left", "same"),
        (3, "right", 0, "left", "same"),
        (3, "down", 4, "left", "reverse"),
        (4, "left", 3, "down", "reverse"),
        (4, "up", 0, "down", "same"),
        (4, "right", 1, "down", "same"),
        (4, "down", 2, "down", "reverse"),
        (5, "left", 3, "up", "same"),
        (5, "up", 2, "up", "reverse"),
        (5, "right", 1, "up", "reverse"),
        (5, "down", 0, "up", "same"),
    ]

    def get_edge_coord(index, edge, reverse):
        if edge == "left":
            coord = (index, 0)
        elif edge == "right":
            coord = (index, 7)
        elif edge == "up":
            coord = (0, index)
        elif edge == "down":
            coord = (7, index)
        if reverse:
            if edge in ["left", "right"]:
                coord = (7 - index, coord[1])
            elif edge in ["up", "down"]:
                coord = (coord[0], 7 - index)
        return coord

    for face1, edge1, face2, edge2, order in connections:
        for i in range(8):
            coord1 = get_edge_coord(i, edge1, False)
            coord2 = get_edge_coord(i, edge2, order == "reverse")
            boundary_neighbors[faces[face1][coord1]].append(faces[face2][coord2])
            boundary_neighbors[faces[face2][coord2]].append(faces[face1][coord1])

    return boundary_neighbors


def create_adjacency_matrix(self_loop=True, spectral_connection=False):
    # 隣接行列の作成
    N = 384
    adjacency_matrix = np.zeros((N, N), dtype=int)

    # 各面の生成
    faces = [create_board(base=i) for i in range(6)]

    # 各面内の隣接を追加
    for face in faces:
        for i in range(8):
            for j in range(8):
                idx = face[i, j]
                neighbors = get_neighbors(face, i, j)
                for neighbor in neighbors:
                    adjacency_matrix[idx, neighbor] = 1
                    adjacency_matrix[neighbor, idx] = 1

    # 境界の隣接を追加
    boundary_neighbors = get_boundary_neighbors(faces)
    for idx, neighbors in boundary_neighbors.items():
        for neighbor in neighbors:
            adjacency_matrix[idx, neighbor] = 1
            adjacency_matrix[neighbor, idx] = 1

    assert adjacency_matrix[0, 205] == 1
    assert adjacency_matrix[13, 64] == 1
    assert adjacency_matrix[77, 128] == 1
    assert adjacency_matrix[141, 192] == 1
    assert adjacency_matrix[338, 250] == 1
    assert adjacency_matrix[336, 251] == 1
    assert adjacency_matrix[0, 306] == 1
    assert adjacency_matrix[50, 320] == 1
    assert adjacency_matrix[64, 319] == 1
    assert adjacency_matrix[118, 349] == 1

    assert adjacency_matrix.sum() == 384 * 4

    # 4つずつのまとまりについては spectral element 内で結合しているとみなす
    if spectral_connection:
        for i in range(384):
            for j in range(i + 1, 384):
                if i // 4 == j // 4:
                    adjacency_matrix[i, j] = 1
                    adjacency_matrix[j, i] = 1

    for i in range(384):
        if self_loop:
            adjacency_matrix[i][i] = 1
        else:
            adjacency_matrix[i][i] = 0

    return adjacency_matrix

In [26]:
def create_edge_index(self_loop=True, spectral_connection=True):

    matrix = create_adjacency_matrix(self_loop=False, spectral_connection=False)

    base_edge_index = np.array(np.nonzero(matrix))

    node_i_list = []
    node_j_list = []

    for ci in range(384):
        for h in range(1, 60):
            node_i_list.append(ci + (h - 1) * 384)
            node_j_list.append(ci + (h) * 384)

            node_i_list.append(ci + (h) * 384)
            node_j_list.append(ci + (h - 1) * 384)

    for i in range(base_edge_index.shape[1]):
        ci = base_edge_index[0, i]
        cj = base_edge_index[1, i]
        for h in range(60):
            node_i_list.append(ci + h * 384)
            node_j_list.append(cj + h * 384)

            node_i_list.append(cj + h * 384)
            node_j_list.append(ci + h * 384)

    if self_loop:
        for ci in range(384):
            for h in range(60):
                node_i_list.append(ci + h * 384)
                node_j_list.append(ci + h * 384)
    edge_index = np.stack([node_i_list, node_j_list], axis=0)
    edge_index = torch.tensor(edge_index, dtype=torch.long)

    # self loop 無し：229632
    return edge_index

In [27]:
edge_index = create_edge_index(self_loop=True)
edge_index.shape

torch.Size([2, 252672])

In [33]:
edge_index

tensor([[    0,   384,   384,  ..., 22271, 22655, 23039],
        [  384,     0,   768,  ..., 22271, 22655, 23039]])

In [46]:
import xarray as xr


def create_edge_attr(edge_index) -> torch.tensor:
    """
    output: (len(edge), 5)
    lat,lonの差、高低差をsin,cosで表現して入力
    """
    grid_path = "/kaggle/working/misc/grid_info/ClimSim_low-res_grid-info.nc"
    grid_info = xr.open_dataset(grid_path)
    latitude = grid_info["lat"].to_numpy()
    longitude = grid_info["lon"].to_numpy()
    latitude_radian = np.radians(latitude)
    longitude_radian = np.radians(longitude)

    lat_diff = (
        latitude_radian[edge_index[0, :] % 384]
        - latitude_radian[edge_index[1, :] % 384]
    )
    lon_diff = (
        longitude_radian[edge_index[0, :] % 384]
        - longitude_radian[edge_index[1, :] % 384]
    )
    h_diff = edge_index[0, :] // 384 - edge_index[1, :] // 384

    lat_diff_sin = torch.tensor(np.sin(lat_diff))
    lat_diff_cos = torch.tensor(np.cos(lat_diff))
    lon_diff_sin = torch.tensor(np.sin(lon_diff))
    lon_diff_cos = torch.tensor(np.cos(lon_diff))
    edge_attr = torch.stack(
        [
            lat_diff_sin,
            lat_diff_cos,
            lon_diff_sin,
            lon_diff_cos,
            h_diff,
        ],
        dim=1,
    )
    return edge_attr

In [47]:
edge_attr = create_edge_attr(edge_index)

In [49]:
edge_attr[384 * 59 * 2 - 2 :]

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000, -1.0000],
        [ 0.0000,  1.0000,  0.0000,  1.0000,  1.0000],
        [ 0.0594,  0.9982, -0.1953,  0.9808,  0.0000],
        ...,
        [ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000],
        [ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000],
        [ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000]], dtype=torch.float64)

In [53]:
cnt = 0
for i in range(edge_index.shape[1]):

    if (edge_index[0, i] < 384 * 2) and edge_index[1, i] < 384 * 2:
        print(edge_index[0, i].detach().item(), edge_index[1, i].detach().item())
        cnt += 1

0 384
384 0
1 385
385 1
2 386
386 2
3 387
387 3
4 388
388 4
5 389
389 5
6 390
390 6
7 391
391 7
8 392
392 8
9 393
393 9
10 394
394 10
11 395
395 11
12 396
396 12
13 397
397 13
14 398
398 14
15 399
399 15
16 400
400 16
17 401
401 17
18 402
402 18
19 403
403 19
20 404
404 20
21 405
405 21
22 406
406 22
23 407
407 23
24 408
408 24
25 409
409 25
26 410
410 26
27 411
411 27
28 412
412 28
29 413
413 29
30 414
414 30
31 415
415 31
32 416
416 32
33 417
417 33
34 418
418 34
35 419
419 35
36 420
420 36
37 421
421 37
38 422
422 38
39 423
423 39
40 424
424 40
41 425
425 41
42 426
426 42
43 427
427 43
44 428
428 44
45 429
429 45
46 430
430 46
47 431
431 47
48 432
432 48
49 433
433 49
50 434
434 50
51 435
435 51
52 436
436 52
53 437
437 53
54 438
438 54
55 439
439 55
56 440
440 56
57 441
441 57
58 442
442 58
59 443
443 59
60 444
444 60
61 445
445 61
62 446
446 62
63 447
447 63
64 448
448 64
65 449
449 65
66 450
450 66
67 451
451 67
68 452
452 68
69 453
453 69
70 454
454 70
71 455
455 71
72 456
456 7

In [54]:
cnt

7680

In [56]:
torch.arange(384 * 60).reshape(384, 60).flatten(0, 1)

tensor([    0,     1,     2,  ..., 23037, 23038, 23039])