In [50]:
"""
To launch all the tasks, create tmux sessions (separately for each of the following) 
and run (for instance):

python canvi_sbibm.py --task two_moons --cuda_idx 0
python canvi_sbibm.py --task slcp --cuda_idx 1
python canvi_sbibm.py --task gaussian_linear_uniform --cuda_idx 2
python canvi_sbibm.py --task bernoulli_glm --cuda_idx 3
python canvi_sbibm.py --task gaussian_mixture --cuda_idx 4
python canvi_sbibm.py --task gaussian_linear --cuda_idx 5
python canvi_sbibm.py --task slcp_distractors --cuda_idx 6
python canvi_sbibm.py --task bernoulli_glm_raw --cuda_idx 7
"""

import pandas as pd
import numpy as np
import sbibm
import torch
import math
import torch.distributions as D
import matplotlib.pyplot as plt

from pyknos.nflows import flows, transforms
from functools import partial
from typing import Optional
from warnings import warn

from pyknos.nflows import distributions as distributions_
from pyknos.nflows import flows, transforms
from pyknos.nflows.nn import nets
from pyknos.nflows.transforms.splines import rational_quadratic
from torch import Tensor, nn, relu, tanh, tensor, uint8

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['text.usetex'] = True
mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'
mpl.rcParams['text.latex.preamble'] = r'\usepackage{amsfonts}'

sns.set_theme()

from sbi.utils.sbiutils import (
    standardizing_net,
    standardizing_transform,
    z_score_parser,
)
from sbi.utils.torchutils import create_alternating_binary_mask
from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device

import os
import pickle
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [51]:
sbibm.get_available_tasks()

['sir',
 'two_moons',
 'slcp',
 'gaussian_linear_uniform',
 'lotka_volterra',
 'bernoulli_glm',
 'gaussian_mixture',
 'gaussian_linear',
 'slcp_distractors',
 'bernoulli_glm_raw']

In [52]:
task = "two_moons"
task = sbibm.get_task(task)
prior = task.get_prior_dist()
simulator = task.get_simulator()

In [80]:
# sample just for shaping network
y = prior.sample((1,))
x = simulator(y)

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(x.shape[-1], 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, y.shape[-1])
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [81]:
model = SimpleModel()
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
 
n_epochs = 1_000
sims_per_epoch = 100

In [84]:
losses = []
for epoch in range(n_epochs):
    y = prior.sample((sims_per_epoch,))
    x = simulator(y)

    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    losses.append(loss.detach().numpy())
        
    # backward pass
    optimizer.zero_grad()
    loss.backward()
    # update weights
    optimizer.step()
    
    # print progress
    if epoch % 100:
        print(f"epoch {epoch} loss {loss}")

epoch 1 loss 0.1552262157201767
epoch 2 loss 0.1442476212978363
epoch 3 loss 0.16369205713272095
epoch 4 loss 0.20309671759605408
epoch 5 loss 0.15342877805233002
epoch 6 loss 0.1969963014125824
epoch 7 loss 0.19185087084770203
epoch 8 loss 0.18020126223564148
epoch 9 loss 0.17481307685375214
epoch 10 loss 0.14098061621189117
epoch 11 loss 0.15761274099349976
epoch 12 loss 0.12570105493068695
epoch 13 loss 0.15111738443374634
epoch 14 loss 0.18157269060611725
epoch 15 loss 0.19626674056053162
epoch 16 loss 0.1614682376384735
epoch 17 loss 0.1714898645877838
epoch 18 loss 0.1635822355747223
epoch 19 loss 0.17118316888809204
epoch 20 loss 0.16000030934810638
epoch 21 loss 0.16875670850276947
epoch 22 loss 0.1426895558834076
epoch 23 loss 0.14582742750644684
epoch 24 loss 0.21931053698062897
epoch 25 loss 0.19048738479614258
epoch 26 loss 0.18620361387729645
epoch 27 loss 0.17619384825229645
epoch 28 loss 0.15621906518936157
epoch 29 loss 0.1995466947555542
epoch 30 loss 0.160356119275093

In [86]:
cal_y = prior.sample((500,))
cal_x = simulator(cal_y)
cal_y_pred = model(cal_x)
cal_scores = ((cal_y - cal_y_pred) ** 2).mean(axis=1).detach().numpy()

alpha = 0.05
desired_coverage = 1 - alpha
quantile = np.quantile(cal_scores, q = desired_coverage)

In [91]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

model.fc3.register_forward_hook(get_activation('fc3'))
gen_train_y = prior.sample((500,))
gen_train_x = simulator(gen_train_y)
output = model(gen_train_x)
print(activation['fc3'].shape)

torch.Size([500, 32])
