# Dataset

### Our Implemnt

In [1]:
from cephdataset import *

In [2]:
single_vnames = ["2m_temperature",
                  "10m_u_component_of_wind",
                  "10m_v_component_of_wind",
                  "total_cloud_cover",
                  "total_precipitation",
                  "toa_incident_solar_radiation"]
level_vnames= []
for physics_name in ["geopotential", "temperature",
                     "specific_humidity","relative_humidity",
                     "u_component_of_wind","v_component_of_wind",
                     "vorticity","potential_vorticity"]:
    for pressure_level in [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]:
        level_vnames.append(f"{pressure_level}hPa_{physics_name}")
all_vnames = single_vnames + level_vnames

In [5]:
for i,v in enumerate(all_vnames):
    print(f"[{i:3d}]: {v}")

[  0]: 2m_temperature
[  1]: 10m_u_component_of_wind
[  2]: 10m_v_component_of_wind
[  3]: total_cloud_cover
[  4]: total_precipitation
[  5]: toa_incident_solar_radiation
[  6]: 50hPa_geopotential
[  7]: 100hPa_geopotential
[  8]: 150hPa_geopotential
[  9]: 200hPa_geopotential
[ 10]: 250hPa_geopotential
[ 11]: 300hPa_geopotential
[ 12]: 400hPa_geopotential
[ 13]: 500hPa_geopotential
[ 14]: 600hPa_geopotential
[ 15]: 700hPa_geopotential
[ 16]: 850hPa_geopotential
[ 17]: 925hPa_geopotential
[ 18]: 1000hPa_geopotential
[ 19]: 50hPa_temperature
[ 20]: 100hPa_temperature
[ 21]: 150hPa_temperature
[ 22]: 200hPa_temperature
[ 23]: 250hPa_temperature
[ 24]: 300hPa_temperature
[ 25]: 400hPa_temperature
[ 26]: 500hPa_temperature
[ 27]: 600hPa_temperature
[ 28]: 700hPa_temperature
[ 29]: 850hPa_temperature
[ 30]: 925hPa_temperature
[ 31]: 1000hPa_temperature
[ 32]: 50hPa_specific_humidity
[ 33]: 100hPa_specific_humidity
[ 34]: 150hPa_specific_humidity
[ 35]: 200hPa_specific_humidity
[ 36]: 250hP

In [47]:
class WeathBench32x64CK(WeathBench):
    default_root = 'datasets/weatherbench32x64'
    
    def config_pool_initial(self):
        CK_order = [1, 2, 0, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 
               29, 30, 31, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 
               68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83]
        config_pool={
            'SWINRNN69' : (CK_order    ,'gauss_norm'   , self.mean_std[:,CK_order].reshape(2,69,1,1)      , identity, identity ),
        }
        self.constant_index = [0,2]
        return config_pool
        
    def load_numpy_from_url(self,url):#the saved numpy is not buffer, so use normal reading
        if "s3://" in url:
            if self.client is None:self.client=Client(conf_path="~/petreloss.conf")
            with io.BytesIO(self.client.get(url)) as f:
                array = np.load(f)
        else:
            array = np.load(url)
        return array
    
    def get_item(self,idx,reversed_part=False):
        year, hour = self.single_data_path_list[idx]
        url  = f"{self.root}/{year}/{year}-{hour:04d}.npy"
        print(url)
        odata = np.load(url)
        data = odata[self.channel_choice]
        data = (data - self.mean)/self.std
        cons = self.constants[self.constant_index]
        return np.concatenate([cons,data])

In [79]:
dataset2 = WeathBench32x64CK('valid',dataset_flag='SWINRNN69',time_intervel=6)

use dataset in datasets/weatherbench32x64


### Others

In [55]:
from torch.utils.data import Dataset
from tqdm import tqdm
import numpy as np
import io
import torch
import time

Years = {
    'train': range(1979, 2016),
    'valid': range(2018, 2019),
    'test': range(2016, 2018),
    'all': range(1979, 2019)
}

multi_level_vnames = [
    "z", "t", "q", "r", "u", "v", "vo", "pv",
]
single_level_vnames = [
    "t2m", "u10", "v10", "tcc", "tp", "tisr",
]
long2shortname_dict = {"geopotential": "z", "temperature": "t", "specific_humidity": "q", "relative_humidity": "r", "u_component_of_wind": "u", "v_component_of_wind": "v", "vorticity": "vo", "potential_vorticity": "pv", \
    "2m_temperature": "t2m", "10m_u_component_of_wind": "u10", "10m_v_component_of_wind": "v10", "total_cloud_cover": "tcc", "total_precipitation": "tp", "toa_incident_solar_radiation": "tisr"}
