In [2]:
import torch

In [6]:
a = torch.Tensor([1e-7])

In [10]:
a.half()

tensor([1.1921e-07], dtype=torch.float16)

In [5]:
torch.finfo(torch.float32).eps

1.1920928955078125e-07

# Introduction

GraphCast 最大的数据结构是构建了一个「均匀」的三角形mesh网络, 网络的每个node都可以对应到球面的一个点, 所以可以用$(\theta,\phi)$来简单的表示每个node。
当然在数据结构中，我们直接用 $node_0,node_1,\dots,node_N$ 来顺序存储所有 node
有几个难点:
- 我们需要记录每个 node 的近临关系，在GraphCast中定义了 $M_0,M_1,\dots,M_6$ 七个level，且是5近临结构。所以每个level都是一个$(N,5)$的table，每个元素表示的是紧邻的点(由于我们最后是计算综合的aggregate，所以这里的顺序没有关系）
- mesh 的方式是 icosahedral （二十面体）以及后续的三角划分，这种方法叫做 Geodesic_polyhedron [https://en.wikipedia.org/wiki/Geodesic_polyhedron]

所以第一步就是建立球面上的 二十面体 $M_0$ 以及接下来所有 $M_i$ 的 node， 用$(\theta,\phi)$来标记。

- 我们可以用 `anti_lib_progs` 这个包来生成坐标，但是没有临近关系。
- 我们也可以用 Mathematica 13.0 来做
```
GetNearByPoint[polygons_] := Module[{pool = Association[]},
  For[i = 1, i <= Length[polygons], i++,
   polygon = polygons[[i]];
   For[j = 1, j <= 3, j++,
    keyvalue = TakeDrop[polygon, {j}];
    key = keyvalue[[1]][[1]];
    val = keyvalue[[2]];
    If[KeyExistsQ[pool, key], , pool[key] = {}];
    pool[key] = Union[Join[pool[key], val]];
    ]
   ];
  pool
  ]
GetNearByPoint[polygons_, position_] := 
 Module[{pool = Association[]},
  For[i = 1, i <= Length[polygons], i++,
   polygon = polygons[[i]];
   For[j = 1, j <= 3, j++,
    keyvalue = TakeDrop[polygon, {j}];
    key = position[[keyvalue[[1]][[1]]]];
    val = position[[keyvalue[[2]]]];
    If[KeyExistsQ[pool, key], , pool[key] = {}];
    pool[key] = Union[Join[pool[key], val]];
    ]
   ];
  pool
  ]
polygons = GeodesicPolyhedron["Icosahedron", 1];
GetNearByPoint[polygons[[2]], Round[polygons[[1]], 0.0001]]
Graphics3D[polygons]
```
这里我们可以发现， 论文中使用的setting是
`GeodesicPolyhedron["Icosahedron", n];` $n\in[1,2,4,8,16,32,64]$

对应的 node数为 $[12,42,162,642,2562,10242,40962]$

所以还是相当大的一个Mesh, 但他实际上还是取决于分辨率的。
- 比如 32x64 那么对应的 node 为 2048 对于 M5
- 比如 64x128 那么对应的 node 为 8192 对应于 M6
- 比如 720x1440 那么对应的 node 为 1036800 这是 25 倍的 M7

### Create Mesh-Node

我们这里直接用 Mathematica 中生成需要的所有 node

In [1]:
import csv

In [11]:
import numpy as np

In [None]:
'GraphCastStructure/M1.csv'

In [13]:
def readMx(path):
    pool = {}
    with open(path, 'r') as csvfile:
        spamreader = csv.reader(csvfile)
        for row in spamreader:
            key,val = row
            key = eval(key.replace("{","(").replace("}",")"))
            val = eval(val.replace("{","(").replace("}",")"))
            pool[key]=val
    return pool

In [20]:
Mpoolist = [readMx(f'GraphCastStructure/M{i+1}.csv') for i in range(7)]

In [1]:
import torch

In [45]:

torch.linalg.norm(a,dim=(1,2,3))

RuntimeError: linalg.norm: If dim is specified, it must be of length 1 or 2. Got [1, 2, 3]

In [58]:
a = torch.randn(2,3,4,5)

In [60]:
a.shape

torch.Size([2, 3, 4, 5])

In [61]:
torch.einsum('Bijk,Bijk->B',a,a)

tensor([1.0000, 1.0000])

In [21]:
position2node = {}
for pool in Mpoolist:
    for key in pool:
        if key not in position2node:
            position2node[key] = len(position2node)

In [75]:
enumerate

In [76]:
def save_json(path,_dict):
    with open(path,'w') as f:
        json.dump(_dict,f)

In [78]:
import torch

In [79]:
torch.save(position2node,"GraphCastStructure/position2node.json.pt")

In [25]:
node2position = np.zeros((40962,3))
for key,val in position2node.items():
    node2position[val]=np.array(key)

In [81]:
np.save("GraphCastStructure/node2position.npy",node2position)

In [40]:
node2nearby_por_level = []
for pool in Mpoolist:
    node2nearby = {}
    for pos, nearby in pool.items():
        node_id = position2node[pos]
        nearby_ids = [position2node[p] for p in nearby]
        if node_id not in node2nearby:node2nearby[node_id]=[]
        node2nearby[node_id]+=nearby_ids
    node2nearby_por_level.append(node2nearby)

In [84]:
torch.save(node2nearby_por_level,"GraphCastStructure/node2nearby_per_level.json.pt")

In [67]:
key_nearby_pair_per_level = []
for level, node2nearby in enumerate(node2nearby_por_level):
    max_length = max([len(t) for t in node2nearby.values()])
    key_nodes = []
    nearby_nodes_list = []
    for key_node,nearby_nodes in node2nearby.items():
        key_nodes.append(key_node)
        nearby_nodes_list.append(nearby_nodes if len(nearby_nodes)==max_length else nearby_nodes+[-1])
    key_nodes = np.array(key_nodes)
    nearby_nodes = np.array(nearby_nodes_list)
    key_nearby_pair_per_level.append([key_nodes,nearby_nodes])
    #print(f"{level}: {min(lenth_list)} - {max(lenth_list)}")
    

In [85]:
torch.save(key_nearby_pair_per_level,"GraphCastStructure/keynodes2nearbynodes_pair_per_level.json.pt")

In [64]:
edge2id = {}

for key_nodes, node2nearby in key_nearby_pair_per_level:
    for key_node, nearby_nodes in zip(key_nodes, node2nearby):
        for nearby_node in nearby_nodes:
            if key_node == -1:continue
            if nearby_node == -1:continue
            edge_id = [key_node,nearby_node]
            edge_id.sort()
            edge_id = tuple(edge_id)
            if edge_id not in edge2id:
                edge2id[edge_id] = len(edge2id)
                

In [86]:
torch.save(edge2id,"GraphCastStructure/edge2id.json.pt")

In [65]:
edgeid2pair = np.zeros((len(edge2id),2),dtype='int')
for pair, _id in edge2id.items():
    edgeid2pair[_id] = np.array(pair)

In [88]:
np.save("GraphCastStructure/edgeid2pair.npy",edgeid2pair)

In [71]:
key_nearbyedge_pair_per_level = []
for key_nodes, node2nearby in key_nearby_pair_per_level:
    node2nearby_edge_list =[]
    for key_node, nearby_nodes in zip(key_nodes, node2nearby):
        nearby_edge_id = []
        for nearby_node in nearby_nodes:
            if key_node == -1 or nearby_node == -1:
                edge_id = -1
            else:
                edge_pair = [key_node,nearby_node]
                edge_pair.sort()
                edge_id = tuple(edge_pair)
                edge_id = edge2id[edge_id]
            nearby_edge_id.append(edge_id)
        node2nearby_edge_list.append(nearby_edge_id)
    key_nearbyedge_pair_per_level.append([key_nodes,np.array(node2nearby_edge_list)])

In [89]:
torch.save(key_nearbyedge_pair_per_level,"GraphCastStructure/keynodes2nearbyedges_pair_per_level.json.pt")

然后我们计算 grid 和 node 之间的聚合关系

我们用 721x1440 作为例子, 先用 (720, 1440, 3) 来表示出每个位置的 pos

In [90]:
latitude   = np.linspace(90,-90,721)
longitude  = np.linspace(0,360,1440)

In [100]:
x, y  = np.meshgrid(latitude, longitude)

In [131]:
LaLotude = np.stack([y,x],-1).transpose(1,0,2)/180*np.pi #(721,1440)

In [184]:
LaLotudeVector = np.stack([np.cos(LaLotude[...,1])*np.cos(LaLotude[...,0]),
                           np.cos(LaLotude[...,1])*np.sin(LaLotude[...,0]),
                           np.sin(LaLotude[...,1])],2)

In [200]:
LaLotudeVector = LaLotudeVector[180:-180]

In [201]:
LaLotudeVector = LaLotudeVector.reshape(-1,3)

In [157]:
unit = 0.014101772938180504
radius = 0.6*unit

In [202]:
grid_pos = LaLotudeVector[10000]
np.where(np.logical_and((node2position < (grid_pos + unit)).all(axis=1),
                        (node2position > (grid_pos - unit)).all(axis=1)))

(array([ 4146, 16482, 16483, 16545]),)

In [205]:
from tqdm.notebook import tqdm

In [206]:
hard_indexes = []
for idx,grid_pos in tqdm(enumerate(LaLotudeVector)):
    rough_range_index = np.where(np.logical_and((node2position < (grid_pos + unit)).all(axis=1),
                                                (node2position > (grid_pos - unit)).all(axis=1)))[0]
    rough_range = node2position[rough_range_index]
    hard_index  = np.where(np.linalg.norm(rough_range - grid_pos,axis=1)<0.6*unit)
    hard_indexes.append(rough_range_index[hard_index])
    

0it [00:00, ?it/s]

In [207]:
aware =[ len(t) for t in hard_indexes]

In [208]:
from mltool.visualization import *