In [7]:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional
import torch
import os
import sys
from omegaconf import OmegaConf
from lightning.pytorch import seed_everything, loggers as pl_loggers
from prediff.utils.download import (
    pretrained_kol_vae_name,
    pretrained_kol_earthformerunet_name 
    )

from prediff.utils.path import (

    default_pretrained_earthformerunet_dir,
    )

module_path = os.path.abspath(os.path.join('..', '..', 'prediff', 'kol'))
sys.path.append(module_path)
from train_kol_diffore import DifforekolPLModule
# TODO: these need to generalized further

In [8]:
class Args:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

# Example usage
args = Args(save='kol_roll_0108', 
            nodes=1,
            gpus=2, 
            cfg='rollout_kol_v1.yaml',
            )

In [9]:




if args.cfg is not None:
    oc_from_file = OmegaConf.load(open(args.cfg, "r"))
    dataset_cfg = OmegaConf.to_object(oc_from_file.dataset)
    total_batch_size = oc_from_file.optim.total_batch_size
    micro_batch_size = oc_from_file.optim.micro_batch_size
    max_epochs = oc_from_file.optim.max_epochs
    seed = oc_from_file.optim.seed
    float32_matmul_precision = oc_from_file.optim.float32_matmul_precision


torch.set_float32_matmul_precision(float32_matmul_precision)
seed_everything(seed, workers=True)
dm = DifforekolPLModule.get_kol_datamodule(
    dataset_cfg=dataset_cfg,
    micro_batch_size=micro_batch_size,
    num_workers=8, )
dm.prepare_data()
dm.setup()
print('dm setup complete')

global std, mean
std = dm.std
mean = dm.mean


train_dl = dm.train_dataloader()
accumulate_grad_batches = total_batch_size // (micro_batch_size * args.nodes * len(str(args.gpus).split(',')))
total_num_steps = DifforekolPLModule.get_total_num_steps(
    epoch=max_epochs,
    num_samples=dm.num_train_samples,
    total_batch_size=total_batch_size,
)



pl_module = DifforekolPLModule(
    total_num_steps=total_num_steps,
    save_dir=args.save,
    oc_file=args.cfg)

print('pl setup complete')





Seed set to 0
Seed set to 0


/home/users/nus/e1333861/PreDiff/scripts/rollout/kol
/home/users/nus/e1333861/PreDiff/scripts/rollout/kol


KeyboardInterrupt: 

In [None]:

earthformerunet_ckpt_path = os.path.join(default_pretrained_earthformerunet_dir,
                                                pretrained_kol_earthformerunet_name)
state_dict = torch.load(earthformerunet_ckpt_path,
                            map_location=torch.device("cpu"))
pl_module.torch_nn_module.load_state_dict(state_dict=state_dict)


<All keys matched successfully>

In [None]:
micro_batch_size = batch.shape[self.batch_axis]
data_idx = int(batch_idx * micro_batch_size)

target_seq, cond, context_seq = self.get_input(batch, return_verbose=True)

In [None]:




def rollout2d(
    model: torch.nn.Module,
    initial_u: torch.Tensor,
    pde: PDEDataConfig,
    time_history: int,
    num_steps: int,
):
    traj_ls = []
    pred = torch.Tensor()
    data_vector = torch.Tensor().to(device=initial_u.device)
    for i in range(num_steps):
        if i == 0:
            data_scalar = initial_u[:, :time_history]

            data = torch.cat((data_scalar, data_vector), dim=2)

        else:
            data = torch.cat((data, pred), dim=1)
            data = data[
                :,
                -time_history:,
            ]


        pred = model(data)
        traj_ls.append(pred)

    traj = torch.cat(traj_ls, dim=1)
    return traj


def cond_rollout2d(
    model: torch.nn.Module,
    initial_u: torch.Tensor,
    initial_v: torch.Tensor,
    delta_t: Optional[torch.Tensor],
    cond: Optional[torch.Tensor],
    grid: Optional[torch.Tensor],
    pde: PDEDataConfig,
    time_history: int,
    num_steps: int,
):
    traj_ls = []
    pred = torch.Tensor().to(device=initial_u.device)
    data_vector = torch.Tensor().to(device=initial_u.device)
    for i in range(num_steps):
        if i == 0:
            if pde.n_scalar_components > 0:
                data_scalar = initial_u[:, :time_history]
            if pde.n_vector_components > 0:
                data_vector = initial_v[
                    :,
                    :time_history,
                ]

            data = torch.cat((data_scalar, data_vector), dim=2)

        else:
            data = torch.cat((data, pred), dim=1)
            data = data[
                :,
                -time_history:,
            ]

        if grid is not None:
            data = torch.cat((data, grid), dim=1)

        if delta_t is not None:
            pred = model(data, delta_t, cond)
        else:
            pred = model(data, cond)
        traj_ls.append(pred)

    traj = torch.cat(traj_ls, dim=1)
    return traj


def rollout3d_maxwell(
    model: torch.nn.Module,
    initial_d: torch.Tensor,
    initial_h: torch.Tensor,
    time_history: int,
    num_steps: int,
):
    traj_ls = []
    pred = torch.Tensor()
    for i in range(num_steps):
        if i == 0:
            data = torch.cat((initial_d, initial_h), dim=2)
        else:
            data = torch.cat((data, pred), dim=1)  # along time
            data = data[
                :,
                -time_history:,
            ]

        pred = model(data)
        traj_ls.append(pred)

    traj = torch.cat(traj_ls, dim=1)
    return traj