constants = [
    "lsm", "slt", "orography"
]
height_level = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]

multi_level_dict_param = {"z":height_level, "t": height_level, "q": height_level, "r": height_level}



class WeatherBench(Dataset):
    def __init__(self, data_dir='dataset/weatherbench32x64', split='train', **kwargs) -> None:
        super().__init__()
        self.use_mem = kwargs.get('use_mem', False)
        self.length = kwargs.get('length', 1)
        self.file_stride = kwargs.get('file_stride', 1)
        self.sample_stride = kwargs.get('sample_stride', 1)
        self.output_meanstd = kwargs.get("output_meanstd", False)

        vnames_type = kwargs.get("vnames", {})
        self.constants_types = vnames_type.get('constants', [])
        self.single_level_vnames = vnames_type.get('single_level_vnames', single_level_vnames)
        self.multi_level_dict =multi_level_dict= vnames_type.get('multi_level_vnames_dict', {"z": height_level, "v": height_level})
        multi_height_dict = vnames_type.get("height_level", None)


        self.single_level_vnames_index = [single_level_vnames.index(level) for level in self.single_level_vnames]
        if multi_level_dict is not None:
            index = [[multi_level_vnames.index(i), height_level.index(j)] for i in multi_level_dict for j in multi_level_dict[i]]
            self.total_data_index = self.single_level_vnames_index + [6+j+13*i for i,j in index]
        elif multi_height_dict is not None:
            index = [[multi_level_vnames.index(i), height_level.index(j)] for j in multi_height_dict for i in multi_height_dict[j]]
            self.total_data_index = self.single_level_vnames_index + [6+j+13*i for i,j in index]


        constants_index = [constants.index(constant) for constant in self.constants_types]
        
        print(f"constants_index={constants_index}")
        print(f"total_data_index={self.total_data_index}")
        
        self.split = split
        self.data_dir = data_dir
        #self.client = Client(conf_path="~/petreloss.conf")
        years = Years[split]
        self.file_list = self.init_file_list(years)

        # print(constants_index)
        if len(constants_index) > 0:
            self.constants_data = self.get_constants_data(constants_index)
        else:
            self.constants_data = None
        # if self.constants_data is not None:
        #     print(self.constants_data.shape)
        # print("get meanstd...")
        dataset_meanstd = self._get_meanstd()
        # print(dataset_meanstd.shape)
        # print("get meanstd complete")
        self.data_mean = dataset_meanstd[0]
        self.data_std = dataset_meanstd[1]
        # print("mean shape:", self.data_mean.shape)
        # print("std shape:", self.data_std.shape)
        if self.use_mem:
            self.data = self.preload_era5_ceph()
            # print(self.data.shape)


    def init_file_list(self, years):
        file_list = []
        for year in years:
            if year == 1979:                                    # 1979年数据只有8753个，缺少第一天前7小时数据，所以这里我们从第二天开始算起
                for hour in range(17, 8753, self.file_stride):
                    file_list.append([year, hour])
            else:
                if year % 4 == 0:
                    max_item = 8784
                else:
                    max_item = 8760
                for hour in range(0, max_item, self.file_stride):
                    file_list.append([year, hour])
        return file_list

    def _get_meanstd(self):
        url = f"{self.data_dir}/mean_std.npy"
        # with io.BytesIO(self.client.get(url)) as f:
        #     array = np.load(f)                                  #(2, C)
        array = self.load_numpy_from_url(url)
        return array

    def get_constants_data(self, constants_index):
        url = f"{self.data_dir}/constants.npy"
        # with io.BytesIO(self.client.get(url)) as f:
        #     array = np.load(f)                                  #(3, H, W)
        array = self.load_numpy_from_url(url)[constants_index]
        return array

    def preload_era5_ceph(self):
        arrays = []
        pbar = range(len(self.file_list))
        for inum in pbar:
            # end_time = time.time() 
            level_data = self._load_array(inum)
            # if self.constants_data is not None:
            #     array = torch.cat((self.constants_data, level_data), dim=0).unsqueeze(0)
            # else:
            array = level_data.unsqueeze(0)
            arrays.append(array)
            # print("step time:", time.time()-end_time)
        arrays = torch.cat(arrays, dim=0)
        arrays = arrays - self.data_mean.unsqueeze(-1).unsqueeze(-1)
        arrays = arrays / self.data_std.unsqueeze(-1).unsqueeze(-1)
        # print(arrays.shape)
        return arrays

    def load_numpy_from_url(self,url):#the saved numpy is not buffer, so use normal reading
        if "s3://" in url:
            if self.client is None:self.client=Client(conf_path="~/petreloss.conf")
            with io.BytesIO(self.client.get(url)) as f:
                array = np.load(f)
        else:
            array = np.load(url)
        return torch.from_numpy(array)
    
    def _load_array(self, index):
        year, hour = self.file_list[index]
        url = f"{self.data_dir}/{year}/{year}-{hour:04d}.npy"
        print(url)
        # end_time = time.time()
        # with io.BytesIO(self.client.get(url)) as f:
        #     array = np.load(f)                          # C H W
        # print("load array time:", time.time()-end_time)

        array = self.load_numpy_from_url(url)
        # array -= self.data_mean.unsqueeze(-1).unsqueeze(-1)
        # array /= self.data_std.unsqueeze(-1).unsqueeze(-1)
        # array = array[self.total_data_index]
        # print(array)
        return array

    def __len__(self):
        return len(self.file_list) - (self.length-1) * self.sample_stride

    def get_meanstd(self):
        return self.data_mean[self.total_data_index], self.data_std[self.total_data_index]

    def get_clim_daily(self):
        url = f"{self.data_dir}/time_means_daily.npy"
        # with io.BytesIO(self.client.get(url)) as f:
        #     array = np.load(f)                                  #(8760, 110, 32, 64)
        # array = torch.from_numpy(array)
        array = self.load_numpy_from_url(url)
        time_index = list(range(0, 8760, self.file_stride))
        array = (array - self.data_mean.unsqueeze(-1).unsqueeze(-1)) / self.data_std.unsqueeze(-1).unsqueeze(-1)
        array = array[time_index].transpose(0, 1)[self.total_data_index].transpose(0,1)
        return array
    
    def feature_name(self):
        # single_vnames = ["2m_temperature",
        #               "10m_u_component_of_wind",
        #               "10m_v_component_of_wind",
        #               "total_cloud_cover",
        #               "total_precipitation",
        #               "toa_incident_solar_radiation"]
        # level_vnames= []
        # for physics_name in ["geopotential", "temperature",
        #                     "specific_humidity","relative_humidity",
        #                     "u_component_of_wind","v_component_of_wind",
        #                     "vorticity","potential_vorticity"]:
        #     for pressure_level in [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]:
        #         level_vnames.append(f"{pressure_level}hPa_{physics_name}")
        # all_vnames = single_vnames + level_vnames
        multilevelindex = [f"{i}{j}" for i in self.multi_level_dict for j in self.multi_level_dict[i]]
        vanames = self.single_level_vnames + multilevelindex
        return vanames

    def __getitem__(self, index):
        index = min(index, len(self.file_list) - (self.length-1) * self.sample_stride - 1)
        array_seq = []
        for i in range(self.length):
            if self.use_mem:
                if self.constants_data is not None:
                    return_data = torch.cat((self.constants_data, self.data[index + i * self.sample_stride][self.total_data_index]), dim=0)
                else:
                    return_data = self.data[index + i * self.sample_stride][self.total_data_index]
                array_seq.append(return_data)
            else:
                data_after_norm = (self._load_array(index + i * self.sample_stride)-self.data_mean.unsqueeze(-1).unsqueeze(-1)) / self.data_std.unsqueeze(-1).unsqueeze(-1)
                if self.constants_data is not None:
                    return_data = torch.cat((self.constants_data, data_after_norm[self.total_data_index]), dim=0)
                else:
                    return_data = data_after_norm[self.total_data_index]

                array_seq.append(return_data)
        return array_seq


