In [1]:
import netCDF4 as nc
import numpy as np
import torch
from torch.utils.data import Dataset

class SoilMoistureDataset(Dataset):
    def __init__(self, nc_file, lat, lon, sequence_length=30):
        self.nc_file = nc_file
        self.lat = lat
        self.lon = lon
        self.sequence_length = sequence_length
        self.data = self.load_data()

    def load_data(self):
        dataset = nc.Dataset(self.nc_file)
        # 假设数据集中包含 'soil_moisture' 和 'precipitation' 两个变量
        soil_moisture = dataset.variables['soil_moisture'][:]
        precipitation = dataset.variables['precipitation'][:]
        
        # 提取指定经纬度的逐日数据
        lat_idx = np.abs(dataset.variables['lat'][:] - self.lat).argmin()
        lon_idx = np.abs(dataset.variables['lon'][:] - self.lon).argmin()
        soil_moisture_data = soil_moisture[:, lat_idx, lon_idx]
        precipitation_data = precipitation[:, lat_idx, lon_idx]
        
        # 可以进行数据清洗和预处理
        data = np.stack([soil_moisture_data, precipitation_data], axis=1)
        return data

    def __len__(self):
        return len(self.data) - self.sequence_length

    def __getitem__(self, idx):
        # 返回输入序列和目标值
        x = self.data[idx:idx + self.sequence_length, 1]  # 输入为降水量
        y = self.data[idx + self.sequence_length, 0]  # 目标为土壤含水量
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)
