# Setup:

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import torch
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as L
import torch.nn.functional as F
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

from lightning_utils import *
from utils import *
from MOR_Operator import MOR_Operator
from POU_net import POU_net
import JHTDB_sim_op

In [None]:
from model_eval_backend import *

ablation_models = ['Everything-large', # 0
                   'Everything-small', # 1
                   'Control-small', # 2
                   'Control-No-Output-LN', # 3
                   'MixtureOfExperts-normalized-small', # 4
                   'MixtureOfExperts-small', # 5
                   'RecursiveSteps-small', # 6
                   'CNN-small-k4', # 7
                   'Everything-small-k4', # 8
                   'WNO-800-epochs', # 9
                   'Everything-WNO-stride-800-epochs', #10
                  ]
ablation_model_index = 1
model_name = ablation_models[ablation_model_index]

optional_model_kwd_args = {}
if 'CNN' in model_name:
    print('using CNN experts')
    optional_model_kwd_args['make_expert'] = CNN
    optional_model_kwd_args['skip_connections']=True

model = load_model(f'./lightning_logs/paper/{model_name}/last.ckpt', **optional_model_kwd_args)

In [None]:
val_dataset = JHTDB_sim_op.JHTDB_Channel('data/turbulence_output', time_chunking=5)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=5, num_workers=16, shuffle=True)

# Predictive Simulation Based Calibration (aka PSBC)

In [None]:
import model_agnostic_BNN
from model_agnostic_BNN import get_BNN_pred_distribution, clear_cache

# Adapted to create a pytorch mixture distribution
def get_BNN_pred_GMM_distribution(bnn_model, x_input, n_samples=100, **kwd_args):
    '''
    If you just want moments use get_BNN_pred_moments() instead as it is *much* more memory efficient (e.g. for large sample sizes).
    But this is still useful if you want an actual distribution.
    '''
    
    # get sample aleatoric distributions from sampling epistemic distribution
    pred_distribution = get_BNN_pred_distribution(model, inputs, n_samples=n_samples_per_batch)
    pred_distribution = torch.distributions.normal.Normal(*pred_distribution, validate_args=False)
    mix = torch.distributions.Categorical(torch.ones(len(pred_distribution[0]),)) # uniform weights
    pred_distribution = torch.distributions.MixtureSameFamily(mix, pred_distribution)
    return pred_distribution

import numpy as np
dumb_vmap = lambda func: lambda X: torch.stack([func(x) for x in X])
# useful for debugging when vmap won't work

## ALeatoric KDE Equations:

1. We want to take the union of the centroid pdfs so we sum. But then we need to make the pdf integrate to 1 $\int_{x \in \Omega} p(x) dx=1$, so we divide by the number of pdfs resulting in an average: $$p(S_{jk})={1\over N}\sum^N_{i=0}p_{ik}(S_{jk}) \text{ s.t. } S_{jk}\sim p_{jk}$$
2. However all these computations need to happen in the log domain for numerical stability: $$log(p(S_{jk}))=log({1\over N}\sum^N_{i=0}p_{ik}(S_{jk}))=log(\sum^N_{i=0}p_{ik}(S_{jk}))-log(N)$$
4. In practice this requires the Log-sum-exp trick (aka LSE): $$log(p(S_{jk}))={LSE}^N_{i=0}log(p_{ik}(S_{jk}))-log(N)$$
5. Then we need to get the joint pdf across all the spatial dimension(s) (indexed by k) $$log(p(S_j))=\sum_k log(p(S_{jk}))$$

### Mixture Distribution Code Example:

```
mix = D.Categorical(torch.ones(5,))
comp = D.Normal(torch.randn(5,), torch.rand(5,))
gmm = MixtureSameFamily(mix, comp)
```

