In [25]:
from map2map.models.uncertainty_quantification.d2d_uq import StyledVNet
from map2map.utils import load_model_state_dict, plt_slices
from map2map.data import FieldDataset
from map2map.models import narrow_cast, lag2eul
import argparse
import torch
from torch.utils.data import DataLoader
import os
from tqdm import tqdm

In [26]:
args = {'mode': 'train', 'in_norms': ['cosmology.dis'], 'tgt_norms': ['cosmology.dis'], 'crop': 32, 'crop_start': None, 'crop_stop': None, 'crop_step': None, 'in_pad': 48, 'tgt_pad': 48, 'scale_factor': 1, 'model': 'model.StyledVNet', 'criterion': 'MSELoss', 'load_state': 'd2d-forward-gnll.pt', 'load_state_strict': True, 'batch_size': 1, 'loader_workers': 2, 'callback_at': '/jet/home/lianga/search/search/map2map', 'misc_kwargs': {}, 'experiment_title': 'train-forward-GNLL', 'train_style_pattern': '/ocean/projects/cis230021p/lianga/quijote/LH0000/params.npy', 'train_in_patterns': ['/ocean/projects/cis230021p/lianga/quijote/LH0000/lin.npy'], 'train_tgt_patterns': ['/ocean/projects/cis230021p/lianga/quijote/LH0000/nonlin.npy'], 'val_style_pattern': '/ocean/projects/cis230021p/lianga/quijote/LH0005/params.npy', 'val_in_patterns': ['/ocean/projects/cis230021p/lianga/quijote/LH0005/lin.npy'], 'val_tgt_patterns': ['/ocean/projects/cis230021p/lianga/quijote/LH0005/nonlin.npy'], 'augment': False, 'aug_shift': None, 'aug_add': None, 'aug_mul': None, 'optimizer': 'Adam', 'lr': 0.0001, 'optimizer_args': {}, 'reduce_lr_on_plateau': False, 'scheduler_args': {'verbose': True}, 'init_weight_std': None, 'epochs': 1024, 'seed': 10620, 'div_data': False, 'div_shuffle_dist': 1, 'dist_backend': 'nccl', 'log_interval': 100, 'detect_anomaly': False, 'val': True, 'nodes': 1, 'gpus_per_node': 1, 'world_size': 1, 'dist_addr': 'tcp://v018.ib.bridges2.psc.edu:40579', 'style_size': 1, 'in_chan': [3], 'out_chan': [3]}

In [27]:
model = StyledVNet(args['style_size'], sum(args['in_chan']), sum(args['out_chan']), scale_factor=args['scale_factor'])
state = torch.load(
    "/jet/home/lianga/search/search/map2map/checkpoints/train-forward-GNLL_2023-04-09-23-01-30/state_3.pt",
    map_location=torch.device('cpu'),
)
load_model_state_dict(model, state['model'], strict=args['load_state_strict'])

In [28]:
train_dataset = FieldDataset(
    style_pattern=args['train_style_pattern'],
    in_patterns=args['train_in_patterns'],
    tgt_patterns=args['train_tgt_patterns'],
    in_norms=args['in_norms'],
    tgt_norms=args['tgt_norms'],
    callback_at=args['callback_at'],
    augment=args['augment'],
    aug_shift=args['aug_shift'],
    aug_add=args['aug_add'],
    aug_mul=args['aug_mul'],
    crop=args['crop'],
    crop_start=args['crop_start'],
    crop_stop=args['crop_stop'],
    crop_step=args['crop_step'],
    in_pad=args['in_pad'],
    tgt_pad=args['tgt_pad'],
    scale_factor=args['scale_factor'],
)
train_loader = DataLoader(
    train_dataset,
    batch_size=args['batch_size'],
    shuffle=False,
    num_workers=args['loader_workers'],
    pin_memory=True,
)

In [29]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

os.makedirs('figs', exist_ok=True)

for i, data in tqdm(enumerate(train_loader)):
    style, input, target = data['style'].to(device), data['input'].to(device), data['target'].to(device)
    
    mean, var = model(input, style)
    input, mean, target, var = narrow_cast(input, mean, target, var)

    lag_mean, lag_tgt, lag_var = mean, target, var
    eul_out, eul_tgt = lag2eul([lag_mean, lag_tgt])
    
    slice_idx = -1
    fig = plt_slices(
        input[slice_idx], lag_mean[slice_idx], lag_tgt[slice_idx], lag_mean[slice_idx] - lag_tgt[slice_idx],
                   eul_out[slice_idx], eul_tgt[slice_idx], eul_out[slice_idx] - eul_tgt[slice_idx],
                   lag_var[slice_idx],
        title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
                     'eul_out', 'eul_tgt', 'eul_out - eul_tgt',
                     'lag_var'],
    )
    fig.savefig(os.path.join('figs', f'slice_{i}.png'))

