In [82]:
import yaml
import glob
import torch
import numpy as np
import random
import tqdm
from credit.transforms import load_transforms

In [2]:
config = "small_multi/model.yml"

In [3]:
with open(config) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [4]:
train_batch_size = conf["trainer"]["train_batch_size"]
valid_batch_size = conf["trainer"]["valid_batch_size"]
thread_workers = conf["trainer"]["thread_workers"]
valid_thread_workers = (
    conf["trainer"]["valid_thread_workers"]
    if "valid_thread_workers" in conf["trainer"]
    else thread_workers
)

history_len = 1
forecast_len = 1

# datasets (zarr reader)

all_ERA_files = sorted(glob.glob(conf["data"]["save_loc"]))

train_years = [str(year) for year in range(1979, 2014)]
valid_years = [
    str(year) for year in range(2014, 2018)
]  # can make CV splits if we want to later on
test_years = [
    str(year) for year in range(2018, 2022)
]  # same as graphcast -- always hold out

# Filter the files for each set

train_files = [
    file for file in all_ERA_files if any(year in file for year in train_years)
]
valid_files = [
    file for file in all_ERA_files if any(year in file for year in valid_years)
]
test_files = [
    file for file in all_ERA_files if any(year in file for year in test_years)
]

In [55]:
from typing import Optional, Callable, List
from credit.data import (
    get_forward_data,
    generate_integer_list_around,
    flatten_list,
    find_key_for_number,
)


