In [1]:
import os
os.environ["KERAS_BACKEND"] = "torch"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import keras
import sys
sys.path.append("../")

# import bayesflow as bf

import torch
from torch import Tensor
import torch.nn as nn

After loading bayesflow, the computational graph collapses.

In [2]:
class DeepAdaptiveDesign(nn.Module):
  def __init__(
      self,
      # encoder_net: nn.Module, # same summary for bf and dad or different?
      decoder_net: nn.Module,
      design_shape: torch.Size, # [xi_dim]
      summary_variables: list[str] = None # in case of using summary_net from bf
    ) -> None:
    super().__init__()
    self.design_shape = design_shape
    self.register_parameter(
        "initial_design",
        nn.Parameter(0.1 * torch.ones(design_shape, dtype=torch.float32)) # scalar
    )
    # self.encoder_net = encoder_net
    self.decoder_net = decoder_net
    self.summary_variables = summary_variables

  def forward(self, history, batch_size: int) -> Tensor:
    if history is None:
      return self.initial_design
    else:
      # embed design-outcome pairs
      # embeddings = self.encoder_net(filter_concatenate(history, keys=self.summary_variables)).to('cpu').requires_grad_(True)  # in case of using summary_net from bf. [B, summary_dim]
      embeddings = torch.rand([1, 10])
      # get next design
      next_design = self.decoder_net(embeddings)
    return next_design

In [3]:
class EmitterNetwork(nn.Module):
  def __init__(
        self,
        input_dim, # summary_dim
        hidden_dim,
        output_dim, # xi_dim
        n_hidden_layers=2,
        activation=nn.Softplus,
    ):
    super().__init__()
    self.activation_layer = activation()
    self.input_layer = nn.Linear(input_dim, hidden_dim)
    if n_hidden_layers > 1:
      self.middle = nn.Sequential(
         *[
            nn.Sequential(nn.Linear(hidden_dim, hidden_dim), activation())
            for _ in range(n_hidden_layers - 1)
          ]
            )
    else:
      self.middle = nn.Identity()
      
    self.output_layer = nn.Linear(hidden_dim, output_dim)

  def forward(self, r):
    print("Input requires grad:", r.requires_grad) 
    x = self.input_layer(r)
    print("Input requires grad:", x.requires_grad) 
    x = self.activation_layer(x)
    print("Input requires grad:", x.requires_grad) 
    x = self.middle(x)
    print("Input requires grad:", x.requires_grad) 
    x = self.output_layer(x)
    return x.unsqueeze(1) # [B, xi_dim] -> [B, 1, xi_dim]

In [4]:
class Simulator(nn.Module):
    def __init__(self, design_net):
        super().__init__()
        self.design_net = design_net
    
    def forward(self):
        designs = []
        
        for i in range(5):
            history = None if i == 0 else 0

            xi = self.design_net(history, 0)

            if history is None:
                xi = xi.expand(1, 1, 1) # for initial design

            designs.append(xi)

        designs = torch.cat(designs, dim=0)

        return designs

In [5]:
decoder = EmitterNetwork(10, 24, 1)

In [6]:
design_net = DeepAdaptiveDesign(
    decoder_net=decoder,
    design_shape=torch.Size([1])
)

In [7]:
input = torch.rand([1, 10], requires_grad=True) # embedding (just for testing)

we would expect True to be printed four times -> computational graph is built properly.

In [8]:
output = decoder(input)

Input requires grad: True
Input requires grad: True
Input requires grad: True
Input requires grad: True


For initail design, we expect None for `grad_fn` since `initial_design` is a leaf node.

But still `requires_grad = True`.

In [9]:
next_design = design_net(history = None, batch_size  = 0) # we didn't need batch_size arg actually..
print(next_design.grad_fn)
print(next_design.requires_grad)

None
True


If we input *other than* `None` in `history`, this implements design given some data hence we should have `grad_fn` available.

In [10]:
next_design = design_net(history = 0, batch_size  = 0) 
print(next_design.grad_fn)
print(next_design.requires_grad)

Input requires grad: False
Input requires grad: True
Input requires grad: True
Input requires grad: True
<UnsqueezeBackward0 object at 0x10bf29750>
True


In [11]:
simulator = Simulator(design_net)

In [12]:
out = simulator()
print(out.grad_fn)
print(out.requires_grad)

Input requires grad: False
Input requires grad: True
Input requires grad: True
Input requires grad: True
Input requires grad: False
Input requires grad: True
Input requires grad: True
Input requires grad: True
Input requires grad: False
Input requires grad: True
Input requires grad: True
Input requires grad: True
Input requires grad: False
Input requires grad: True
Input requires grad: True
Input requires grad: True
<CatBackward0 object at 0x10bf2a530>
True


`requires_grad = True` automatically except for input. 
It is a natural bahavior of pytorch.