In [3]:
import sys
sys.path.append('/home/ajliang/search')

In [4]:
from search.map2map.map2map.data import FieldDataset
from search.map2map.map2map.models import StyledVNet
from search.map2map.map2map.utils import load_model_state_dict
from torch.utils.data import DataLoader
import torch

In [16]:
dataset = FieldDataset(
    style_pattern="/user_data/ajliang/ood_data/in_distribution/train/*/params.npy",
    in_patterns=["/user_data/ajliang/ood_data/in_distribution/train/*/lin.npy"],
    tgt_patterns=["/user_data/ajliang/ood_data/in_distribution/train/*/nonlin.npy"],
    in_norms=['cosmology.dis'],
    tgt_norms=['cosmology.dis'],
    crop=32,
    in_pad=48,
    scale_factor=1,
)
loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
)

In [49]:
model = StyledVNet(
    dataset.style_size,
    sum(dataset.in_chan),
    sum(dataset.tgt_chan),
    dropout_prob=0.1,
    scale_factor=1.0,
)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model.to(device)
state = torch.load("/home/ajliang/search/model_weights/paper_fwd_d2d_weights.pt", map_location=device)
load_model_state_dict(model, state["model"], strict=True)
model.eval()

StyledVNet(
  (conv_l00): ResStyledBlock(
    (convs): ModuleList(
      (0): ConvStyled3d()
      (1): LeakyReLUStyled(negative_slope=0.01)
      (2): ConvStyled3d()
    )
    (act): LeakyReLUStyled(negative_slope=0.01)
    (skip): ConvStyled3d()
  )
  (conv_l01): ResStyledBlock(
    (convs): ModuleList(
      (0): ConvStyled3d()
      (1): LeakyReLUStyled(negative_slope=0.01)
      (2): ConvStyled3d()
    )
    (act): LeakyReLUStyled(negative_slope=0.01)
    (skip): ConvStyled3d()
  )
  (down_l0): ConvStyledBlock(
    (convs): ModuleList(
      (0): ConvStyled3d()
      (1): LeakyReLUStyled(negative_slope=0.01)
    )
  )
  (conv_l1): ResStyledBlock(
    (convs): ModuleList(
      (0): ConvStyled3d()
      (1): LeakyReLUStyled(negative_slope=0.01)
      (2): ConvStyled3d()
    )
    (act): LeakyReLUStyled(negative_slope=0.01)
    (skip): ConvStyled3d()
  )
  (down_l1): ConvStyledBlock(
    (convs): ModuleList(
      (0): ConvStyled3d()
      (1): LeakyReLUStyled(negative_slope=0.01)
 

In [50]:
def estimate_uncertainty(
    model: torch.nn.Module, style: torch.Tensor, input: torch.Tensor,
    sample_size: int,
) -> torch.Tensor:
    outputs = None
    for s in range(sample_size):
        out = model(input, style)
        if outputs is None:
            outputs = torch.zeros((sample_size, *out.shape), device=out.device)
        outputs[s] = out

    variance = torch.var(outputs, dim=0)
    return variance.sum()

In [51]:
sample_size = 10
variances = torch.zeros(len(dataset))
with torch.no_grad():
    model.eval()
    for i, data in enumerate(loader):
        style, input = data["style"].to(device), data["input"].to(device)
        variances[i] = estimate_uncertainty(model, style, input, sample_size)
        print(f"Sample {i}: variance = {variances[i]}")

print(f"Mean variance: {variances.mean()}")

Sample 0: variance = 25378.0234375
Sample 1: variance = 23937.65625
Sample 2: variance = 25636.259765625


KeyboardInterrupt: 

: 

In [26]:
output1 = model(input, style)

In [27]:
output2 = model(input, style)

In [28]:
output1

tensor([[[[[ 1.1536e+00,  1.1919e+00,  1.1468e+00,  ..., -2.1987e+00,
            -2.3011e+00, -2.2412e+00],
           [ 7.4561e-01,  9.9035e-01,  8.0885e-01,  ..., -1.3515e+00,
            -1.8411e+00, -1.6818e+00],
           [ 9.7809e-01,  4.2845e-01,  1.5020e+00,  ..., -1.7707e+00,
            -1.2909e+00, -2.0853e+00],
           ...,
           [ 4.7567e-01, -1.0833e+00,  1.6250e-01,  ..., -1.2901e+00,
            -2.7419e-01, -6.2789e-01],
           [ 9.3112e-01,  9.6601e-01,  1.4697e+00,  ..., -2.1677e-01,
             4.6267e-01, -3.7097e-01],
           [ 2.5325e-01,  7.2509e-01,  9.3584e-01,  ..., -5.5512e-01,
             6.6332e-02,  3.7641e-01]],

          [[ 9.1398e-01,  2.2814e+00,  8.3624e-01,  ..., -1.9893e+00,
            -1.2379e+00, -2.2314e+00],
           [ 8.2152e-01,  1.5147e+00,  1.0039e+00,  ..., -2.0413e+00,
            -2.0080e+00, -2.8232e+00],
           [ 6.2034e-02,  1.3301e+00,  7.6198e-01,  ..., -1.7996e+00,
            -3.7082e+00, -2.4617e+00],
 

In [29]:
output2

tensor([[[[[ 6.2782e-01,  3.6844e-01,  9.9919e-01,  ..., -1.7745e+00,
            -1.0183e+00, -8.9470e-01],
           [ 1.2842e+00,  9.2313e-01,  7.2574e-01,  ..., -1.1970e+00,
            -1.8518e+00, -9.7558e-01],
           [ 6.9994e-01,  1.5541e+00,  7.6533e-01,  ..., -2.2134e+00,
            -1.7936e+00, -1.5716e+00],
           ...,
           [-7.5672e-01,  2.7392e-01,  7.0801e-01,  ...,  7.5141e-01,
            -4.0582e-01,  6.3688e-01],
           [-4.0818e-01, -2.4078e-01,  6.9521e-01,  ..., -4.9713e-01,
            -6.9364e-01,  4.4039e-01],
           [ 1.0248e+00,  1.0572e+00, -8.2437e-02,  ...,  8.6173e-01,
            -6.5210e-02,  6.3113e-01]],

          [[ 1.6156e+00,  9.3211e-01,  1.3409e+00,  ..., -1.2184e+00,
            -1.2129e+00, -1.8498e+00],
           [ 1.5945e+00,  1.2969e+00,  1.2960e+00,  ..., -4.4677e-01,
            -1.0468e+00, -2.1378e+00],
           [ 1.4205e+00,  9.5841e-01,  2.1208e-01,  ..., -2.0980e+00,
            -1.8503e+00, -2.1983e+00],
 

In [31]:
output_stacked = torch.stack((output1, output2))

In [33]:
output_stacked.shape

torch.Size([2, 1, 3, 64, 64, 64])

In [35]:
torch.var(output_stacked, dim=0).shape

torch.Size([1, 3, 64, 64, 64])

In [39]:
torch.var(output_stacked, dim=0).sum()

tensor(237333.5781, grad_fn=<SumBackward0>)

In [38]:
from statistics import variance

variance([1.1536e+00, 6.2782e-01]), variance([1.1919e+00, 3.6844e-01])

(0.13822230419999995, 0.3390431858)

In [47]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
import os

list(os.listdir('/user_data/ajliang/ood_data/in_distribution/train'))

['LH0010',
 'LH0022',
 'LH0004',
 'LH0000',
 'LH0021',
 'LH0014',
 'LH0042',
 'LH0035',
 'LH0023',
 'LH0002',
 'LH0030',
 'LH0043',
 'LH0012',
 'LH0018']