# 处理201001-201901的每月数据

In [1]:
import h5py
import numpy as np
import os
import torch
from torch.utils.data import Dataset, IterableDataset
import xarray as xr
import torch.nn.functional as F

In [2]:
def get_sub(data, latitude, longitude, lat_min, lat_max, lon_min, lon_max):
        """
        提取子区域的数据

        input:
        lat_min, lat_max, lon_min, lon_max: 子区域范围
        data: 原始数据
        latitude, longitude: 经纬度数据

        return: subset_data, subset_lat, subset_lon
        """
        
        lat_indices = np.where((latitude >= lat_min) & (latitude <= lat_max))[0]  # 找到对应的索引
        lon_indices = np.where((longitude >= lon_min) & (longitude <= lon_max))[0]
        
        subset_data = data[:, lat_indices, :][:, :, lon_indices]  # 提取子集数据
        subset_lat = latitude[lat_indices]  # 提取相应的经纬度数组
        subset_lon = longitude[lon_indices]

        return subset_data, subset_lat, subset_lon

def down_sample(data, lat_list, lon_list):
        '''
        0.25*0.25下采样到0.5*0.5

        in: (t, lat, lon)
        out: DataArray dim=(t, lat, lon)
        '''
        data = torch.tensor(data.values)  # array --> tensor
        if data.dim() == 3:
            data = data.unsqueeze(0)

        lat, lon = data.shape[-2], data.shape[-1]
        new_lat, new_lon = int(lat / 2), int(lon / 2)  # 目标尺寸
        new_size = (new_lat, new_lon)
        data = F.interpolate(data, size=new_size, mode='bilinear', align_corners=False)

        return data, lat_list[::2], lon_list[::2]

def compute_climatological_mean_and_anomalies(data):
        """
        计算每个变量的气候学平均值, 从而计算其异常值
        input: data (xarray.Dataset or xarray.DataArray): 包含多个变量的时间序列数据，维度为 (time, lat, lon)。
        return: xarray.Dataset or xarray.DataArray: 包含异常值的数据集，维度为 (time, lat, lon)。
        """
        # 时间维度名为 'time'
        # print("Dimensions of data:", data.dims)
        clim_mean = data.mean(dim='time')  # 计算气候学平均值（沿着time维度求平均）
        clim_mean_expanded = clim_mean.broadcast_like(data)
        anomalies = data - clim_mean_expanded  # 从原始数据中减去气候学平均值得到异常值
        
        return anomalies


def min_max(data):
        """
        对输入数据按变量进行归一化

        input:(var, time, lat, lon)
        output: (var, time, lat, lon)
        """
        minmax = []
        for i in range(data.shape[0]):
                var_data = data[i]
                var_min = var_data.min(dim='time')
                var_max = var_data.max(dim='time')
                normalized_var_data = (var_data - var_min) / (var_max - var_min)
                minmax.append(normalized_var_data)
                # normalized_data.loc[dict(var=var)] = normalized_var_data

        minmax = xr.concat(minmax, dim='file')
        return minmax