In [9]:
import yaml

In [80]:
with open("../wpredict-wp32x64_2/configs/swinvrnn/lgnet.yaml", 'r') as cfg_file:
    cfg_params = yaml.load(cfg_file, Loader = yaml.FullLoader)

In [81]:
cfg_params['dataset']['test']['data_dir']='datasets/weatherbench32x64'
datasets = WeatherBench(split='test',**cfg_params['dataset']['test'])

constants_index=[0, 2]
total_data_index=[1, 2, 0, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83]


In [82]:
len(datasets)

17538

In [83]:
len(dataset2)

17533

In [86]:
index= 3
np.linalg.norm(datasets[index][1].numpy()-dataset2[index][1],axis=(1,2))

datasets/weatherbench32x64/2016/2016-0003.npy
datasets/weatherbench32x64/2016/2016-0009.npy
datasets/weatherbench32x64/2016/2016-0003.npy
datasets/weatherbench32x64/2016/2016-0009.npy


array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.])

In [43]:
index = 2
index = min(index, len(datasets.file_list) - (datasets.length-1) * datasets.sample_stride - 1)
year, hour = datasets.file_list[index]
url = f"{datasets.data_dir}/{year}/{year}-{hour:04d}.npy"
print(url)
array = datasets.load_numpy_from_url(url)
data_after_norm = (array-datasets.data_mean.unsqueeze(-1).unsqueeze(-1)) / datasets.data_std.unsqueeze(-1).unsqueeze(-1)

