In [2]:
import copy
import torch.nn as nn
import torch
from FrEIA.framework import *
from FrEIA.modules import *
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
device = torch.device("cuda:0")

In [2]:
def load_data(num):
    all_data = []
    all_params = []
    all_length = []
    i = 0
    j = 2
    while(len(all_data)<num):
        name_dat = f"./data{j}/dat{i}.csv"
        name_param = f"./params{j}/params{i}.csv"
        try:
            df_dat = pd.read_csv(name_dat)
            df_param = pd.read_csv(name_param)
            param = df_param.to_numpy()[0,[0,3,4,5,6,7,8]]
            data = df_dat.iloc[:,2].to_numpy().reshape(-1,7)
            if len(data[np.isnan(data)]) == 0:
                all_data.append(torch.from_numpy(data).to(torch.float32).to(device))
                all_params.append(param)
                all_length.append(len(data))
            i += 1
        except:
            i = 0
            j += 1
        if j>8:
            print("not so much data")
            break
        if len(all_data)%1000 == 0:
            print("Loaded",len(all_data),"Entries")
    all_data = nn.utils.rnn.pad_sequence(all_data, batch_first = True).to(device).to(torch.float32)
    all_params = torch.from_numpy(np.array(all_params)).to(torch.float32).to(device)
    return all_data, all_params, all_length
data, params, length = load_data(10000)

means = params.mean(dim = 0)
stds = params.std(dim = 0)
stds[stds==0] = 1
params_normalized = (params-means)/stds

boarder = max(params_normalized.max(),-params_normalized.min())
maximal = boarder.item()+1e-6
params_pre = torch.arctanh(params_normalized/maximal) #get this from [-1.8,1.8] to (-infty, infty)

Loaded 1000 Entries
Loaded 2000 Entries
Loaded 3000 Entries
Loaded 4000 Entries
Loaded 5000 Entries
Loaded 6000 Entries
Loaded 7000 Entries
Loaded 8000 Entries
Loaded 9000 Entries
Loaded 10000 Entries


In [3]:
def get_linear_subnet(N, inp_size, hidden_size, out_size):
    layer_list = []
    layer_list.append(nn.Linear(inp_size, hidden_size))
    layer_list.append(nn.ReLU())
    for i in range(N-1):
        layer_list.append(nn.Linear(hidden_size, hidden_size))
        layer_list.append(nn.ReLU())
    layer_list.append(nn.Linear(hidden_size, out_size))
    return nn.Sequential(*layer_list)

def get_conv_subnet(N, inp_size, kernel = 3, stride = 1):
    layer_list = []
    for i in range(N-1):
        layer_list.append(nn.Conv1d(inp_size, inp_size, kernel, stride, groups = inp_size))
        layer_list.append(nn.ReLU())
    layer_list.append(nn.Conv1d(inp_size, inp_size, kernel, stride, groups = inp_size))
    return nn.Sequential(*layer_list)

class RNN(nn.Module):
    """"
    combination of convolutions and recurrent nets, used to treat the condition for our cINN beforehand.
    
    
    
    """
    def __init__(self, inp_size, hidden_size = 64, num_rnns = 5, lr = 1e-3):
        super(RNN, self).__init__()
        self.inp_size = inp_size
        #self.conv = get_conv_subnet(num_conv, inp_size).to(device)
        self.rnn = nn.LSTM(inp_size, hidden_size, num_rnns, batch_first = True, bidirectional = True).to(device)
        #self.linear = get_linear_subnet(num_linear, hidden_size*10*2, hidden_size, out_size).to(device)
        self.params_trainable = list(filter(
                lambda p: p.requires_grad, self.rnn.parameters())) 
        n_trainable = sum(p.numel() for p in self.params_trainable)
        print(f"Number of RNN parameters: {n_trainable}", flush=True)        
        self.optimizer = torch.optim.AdamW(
                self.params_trainable,
                lr = lr,
                betas =[0.9, 0.99],
                eps = 1e-6,
                weight_decay = 0
            )
        self.scheduler =  torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
            )        
        
    def forward(self, x):
        full, (last, cn) = self.rnn(x)
        return torch.swapaxes(last, 0, 1).reshape(last.shape[1], -1)
    
    def get_dim(self):
        return self.forward(torch.randn(1,100,self.inp_size, device = device)).shape[1]
        
    