def get_input_data(folder_path, reference_file, lat_min, lat_max, lon_min, lon_max):
        """
        提取输入数据并裁剪
        folder_path, reference_file: 数据文件夹地址 及 参考数据文件地址
        
        return:  (var, time, lat, lon)
        """
        # 1、提取文件名
        nc_files = [file for file in os.listdir(folder_path) if file.endswith('.nc')]
        data_all = []  # 存储所有数据

        # 2、先加载reference data, 作为网格插值的基准
        ref_ds = xr.open_dataset(reference_file)
        ref_lat = ref_ds['lat']
        ref_lon = ref_ds['lon']
        ref_data = ref_ds['data'][0:109, ...]   # torch.Size([109, 108, 200])
        data, sub_ref_lat, sub_ref_lon = down_sample(ref_data, ref_lat, ref_lon)  # 0.25*0.25下采样到0.5*0.5
        sub_ref_data = xr.DataArray(data.squeeze(0), dims=["time", "lat", "lon"], coords={"lat": sub_ref_lat, "lon": sub_ref_lon})
        # 提取子区域
        ref_subset_data, ref_subset_lat, ref_subset_lon = get_sub(sub_ref_data, sub_ref_lat, sub_ref_lon, lat_min, lat_max, lon_min, lon_max)
        # print('ref sub: ', ref_subset_data.shape, ref_data.shape)
        # 将 -999.0 的值转换为 np.nan
        mask = np.where(ref_subset_data == -999.0, np.nan, ref_subset_data)
        ref_subset_data = xr.DataArray(mask, dims=["time", "lat", "lon"], coords={"lat": ref_subset_lat, "lon": ref_subset_lon})
        data_all.append(ref_subset_data)

        # 3、逐个加载.nc文件并进行插值
        for file in nc_files:
            file_path = os.path.join(folder_path, file)
            print(f"Processing file: {file_path}")
            ds = xr.open_dataset(file_path)
            data = ds['data'][:109, ...]  # 提取前109个时间步的数据 
            interpolated_data = data.interp(lat=ref_lat, lon=ref_lon)  # 将 'data' 插值到目标经纬度网格
            data, lat, lon = down_sample(interpolated_data, ref_lat, ref_lon)  # 0.25*0.25下采样到0.5*0.5
            data = xr.DataArray(data.squeeze(0), dims=["time", "lat", "lon"], coords={"lat": lat, "lon": lon})
            # print('after sample: ', data.shape, lat.shape, lon.shape)
            subset_data, subset_lat, subset_lon = get_sub(data, lat, lon, lat_min, lat_max, lon_min, lon_max)  # 提取子区域
            # print('sub_set: ', subset_data.shape, subset_lat.shape, subset_lon.shape)
            # 掩码处理：通过reference的nan值将所有数据相同位置的数字换为nan
            nan_mask = np.isnan(ref_subset_data)
            # print('mask: ', nan_mask.shape)
            # print(nan_mask)
            masked_data = np.where(nan_mask, np.nan, subset_data)
            masked_data = xr.DataArray(masked_data, dims=["time", "lat", "lon"], coords={"lat": ref_subset_lat, "lon": ref_subset_lon})
            
            data_all.append(masked_data)

        data_all = xr.concat(data_all, dim='file')  # 将所有插值后的数据堆叠在一起
        data_all = data_all.where(np.abs(data_all) <= 100, np.nan)  # 将数据中绝对值大于100的数值替换为NaN
        data_all = compute_climatological_mean_and_anomalies(data_all)  # 计算数据异常值 - 减去 climatological mean
        data_all = min_max(data_all)  # 最大最小归一化

        # print('shape of region:', data_all.shape)
        return data_all

def get_armor(path, key, lat_min, lat_max, lon_min, lon_max):
        '''
        提取label

        armor数据如下:
        depth (36,)
        latitude (688,)
        longitude (1439,)
        time (313,)
        mlotst (313, 688, 1439)
        so (313, 36, 688, 1439)
        to (313, 36, 688, 1439)
        
        return: (depth, time, lat, lon)
        '''
        f = xr.open_dataset(path, chunks={'time': 1})
        data = f[key][204:313, ...]
        depth = f['depth']
        lat = f['latitude']
        lon = f['longitude']

        # down_sample
        sub_data, lat, lon = down_sample(data, lat, lon)
        data = xr.DataArray(sub_data, dims=["time", "depth", "latitude", "longitude"], coords={"latitude": lat, "longitude": lon[:719]})

        # 找到对应的索引
        lat_indices = np.where((lat >= lat_min) & (lat <= lat_max))[0]
        lon_indices = np.where((lon >= lon_min) & (lon <= lon_max))[0]
        # print('lat,lon:', lat_indices.shape, lon_indices.shape)

        # 提取子集数据
        subset_data = data[:, :, lat_indices, lon_indices].transpose('depth', 'time',  'latitude', 'longitude')
        # print('end:', subset_data.shape)

        # 提取相应的经纬度数组
        subset_lat = lat[lat_indices]
        subset_lon = lon[lon_indices]

        # 计算数据异常值 - 减去 climatological mean
        # print(subset_data.dims)
        subset_data = compute_climatological_mean_and_anomalies(subset_data)
        nan_mask = np.isnan(subset_data)
        nan_mask = torch.tensor(nan_mask.values)
        # print('armor: ', subset_data)

        # minmax归一化
        subset_data = min_max(subset_data)

        # print('return:', subset_data.shape)

        return subset_data, subset_lat, subset_lon, depth, nan_mask