year, hour = dataset2.single_data_path_list[index]
url  = f"{dataset2.root}/{year}/{year}-{hour:04d}.npy"
print(url)
odata = np.load(url)
# data = odata[self.channel_choice]
# data = (data - self.mean)/self.std

datasets/weatherbench32x64/2016/2016-0012.npy
datasets/weatherbench32x64/2016/2016-0002.npy


In [35]:
np.linalg.norm(data_origin-odata,axis=(1,2))

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

# Run

In [1]:
import torch

In [2]:
from train.pretrain import *

In [4]:
args= get_args("checkpoints/WeathBench32x64CK/SWIN_Feature-LoRAGraphCastDGLSym/ts_22_fourcast-2D706N_per_6_step/02_19_18_00_34835-seed_73001/config.json")

In [8]:
args.use_wandb=0
args.gpu = args.local_rank = gpu  = local_rank = 0
##### parse args: dataset_kargs / model_kargs / train_kargs  ###########
args= parse_default_args(args)
SAVE_PATH = get_ckpt_path(args)
SAVE_PATH = "debug"
args.SAVE_PATH = str(SAVE_PATH)
#args.pretrain_weight = os.path.join(args.SAVE_PATH,'pretrain_latest.pt')
########## inital log ###################
logsys = create_logsys(args,False)
args.distributed = False

if args.distributed:
    if args.dist_url == "env://" and args.rank == -1:
        args.rank = int(os.environ["RANK"])
    if args.multiprocessing_distributed:
        # For multiprocessing distributed training, rank needs to be the
        # global rank among all the processes
        args.rank = args.rank * ngpus_per_node + local_rank
    logsys.info(f"start init_process_group,backend={args.dist_backend}, init_method={args.dist_url},world_size={args.world_size}, rank={args.rank}")
    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,world_size=args.world_size, rank=args.rank)

model           = build_model(args)
#param_groups    = timm.optim.optim_factory.add_weight_decay(model, args.weight_decay)
optimizer,lr_scheduler,criterion = build_optimizer(args,model)
loss_scaler     = torch.cuda.amp.GradScaler(enabled=True)
logsys.info(f'use lr_scheduler:{lr_scheduler}')

2023-02-21 14:39:26,122 model args: img_size= (32, 64)
2023-02-21 14:39:26,124 model args: patch_size= 2


log at debug
wandb id: None
wandb is off, the recorder list is  ['tensorboard'], we pass wandb

            This is ===> GraphCast Model(DGL) <===
            Information: 
                total mesh node: 2562 total unique mesh edge:10230*2=20460 
                total grid node 2048+2 = 2050 but activated grid 1928 
                from activated grid to mesh, create 4*2562 - 6 = 10242 edges. (north and south pole repeat 4 times) 
                there are 122 unactivated grid node
                when mapping node to grid, 
                from node to activated grid, there are 10032 edges
                from node to unactivated grid, there are 976 edges
                thus, totally have 11008 edge. 
                #notice some grid only have 1-2 linked node but some grid may have 30 lined node
            


2023-02-21 14:39:26,957 use model ==> SWIN_Feature
2023-02-21 14:39:26,959 Rank: 0, Local_rank: 0 | Number of Parameters: 26296320, Number of Buffers: 0, Size of Model: 100.3125 MB

2023-02-21 14:39:29,023 use lr_scheduler:<timm.scheduler.cosine_lr.CosineLRScheduler object at 0x7f2234c45fa0>


In [12]:
args.subweight

''

In [None]:
pretrain_path = os.path.join(ckpt_path,"backbone.best.pt")

In [7]:
model  = build_model(args)

AttributeError: 'str' object has no attribute 'info'