66it [01:01,  1.07it/s]


KeyboardInterrupt: 

In [34]:
input

tensor([[[[[ 0.9745,  1.0556,  1.1568,  ...,  0.2805,  0.1288,  0.0566],
           [ 0.7747,  0.9762,  1.1138,  ...,  0.2679,  0.0455, -0.0561],
           [ 0.4493,  0.8132,  0.9239,  ...,  0.0270, -0.0784, -0.0927],
           ...,
           [ 0.7639,  0.7409,  0.6902,  ..., -0.4463, -0.5132, -0.5167],
           [ 0.8334,  0.8130,  0.7648,  ..., -0.3403, -0.3742, -0.3760],
           [ 0.7380,  0.6615,  0.7230,  ..., -0.2895, -0.3181, -0.2887]],

          [[ 1.0615,  1.0478,  1.0943,  ...,  0.6389,  0.4537,  0.3259],
           [ 0.9073,  0.9003,  0.8473,  ...,  0.2549,  0.1904,  0.1382],
           [ 0.7531,  0.7893,  0.8416,  ..., -0.0422,  0.0461,  0.1228],
           ...,
           [ 0.5612,  0.5868,  0.5552,  ..., -0.3459, -0.3676, -0.2528],
           [ 0.4375,  0.5975,  0.5861,  ..., -0.3232, -0.2904, -0.2389],
           [ 0.4469,  0.6588,  0.7622,  ..., -0.3905, -0.2693, -0.2456]],

          [[ 1.1826,  1.0566,  1.0026,  ...,  0.8822,  0.7875,  0.5969],
           [ 1.

In [35]:
mean

tensor([[[[[ 9.8377e-01,  1.0645e+00,  1.1656e+00,  ...,  2.9304e-01,
             1.4153e-01,  6.9151e-02],
           [ 7.8357e-01,  9.8468e-01,  1.1222e+00,  ...,  2.8201e-01,
             5.9711e-02, -4.2398e-02],
           [ 4.5764e-01,  8.2063e-01,  9.3122e-01,  ...,  4.2322e-02,
            -6.2269e-02, -7.7903e-02],
           ...,
           [ 7.7036e-01,  7.4713e-01,  6.9663e-01,  ..., -3.9782e-01,
            -4.6711e-01, -4.7283e-01],
           [ 8.4021e-01,  8.1931e-01,  7.7036e-01,  ..., -2.9392e-01,
            -3.2929e-01, -3.3282e-01],
           [ 7.4479e-01,  6.6715e-01,  7.2833e-01,  ..., -2.4441e-01,
            -2.7389e-01, -2.4557e-01]],

          [[ 1.0702e+00,  1.0569e+00,  1.1029e+00,  ...,  6.5137e-01,
             4.6638e-01,  3.3879e-01],
           [ 9.1515e-01,  9.0869e-01,  8.5553e-01,  ...,  2.7046e-01,
             2.0616e-01,  1.5377e-01],
           [ 7.6102e-01,  7.9675e-01,  8.4870e-01,  ..., -2.2699e-02,
             6.5557e-02,  1.4106e-01],
 

In [42]:
target

tensor([[[[[ 1.0991e+00,  1.1633e+00,  1.2068e+00,  ...,  2.6758e-01,
             1.5770e-01,  8.4953e-02],
           [ 9.0403e-01,  1.0565e+00,  1.1087e+00,  ...,  1.6051e-01,
             4.1335e-02, -4.7587e-02],
           [ 6.6649e-01,  8.4853e-01,  9.3343e-01,  ...,  8.1100e-03,
            -7.2683e-02, -1.1300e-01],
           ...,
           [ 9.1705e-01,  7.8958e-01,  7.4200e-01,  ..., -3.7739e-01,
            -4.2950e-01, -4.3948e-01],
           [ 8.9365e-01,  8.4924e-01,  7.9606e-01,  ..., -2.9836e-01,
            -3.2393e-01, -3.0652e-01],
           [ 7.9895e-01,  7.3849e-01,  7.5548e-01,  ..., -2.6547e-01,
            -2.6906e-01, -1.7374e-01]],

          [[ 1.1338e+00,  1.1311e+00,  1.1464e+00,  ...,  5.0970e-01,
             3.8933e-01,  2.7259e-01],
           [ 9.9863e-01,  9.3900e-01,  9.6366e-01,  ...,  2.4459e-01,
             1.5288e-01,  6.9242e-02],
           [ 8.6493e-01,  8.8209e-01,  8.6387e-01,  ..., -3.4679e-03,
             2.0701e-03, -3.8903e-02],
 