def precess_data(input, label, lat, lon, depth, mask, reference_file):
        # mask: (36, 109, 54, 100)
        mask = torch.where(mask, 0, 1)  # 将True False 换为0 1，false代表非nan值处
        input = torch.from_numpy(input.values).permute(1,0,2,3)
        label = torch.from_numpy(label.values).permute(1,0,2,3)
        lat = torch.from_numpy(lat.values)
        lon = torch.from_numpy(lon.values)
        depth = torch.from_numpy(depth.values)

        # 将lat和lon合并到input中
        time = input.shape[0]
        lat = input.shape[2]
        lon = input.shape[3]
        expand_lat = lat.unsqueeze(0).unsqueeze(-1).repeat(time, 1, 1, lon)
        expand_lon = lon.unsqueeze(0).unsqueeze(0).repeat(time, 1, lat, 1)
        input = torch.cat((input, expand_lat, expand_lon), dim=1)

        # 将时间合并到input中
        ds = xr.open_dataset(reference_file)
        time = ds.variables['time'][0:109].values  # 201001 - 201901
        jd1 = torch.cos( torch.tensor(2*np.pi*(time/12)+1) )
        jd2 = torch.sin( torch.tensor(2*np.pi*(time/12)+1) )
        jd1 = jd1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, lat, lon)
        jd2 = jd2.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, lat, lon)
        # print('jd1, jd2', jd1.shape, jd2.shape)
        input = torch.cat((input, jd1), dim=1)
        input = torch.cat((input, jd2), dim=1)
        
        # 将数据中的nan全换为0
        input = torch.where(torch.isnan(input), torch.full_like(input, 0), input)
        label = torch.where(torch.isnan(label), torch.full_like(label, 0), label)

        # 总数据：0:109, 代表201001-201901
        # 数据划分为train:201001-201803 test:201804-201901

        
        print('shape of variable: ', input.shape, label.shape, lat.shape, lon.shape, depth.shape)

In [19]:
reference_file = '/home/data2/pengguohang/My_Ocean/challenge/oisst_monthly_201001-201904.nc'
folder_path='/home/data2/pengguohang/My_Ocean/challenge'
label_path = '/home/data2/pengguohang/My_Ocean/CMEMS/armor_montly_1993_2019/armor_1993_2019.nc'

# 墨西哥湾区域
lat_min = 23
lat_max = 50
lon_min = -80   
lon_max = -30

key = 'to'  # 'so' or 'to'

inputs = get_input_data(folder_path, reference_file, lat_min, lat_max, lon_min, lon_max)
labels, lat, lon, depth, mask = get_armor(label_path, key, lat_min, lat_max, lon_min, lon_max)


Processing file: /home/data2/pengguohang/My_Ocean/challenge/vwnd_monthly_201001-201904.nc
Processing file: /home/data2/pengguohang/My_Ocean/challenge/sss_cci_monthly_201001_201912_data.nc
Processing file: /home/data2/pengguohang/My_Ocean/challenge/swh_monthly_201001_201912_data.nc
Processing file: /home/data2/pengguohang/My_Ocean/challenge/oisst_monthly_201001-201904.nc
Processing file: /home/data2/pengguohang/My_Ocean/challenge/uwnd_monthly_201001-201904.nc
Processing file: /home/data2/pengguohang/My_Ocean/challenge/sla_monthly_201001_201901.nc
Processing file: /home/data2/pengguohang/My_Ocean/challenge/adt_monthly_201001-201912.nc


In [20]:
mask_data = torch.where(mask, 0, 1)  # 将True False 换为0 1，false代表非nan值处
mask_data = mask_data[0, 0, :, :]
input_data = torch.from_numpy(inputs.values).permute(1,0,2,3)
label_data = torch.from_numpy(labels.values).permute(1,0,2,3)
lat_data = torch.from_numpy(lat.values)
lon_data = torch.from_numpy(lon.values)
depth_data = torch.from_numpy(depth.values)
print(input_data.shape)
print(label_data.shape)
print(mask_data.shape)

torch.Size([109, 8, 54, 100])
torch.Size([109, 36, 54, 100])
torch.Size([54, 100])


In [21]:
import torch
import netCDF4 as nc

# 创建一个NetCDF文件
dataset = nc.Dataset('/home/data2/pengguohang/My_Ocean/challenge/data.nc', 'w', format='NETCDF4')