class MultiStepERA5(torch.utils.data.Dataset):
    def __init__(
        self,
        filenames: List[str] = [
            "/glade/derecho/scratch/wchapman/STAGING/TOTAL_2012-01-01_2012-12-31_staged.zarr",
            "/glade/derecho/scratch/wchapman/STAGING/TOTAL_2013-01-01_2013-12-31_staged.zarr",
        ],
        history_len: int = 1,
        forecast_len: int = 2,
        transform: Optional[Callable] = None,
        seed=42,
        skip_periods=None,
        one_shot=None,
        max_forecast_len=None,
        rank=0,
        world_size=1,
    ):
        self.history_len = history_len
        self.forecast_len = forecast_len
        self.transform = transform
        self.skip_periods = skip_periods
        self.one_shot = one_shot
        self.total_seq_len = self.history_len + self.forecast_len
        self.max_forecast_len = max_forecast_len
        self.rank = rank
        self.world_size = world_size
        np.random.seed(seed + rank)

        all_fils = []
        filenames = sorted(filenames)
        for fn in filenames:
            all_fils.append(get_forward_data(filename=fn))
        self.all_fils = all_fils
        self.data_array = all_fils[0]

        # Set data places
        indo = 0
        self.meta_data_dict = {}
        for ee, bb in enumerate(self.all_fils):
            self.meta_data_dict[str(ee)] = [
                len(bb["time"]),
                indo,
                indo + len(bb["time"]),
            ]
            indo += len(bb["time"]) + 1

        # Set out of bounds indexes...
        OOB = []
        for kk in self.meta_data_dict.keys():
            OOB.append(generate_integer_list_around(self.meta_data_dict[kk][2]))
        self.OOB = flatten_list(OOB)

        # Generate sequences based on rank and world_size
        self.sequence_indices = self.generate_sequences()
        self.forecast_hour = 0

    def generate_sequences(self):
        # Calculate the total length manually
        total_length = sum(
            len(bb["time"]) - self.total_seq_len + 1 for bb in self.all_fils
        )
        all_indices = list(range(total_length))

        chunk_size = len(all_indices) // self.world_size
        start_idx = self.rank * chunk_size
        end_idx = (
            start_idx + chunk_size
            if self.rank != self.world_size - 1
            else len(all_indices)
        )

        random.shuffle(all_indices)

        # Select the start times
        random_start_times = all_indices[start_idx:end_idx]
        sequence_indices = []

        for start_time in random_start_times:
            if start_time == 0:
                continue
            for i in range(self.forecast_len + 1):
                sequence_indices.append(start_time + i)

        return sequence_indices

    def __post_init__(self):
        # Total sequence length of each sample.
        self.total_seq_len = self.history_len + self.forecast_len

    def __len__(self):
        return len(self.sequence_indices)

    def is_end_of_forecast(self, index: int) -> bool:
        """
        Determine if the current index is the last index in a forecast sequence.

        Parameters:
            index (int): The current index in sequence_indices.

        Returns:
            bool: True if the current index is the last index in a forecast sequence, otherwise False.
        """
        # Get the index of the current position in the sequence_indices list
        current_pos = self.sequence_indices.index(index)

        # Check if it's the last index in the sequence
        if current_pos == len(self.sequence_indices) - 1:
            return 1

        # Determine if the next index starts a new forecast
        next_index = self.sequence_indices[current_pos + 1]
        if next_index - index != 1:
            return 1

        return 0

    def __getitem__(self, index):
        index = self.sequence_indices[index]
        # The rest of your existing __getitem__ implementation remains unchanged
        # find the result key:
        result_key = find_key_for_number(index, self.meta_data_dict)

        # get the data selection:
        true_ind = index - self.meta_data_dict[result_key][1]

        if true_ind > (
            len(self.all_fils[int(result_key)]["time"])
            - (self.history_len + self.forecast_len + 1)
        ):
            true_ind = len(self.all_fils[int(result_key)]["time"]) - (
                self.history_len + self.forecast_len + 1
            )

        datasel = self.all_fils[int(result_key)].isel(
            time=slice(true_ind, true_ind + self.history_len + self.forecast_len + 1)
        )

        historical_data = datasel.isel(time=slice(0, self.history_len)).load()
        target_data = datasel.isel(
            time=slice(self.history_len, self.history_len + 1)
        ).load()

        sample = {
            "historical_ERA5_images": historical_data,
            "target_ERA5_images": target_data,
            "datetime_index": [
                int(historical_data.time.values[0].astype("datetime64[s]").astype(int)),
                int(target_data.time.values[0].astype("datetime64[s]").astype(int)),
            ],
        }

        if self.transform:
            sample = self.transform(sample)

        sample["index"] = index
        sample["stop_forecast"] = self.is_end_of_forecast(index)
        sample["forecast_hour"] = self.forecast_hour
        sample["datetime_index"] = [
            int(historical_data.time.values[0].astype("datetime64[s]").astype(int)),
            int(target_data.time.values[0].astype("datetime64[s]").astype(int)),
        ]

        if sample["stop_forecast"]:
            self.forecast_hour = 0
        else:
            self.forecast_hour += 1

        return sample

In [62]:
dataset = MultiStepERA5(
    filenames=train_files,
    history_len=history_len,
    forecast_len=forecast_len,
    skip_periods=conf["data"]["skip_periods"],
    one_shot=conf["data"]["one_shot"],
    transform=load_transforms(conf),
    rank=0,
    world_size=2,
)

In [63]:
dataset.sequence_indices[:12]

[231353,
 231354,
 173529,
 173530,
 232329,
 232330,
 12921,
 12922,
 183689,
 183690,
 193970,
 193971]

In [25]:
# dataset.sequence_indices

In [64]:
sample = dataset.__getitem__(0)

In [65]:
print(sample["stop_forecast"], sample["datetime_index"])

0 [1116774000, 1116777600]


In [66]:
sample = dataset.__getitem__(1)

In [67]:
print(sample["stop_forecast"], sample["datetime_index"])

1 [1116777600, 1116781200]


In [68]:
sample = dataset.__getitem__(2)

In [69]:
print(sample["stop_forecast"], sample["datetime_index"])

0 [908632800, 908636400]


In [70]:
sample = dataset.__getitem__(3)

In [71]:
print(sample["stop_forecast"], sample["datetime_index"])

1 [908636400, 908640000]


In [72]:
dataset.sequence_indices[:10]