class cINN(nn.Module):
    """
    cINN baseclass, using cubic spline blocks.
    
    
    """
    
    
    def __init__(self, inp_size, cond_size, num_blocks = 15, sub_layers = 3, sub_width = 128, lr = 5e-3):
        super(cINN, self).__init__()
        constructor_fct = lambda x_in, x_out: get_linear_subnet(sub_layers, 
                                                                x_in,
                                                                x_in,
                                                                x_out)

        block_kwargs = {
                        "num_bins": 60,
                        "subnet_constructor": constructor_fct,
                        "bounds_init": 10,
                        "permute_soft": True
                           }
        inp_size = (inp_size,)        
        nodes = [InputNode(*inp_size, name='inp')]
        cond_node = ConditionNode(*(cond_size,))
        for i in range(num_blocks):
            nodes.append(Node(
                    [nodes[-1].out0],
                    CubicSplineBlock,
                    block_kwargs,
                    conditions = cond_node,
                    name = f"block_{i}",
                    
                ))
        nodes.append(OutputNode([nodes[-1].out0], name='out'))
        nodes.append(cond_node)
        self.model = GraphINN(nodes, verbose=False).to(device)
        self.params_trainable = list(filter(
                lambda p: p.requires_grad, self.model.parameters()))
        n_trainable = sum(p.numel() for p in self.params_trainable)
        print(f"Number of cINN parameters: {n_trainable}", flush=True)
        
        self.optimizer = torch.optim.AdamW(
                self.params_trainable,
                lr = lr,
                betas =[0.9, 0.99],
                eps = 1e-6,
                weight_decay = 0
            )
        self.scheduler =  torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                verbose = True
            )
    def forward(self, x, cond = None):
        return self.model(x, c = cond)
    