# 创建维度
dim1 = dataset.createDimension('vars', 8)
dim2 = dataset.createDimension('month', 109)
dim3 = dataset.createDimension('lat', 54)
dim4 = dataset.createDimension('lon', 100)
dim5 = dataset.createDimension('depth', 36)

# 创建变量并关联到维度
input = dataset.createVariable('input', 'f4', ('month', 'vars', 'lat', 'lon'))
label = dataset.createVariable('label', 'f4', ('month', 'depth', 'lat', 'lon'))
lat = dataset.createVariable('lat', 'f4', ('lat',))
lon = dataset.createVariable('lon', 'f4', ('lon',))
depth = dataset.createVariable('depth', 'f4', ('depth',))
mask = dataset.createVariable('mask', 'f4', ('lat', 'lon'))

# 将PyTorch张量的数据复制到NetCDF变量中
input[:] = input_data.numpy()  # 需要将张量转换为NumPy数组
label[:] = label_data.numpy()
lat[:] = lat_data.numpy()
lon[:] = lon_data.numpy()
depth[:] = depth_data.numpy()
mask[:] = mask_data.numpy()


# 添加一些属性
input.units = 'none'
input.long_name = 'inputs'

# 关闭文件
dataset.close()

In [28]:
import netCDF4 as nc
import matplotlib.pyplot as plt

path = '/home/data2/pengguohang/My_Ocean/challenge/data/data.nc'
with nc.Dataset(path) as ds:
    print(ds.variables.keys())
    input = ds['input'][:]
    print('input: ', input.shape)
    label = ds['label'][:]
    print('label: ', label.shape)
    lat = ds['lat'][:]
    print('lat: ', lat.shape)
    lon = ds['lon'][:]
    print('lon: ', lon.shape)
    depth = ds['depth'][:]
    print('depth: ', depth.shape)
    mask = ds['mask'][:]
    print('mask: ', mask.shape)


dict_keys(['input', 'label', 'lat', 'lon', 'depth', 'mask'])
input:  (109, 8, 54, 100)
label:  (109, 36, 54, 100)
lat:  (54,)
lon:  (100,)
depth:  (36,)
mask:  (54, 100)


# 查看CORA数据

In [2]:
path = '/home/data2/pengguohang/My_Ocean/CMEMS/CORA_1993_2019_P1M/CORA_199301_201901_PSAL_P1M.nc'

ds = xr.open_dataset(path)
print(ds)

<xarray.Dataset>
Dimensions:    (depth: 102, latitude: 1671, longitude: 720, time: 313)
Coordinates:
  * depth      (depth) float32 1.0 3.0 5.0 10.0 15.0 ... 940.0 960.0 980.0 1e+03
  * latitude   (latitude) float64 -77.0 -76.9 -76.8 -76.7 ... 89.8 89.9 90.0
  * longitude  (longitude) float64 -180.0 -179.5 -179.0 ... 178.5 179.0 179.5
  * time       (time) datetime64[ns] 1993-01-01 1993-02-01 ... 2019-01-01
Data variables:
    PSAL       (time, depth, latitude, longitude) float32 ...
Attributes: (12/22)
    Conventions:               CF-1.4
    analysis_name:             OA_CORA5.2_
    comment:                   V8.0 reference climatology and analysis parame...
    creation_date:             20230926T220219L
    data_manager:              Tanguy Szekely
    easternmost_longitude:     179.5
    ...                        ...
    southernmost_latitude:     -77.0105
    start_date:                2022-12-15
    stop_date:                 2022-12-15
    title:                     Global O

# 处理CCMP风速数据

In [5]:
folder_path = '/home/data2/pengguohang/My_Ocean/CCMP/ccmp_1993_201904_M'

file_name = sorted(
        [f for f in os.listdir(folder_path) if f.endswith('.nc') and f != 'CCMP_Wind_Analysis_climatology_V02.0_L3.5_RSS.nc']
    )
print(len(file_name))
print(file_name)

315
['CCMP_Wind_Analysis_199301_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199302_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199303_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199304_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199305_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199306_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199307_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199308_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199309_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199310_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199311_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199312_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199401_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199402_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199403_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199404_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199405_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199406_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199407_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199408_V02.0_L3.5_RSS.nc', 'CCMP_Wind_Analysis_199409_V02.0_L3.5_RSS.nc', 'CCMP_Wi