[231353, 231354, 173529, 173530, 232329, 232330, 12921, 12922, 183689, 183690]

In [35]:
# Define DataLoader parameters
batch_size = 2  # Adjust the batch size as needed
num_workers = 0  # Adjust the number of workers as needed for parallel data loading
shuffle = False  # Must be false! We will let the rank + seed determine randomness

# Create DataLoader
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=shuffle,
    num_workers=num_workers,
    pin_memory=True,  # Set to True if using CUDA to speed up the transfer of data to GPU
)

In [36]:
for sample in data_loader:
    print(sample["datetime"], sample["stop_forecast"])
    if sample["stop_forecast"]:
        break

tensor([288936000000000000, 969469200000000000]) tensor([1, 1])


RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [73]:
sample["datetime"]

908636400000000000

In [74]:
sample["stop_forecast"]

1

In [75]:
sample["forecast_hour"]

1

# Sequential dataset

In [116]:
class DistributedSequentialDataset(torch.utils.data.IterableDataset):
    # https://colab.research.google.com/drive/1OFLZnX9y5QUFNONuvFsxOizq4M-tFvk-?usp=sharing#scrollTo=CxSCQPOMHgwo

    def __init__(
        self,
        filenames,
        history_len,
        forecast_len,
        skip_periods,
        rank,
        world_size,
        shuffle=False,
        transform=None,
        rollout_p=0.0,
    ):
        self.dataset = ERA5Dataset(
            filenames=filenames,
            history_len=history_len,
            forecast_len=forecast_len,
            skip_periods=skip_periods,
            transform=transform,
        )
        self.meta_data_dict = self.dataset.meta_data_dict
        self.all_fils = self.dataset.all_fils
        self.history_len = history_len
        self.forecast_len = forecast_len
        self.filenames = filenames
        self.transform = transform
        self.rank = rank
        self.world_size = world_size
        self.shuffle = shuffle
        self.skip_periods = skip_periods
        self.current_epoch = 0
        self.rollout_p = rollout_p

    def __len__(self):
        tlen = 0
        for bb in self.all_fils:
            tlen += len(bb["time"]) - self.forecast_len
        return tlen

    def set_epoch(self, epoch):
        self.current_epoch = epoch

    def __iter__(self):
        worker_info = get_worker_info()
        num_workers = worker_info.num_workers if worker_info is not None else 1
        worker_id = worker_info.id if worker_info is not None else 0
        sampler = DistributedSampler(
            self,
            num_replicas=num_workers * self.world_size,
            rank=self.rank * num_workers + worker_id,
            shuffle=self.shuffle,
        )
        sampler.set_epoch(self.current_epoch)

        for index in iter(sampler):
            result_key = find_key_for_number(index, self.meta_data_dict)
            true_ind = index - self.meta_data_dict[result_key][1]

            if true_ind > (
                len(self.all_fils[int(result_key)]["time"])
                - (self.history_len + self.forecast_len + 1)
            ):
                true_ind = len(self.all_fils[int(result_key)]["time"]) - (
                    self.history_len + self.forecast_len + 3
                )

            indices = list(
                range(true_ind, true_ind + self.history_len + self.forecast_len)
            )
            stop_forecast = False

            for k, ind in enumerate(indices):
                concatenated_samples = {
                    "x": [],
                    "x_surf": [],
                    "y": [],
                    "y_surf": [],
                    "static": [],
                    "TOA": [],
                }
                sliced = xr.open_zarr(
                    self.filenames[int(result_key)]
                ).isel(
                    time=slice(
                        ind,
                        ind + self.history_len + self.forecast_len + 1,
                        self.skip_periods,
                    )
                )

                historical_data = sliced.isel(time=slice(0, self.history_len)).load()
                target_data = sliced.isel(
                    time=slice(self.history_len, self.history_len + 1)
                ).load()

                sample = {
                    "x": historical_data,
                    "y": target_data,
                    "t": [
                        int(
                            historical_data.time.values[0]
                            .astype("datetime64[s]")
                            .astype(int)
                        ),
                        int(
                            target_data.time.values[0]
                            .astype("datetime64[s]")
                            .astype(int)
                        ),
                    ],
                }

                if self.transform:
                    sample = self.transform(sample)

                for key in concatenated_samples.keys():
                    concatenated_samples[key] = sample[key].squeeze()

                stop_forecast = k == self.forecast_len

                concatenated_samples["forecast_hour"] = k
                concatenated_samples["index"] = index
                concatenated_samples["stop_forecast"] = stop_forecast
                concatenated_samples["datetime"] = [
                    int(
                        historical_data.time.values[0]
                        .astype("datetime64[s]")
                        .astype(int)
                    ),
                    int(target_data.time.values[0].astype("datetime64[s]").astype(int)),
                ]

                yield concatenated_samples

                if stop_forecast:
                    break

                if k == self.forecast_len:
                    break