In [None]:
# GOTCHA: It seems the problem is that it always gives more liklihood to the true value than the samples? Maybe just sample more?
# That wouldn't be all bad if it were actually true it would imply that the true value always has the highest likelihood but it can't be literally true...
def _find_batch_BCI_truth_quantiles(model, inputs, outputs, n_samples_per_batch, chunk_size=25, fake_ideal=False, verbose=True):
    with torch.inference_mode():
        # get sample aleatoric distributions from sampling epistemic distribution
        pred_distribution = get_BNN_pred_distribution(model, inputs, n_samples=n_samples_per_batch)
        pred_distribution = torch.distributions.normal.Normal(*pred_distribution, validate_args=False)
        pred_samples = pred_distribution.sample() # then sample the actual predictions from the sampled aleatoric distributions
        if verbose: print('pred_samples.shape:', pred_samples.shape) # shape==[aleatoric_sample, batch, ...]

        def get_log_density(sample_datum): # does this still work with the mixture?
            ''' gets log density of a single sample given the aleatoric distributions sampled from the epistemic weights '''
            # averaged_pdfs = p(S_jk) = (1/N)∑_i(p_ik(S_jk)) s.t. S_jk ~ p_jk := KDE-style pdf (derived from aleatoric distributions)
            averaged_pdfs = torch.logsumexp(pred_distribution.log_prob(sample_datum), dim=0) - np.log(n_samples_per_batch) # average across epistemic dimension
            joint_pdfs = torch.vmap(torch.sum)(averaged_pdfs) # joint pdf across all non-batch dims := log(∏_ip_i(S_i))
            return joint_pdfs

        # vmap supports additional input dimension: j (aka aleatoric sample dimension)
        #vget_log_density = dumb_vmap(get_log_density)
        vget_log_density = torch.vmap(get_log_density, chunk_size=chunk_size)
        pred_samples = pred_samples[:, None] # move the aleatoric sample dimension out of the epistemic sample distribution dimension
        assert tuple(pred_samples.shape)[:3]==(n_samples_per_batch, 1, inputs.shape[0])
        pred_joint_log_pdfs = vget_log_density(pred_samples) # shape==[aleatoric_sample, batch]
        assert tuple(pred_joint_log_pdfs.shape)==(n_samples_per_batch, inputs.shape[0])
        
        import random # Verified that fake ideal works: 10/10/24
        if fake_ideal: # artificially simulate the ideal case where outputs are sampled from prediction distribution
            outputs = random.choice(pred_samples) # GOTCHA: isn't realized with small number of batches & sample sizes!

        truth_joint_log_pdf = get_log_density(outputs[None])
        assert tuple(truth_joint_log_pdf.shape)==(inputs.shape[0],)

        if verbose:
            # plot the pdf quantile distribution(s)
            batch_display_ids = random.choices(range(inputs.shape[0]), k=1)
            for i in batch_display_ids:
                plt.hist(pred_joint_log_pdfs[:,i].cpu(), color='blue')
                plt.axvline(truth_joint_log_pdf[i].item(), color='red')
                plt.title(f'{i}th joint-log-pdf Distribution')
                plt.show()

        pdf_comparison = pred_joint_log_pdfs<=truth_joint_log_pdf
        if verbose: print(f'pdf_comparison={list(pdf_comparison.ravel().cpu().numpy())}')
        truth_quantiles = torch.sum(pdf_comparison, dim=0)/pdf_comparison.shape[0]
        assert pred_joint_log_pdfs.shape[0]==pdf_comparison.shape[0]==n_samples_per_batch
        return truth_quantiles.cpu().detach()

# Actually we don't need KDE! We can use the aleatoric uncertainty to get density directly!
def find_BCI_truth_quantiles(model, data_loader, n_batches=100, n_samples_per_batch=25, chunk_size=25, fake_ideal=False, verbose=True):
    """
    These quantiles should follow q~U(0,1) in order for BCI theory to be satisfied.
    You can simulate the ideal case as a sanity check with fake_ideal=True.
    GOTCHA: In practice with small batch and/or sample sizes even with fake_ideal=True, the distribution will not look uniform!
    """
    import matplotlib.pyplot as plt
    assert model.device.type=='cuda' # TOO slow on cpu
    truth_quantiles = []
    while len(truth_quantiles)<n_batches:
        for inputs, outputs in data_loader:
            if len(truth_quantiles)==0: print(f'{inputs.shape=}, {outputs.shape=}')
            elif len(truth_quantiles)==n_batches: break
            print(f'processing batch: {len(truth_quantiles)}')
            truth_quantiles.append(_find_batch_BCI_truth_quantiles(model, inputs.to(model.device), outputs.to(model.device),
                                                                   n_samples_per_batch=n_samples_per_batch, fake_ideal=fake_ideal,
                                                                   chunk_size=chunk_size, verbose=verbose))

    truth_quantiles = torch.cat(truth_quantiles)
    display_quantiles = list(torch.quantile(truth_quantiles, q=torch.linspace(0.0,1.0, steps=5)).numpy())
    print('1/5th quantiles of truth quantile distribution: ', display_quantiles)
    plt.hist(truth_quantiles)
    plt.title('Truth Quantiles (Should follow q~U(0,1))')
    plt.show()
    return truth_quantiles

In [None]:
%pdb on
truth_quantiles = find_BCI_truth_quantiles(model, val_data_loader, n_batches=10, n_samples_per_batch=25, fake_ideal=False, verbose=True)

In [None]:
np.savetxt('truth_quantiles.txt', truth_quantiles.numpy())

## Posterior-Predictive Check

$PPC_{LL}=log(p(D_{new}|D))=log(\int p(D_{new}|\theta)*p(\theta|D)d\theta)$

$=log(E_\theta[p(D_{new})])$

In [None]:
from model_agnostic_BNN import log_posterior_predictive_check

## Scratch Space for Developing and Debugging:

In [None]:
data_iter = iter(val_data_loader)
inputs, outputs = next(data_iter)
print(f'{inputs.shape=}, {outputs.shape=}')

In [None]:
pred_mu, pred_sigma = torch.randn(5), torch.rand(5) #get_BNN_pred_distribution(model, inputs.to(model.device), n_samples=10)
pred_distribution = torch.distributions.normal.Normal(pred_mu, pred_sigma)
print(f'{pred_distribution=}')
pred_distribution

In [None]:
samples = pred_distribution.sample()

In [None]:
samples = samples[:, None]

In [None]:
print(f'{outputs.shape=}')
print(f'{samples.shape=}')
print(f'{pred_distribution.loc.shape=}')

In [None]:
ll = pred_distribution.log_prob(outputs.to(model.device))
print(f'{ll.shape=}')

In [None]:
torch.distributions.normal.Normal(*pred_distribution, validate_args=False)