class Estimator():
    """
    Wrapper for training cINN and RNN at the same time
    """
    def __init__(self, rnn, cinn):
        self.cinn = cinn
        self.rnn = rnn
    def train(self, epochs, xtrain, ytrain, length, batch_size):
        loss_curve = []
        for epoch in range(epochs):
            epoch_index = np.random.permutation(len(xtrain))
            epoch_losses = 0
            for i in range(len(xtrain)//batch_size):
                ysamps = nn.utils.rnn.pack_padded_sequence(ytrain[i*batch_size:(i+1)*batch_size], length[i*batch_size:(i+1)*batch_size], batch_first = True, enforce_sorted = False)
                xsamps = xtrain[i*batch_size:(i+1)*batch_size]
                self.cinn.optimizer.zero_grad()
                self.rnn.optimizer.zero_grad()
                cond = self.rnn(ysamps)
                gauss, jac = self.cinn(xsamps, cond)
                latent_loss = torch.mean(gauss**2/2) - torch.mean(jac)/gauss.shape[1]
                if latent_loss < 1e30:
                    latent_loss.backward()
                else:
                    print(f"loss is {latent_loss}")
                    return
                self.cinn.optimizer.step()
                self.rnn.optimizer.step()
                epoch_losses += latent_loss.item()/(len(xtrain)//batch_size)
            loss_curve.append(epoch_losses)
            self.cinn.scheduler.step(epoch_losses)
            self.rnn.scheduler.step(epoch_losses)
            print("Epoch:", epoch + 1)
            print("Loss:", epoch_losses)
        plt.plot(np.arange(len(loss_curve)),np.array(loss_curve))
        
    def inference(self, data_point, true_param):  #plot parameter estimates for a time series and the true param
        outputs = []
        for i in range(200):
            gauss = torch.randn(100,7).to(device)
            cond = self.rnn(data_point.repeat(100,1,1))
            output, _ = self.cinn(gauss, cond)
            outputs.append(output.detach().cpu())
        output = torch.cat(outputs, dim = 0)
        #output = torch.tanh(output)*maximal*stds.cpu()+means.cpu()
        fig, axis = plt.subplots(4,2, figsize = (10,15))
        axis[0,0].hist(output.numpy()[:,0], bins = 100, density = True)
        get_line(axis[0,0], true_param[0])
        axis[0,0].set_title("initial infected")
        axis[0,1].hist(output.numpy()[:,1], bins = 100, density = True)
        get_line(axis[0,1], true_param[1])
        axis[0,1].set_title("populations size")
        axis[1,0].hist(output.numpy()[:,2], bins = 100, density = True)
        get_line(axis[1,0], true_param[2])
        axis[1,0].set_title("contagion distance")
        axis[1,1].hist(output.numpy()[:,3], bins = 100, density = True)
        get_line(axis[1,1], true_param[3])
        axis[1,1].set_title("critical_limit")
        axis[2,0].hist(output.numpy()[:,4], bins = 100, density = True)
        get_line(axis[2,0], true_param[4])
        axis[2,0].set_title("amp_susc")
        axis[2,1].hist(output.numpy()[:,5], bins = 100, density = True)
        get_line(axis[2,1], true_param[5])
        axis[2,1].set_title("amp_rec")
        axis[3,0].hist(output.numpy()[:,6], bins = 100, density = True)
        get_line(axis[3,0], true_param[6])
        axis[3,0].set_title("amp_inf")
        
def get_line(ax, x):
    ax.axvline(x, color = "r")

In [4]:
rnn = RNN(7)
cinn = cINN(7, rnn.get_dim())

network = Estimator(rnn, cinn)
cond = network.train(1000, params_pre, nn.utils.rnn.pad_sequence(data, batch_first = True), length, batch_size = 200)

Number of RNN parameters: 434688
Number of cINN parameters: 22233150


RuntimeError: Node 'block_2': [(7,)] -> CubicSplineBlock -> [(7,)] encountered an error.

In [None]:
i = 15
network = Estimator(rnn, cinn)
network.inference(data[i], params[i])

FP stuff

In [None]:
counts = np.array([0, 881994, 623379, 474861, 340332, 1209080])
eff = np.array([95, 94.9, 95., 94.3, 92.8, 94.5])* 0.01
fact_up = np.zeros(6)
for i in range(6):
    fact_up[i] = counts[i]/np.sum(counts[i+1:])
    
fact_down = np.zeros(6)
for i in range(6):
    fact_down[i] = counts[i]*(1-eff[i])/eff[i]/np.sum(counts[i+1:])
    
print(fact_up)
print(fact_down)

In [6]:
counts = np.array([0, 1282720, 914352, 698344, 502520, 1802110])
eff = np.array([95, 94.9, 95., 94.3, 92.8, 94.5])* 0.01
fact_up = np.zeros(6)
for i in range(6):
    fact_up[i] = counts[i]/np.sum(counts[i+1:])
    
fact_down = np.zeros(6)
for i in range(6):
    fact_down[i] = counts[i]*(1-eff[i])/eff[i]/np.sum(counts[i+1:])
    
print(fact_up)
print(fact_down)

[0.         0.32744786 0.30448216 0.30301784 0.2788509         inf]
[0.         0.0175973  0.01602538 0.01831603 0.02163498        inf]


  fact_up[i] = counts[i]/np.sum(counts[i+1:])
  fact_down[i] = counts[i]*(1-eff[i])/eff[i]/np.sum(counts[i+1:])


In [4]:
counts = np.array([0, 1282720, 914352, 698344, 502520, 1802110])
eff = np.array([95, 94.9, 95., 94.3, 92.8, 94.5])* 0.01
fact_up = np.zeros(6)
for i in range(6):
    fact_up[i] = (counts[i]+np.sqrt(counts[i]))/(np.sum(counts[i+1:])-np.sqrt(np.sum(counts[i+1:])))
    
fact_down = np.zeros(6)
for i in range(6):
    fact_down[i] = counts[i]*(1-eff[i])/eff[i]/np.sum(counts[i+1:])
    
print(fact_up)
print(fact_down)

[0.         0.32790265 0.30497657 0.30358042 0.27945244        inf]
[0.         0.0175973  0.01602538 0.01831603 0.02163498        inf]


  fact_up[i] = (counts[i]+np.sqrt(counts[i]))/(np.sum(counts[i+1:])-np.sqrt(np.sum(counts[i+1:])))
  fact_down[i] = counts[i]*(1-eff[i])/eff[i]/np.sum(counts[i+1:])