In [117]:
from credit.data import ERA5Dataset
from torch.utils.data import get_worker_info
from torch.utils.data.distributed import DistributedSampler
import xarray as xr

In [118]:
test_dataset = DistributedSequentialDataset(
    filenames=train_files,
    history_len=1,
    forecast_len=1,
    skip_periods=conf["data"]["skip_periods"],
    transform=load_transforms(conf),
    rank=0,
    world_size=1,
    shuffle=False,
)

In [119]:
loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, num_workers=0)
# if the batch size > 1, need to loop over data below, as batch size ~ batched length of sequence
# batch size with different samples is handled by distributed DDP or FSDP training

k = 0
for batch in tqdm.tqdm(loader):
    # y = concat_and_reshape(data["y"].squeeze(1), data["y_surf"].squeeze(1))
    # Plotting y
    # plt.figure(figsize=(10, 4))
    # plt.subplot(1, 2, 1)
    # plt.pcolor(y[0, 66, :, :].to('cpu').numpy(), cmap='RdBu', vmin=-3, vmax=3)
    # plt.colorbar()
    # plt.gca().invert_yaxis()
    # plt.title(f'Forecast hour {data["forecast_hour"].item()}')
    # plt.show()

    # print(f'A {batch["forecast_hour"].item()}')

    # if batch['stop_forecast']:
    #     stop_forecast = True
    #     break

    # if conf["data"]["history_len"] == i:
    #     stop_forecast = True
    #     break
    print(batch["stop_forecast"], batch["datetime"])

    k += 1

    if k == 10:
        break

  0%|          | 1/306781 [00:09<849:10:58,  9.96s/it]

tensor([False]) [tensor([283996800]), tensor([284000400])]


  0%|          | 2/306781 [00:19<844:03:20,  9.90s/it]

tensor([True]) [tensor([284000400]), tensor([284004000])]


  0%|          | 3/306781 [00:29<841:46:37,  9.88s/it]

tensor([False]) [tensor([284000400]), tensor([284004000])]


  0%|          | 4/306781 [00:39<843:10:46,  9.89s/it]

tensor([True]) [tensor([284004000]), tensor([284007600])]


  0%|          | 5/306781 [00:49<844:35:50,  9.91s/it]

tensor([False]) [tensor([284004000]), tensor([284007600])]


  0%|          | 6/306781 [00:59<844:01:47,  9.90s/it]

tensor([True]) [tensor([284007600]), tensor([284011200])]


  0%|          | 7/306781 [01:09<842:54:09,  9.89s/it]

tensor([False]) [tensor([284007600]), tensor([284011200])]


  0%|          | 8/306781 [01:19<843:11:34,  9.89s/it]

tensor([True]) [tensor([284011200]), tensor([284014800])]


  0%|          | 9/306781 [01:29<842:30:50,  9.89s/it]

tensor([False]) [tensor([284011200]), tensor([284014800])]


  0%|          | 9/306781 [01:38<936:53:52, 10.99s/it]

tensor([True]) [tensor([284014800]), tensor([284018400])]





In [101]:
batch["stop_forecast"]

tensor([False])