In [1]:
import jax 
import jax.numpy as jnp
from jax.random import gumbel
import optax
import haiku as hk
import numpy as np
import jraph

In [7]:
# first we define the GNN part, which consists of featurizer, mp layers, and readout

class GraphNetwork(hk.Module):
    def __init__(self, *args, name=None, **kwargs):
        super().__init__(name=name)
        self.model = jraph.GraphNetwork(*args, **kwargs)
    
    def __call__(self, g: jraph.GraphsTuple) -> jraph.GraphsTuple:
        return self.model(g)
    
class GraphEncoder(hk.Module):
    """
    Encodes input Fine-Grained (FG) molecular graph and outputs a
    continuous vector embedding of that graph deterministically.
    """
    def __init__(self, n_layers: int, 
                       edge_embedding_size: int,
                       node_embedding_size: int,
                       global_embedding_size: int, 
                       name=None):
        super().__init__(name=name)
        self.n_layers = n_layers
        edge_output_sizes = []
        node_output_sizes = []
        global_output_sizes = [, 2+out_embedding_size]
        def make_embed_edge_fn(activation=jax.nn.relu):
            def f(edge_feats): # single layer for simplicity
                return hk.nets.MLP([edge_embedding_size], activation=activation)(edge_feats)
            return f
        
        def make_embed_node_fn(activation=jax.nn.relu):
            def f(node_feats):
                return hk.nets.MLP([node_embedding_size], activation=activation)(node_feats)
            return f
        
        def make_update_edge_fn(activation=jax.nn.relu):
            @jraph.concatenated_args
            def f(feats):
                return hk.nets.MLP(edge_output_sizes, activation=activation)(feats)
            return f
        
        def make_update_node_fn(activation=jax.nn.relu):
            def f(node_feats, sender_feats, receiver_feats, global_feats):
                return hk.nets.MLP(node_output_sizes, activation=activation)(
                    jnp.concatenate([node_feats, receiver_feats], axis=1) # only aggr over msgs from incoming edges
                )
            return f
        
        def make_update_global_fn(activation=jax.nn.relu):
            @jraph.concatenated_args
            def f(feats):
                return hk.nets.MLP(global_output_sizes, activation=activation)(feats)
            return f
        
        self.featurizer = jraph.GraphMapFeatures(embed_edge_fn=make_embed_edge_fn(),
                                                 embed_node_fn=make_embed_node_fn(), 
                                                 embed_global_fn=None)
        self.mp_layers = [GraphNetwork(update_edge_fn=make_update_edge_fn(), 
                                       update_node_fn=make_update_node_fn(), 
                                       update_global_fn=None) for _ in range(n_layers)]
        self.readout = GraphNetwork(update_edge_fn=make_update_edge_fn(), 
                                    update_node_fn=make_update_node_fn(), 
                                    update_global_fn=make_update_global_fn)
        
    def __call__(self, g: jraph.GraphsTuple):
        g = self.featurizer(g)
        for layer in self.mp_layers:
            g = layer(g)
        g = self.readout(g)
        return g.globals # extract graph level latent feature
        
        

In [None]:
def replace_node_features(g: jraph.GraphsTuple, new_nodes) -> jraph.GraphsTuple:
    nodes, edges, receivers, senders, globals_, n_node, n_edge = g
    n_node = jnp.array([len(nodes)])
    return jraph.GraphsTuple(new_nodes, edges, receivers, senders, globals_, n_node, n_edge)


class GraphNVPLayer(hk.Module):
    def __init__(self, dim, mask_dim, hidden_dim = 16, name=None):
        super().__init__(name=name)
        self.dim = dim
        self.mask_dim = mask_dim
        self.hidden_dim = hidden_dim
        edge_output_sizes = [hidden_dim, hidden_dim]
        node_output_sizes_trans = [hidden_dim, hidden_dim, 1]
        node_output_sizes_scale = [hidden_dim, hidden_dim, dim]
        def make_mlp_edge_update(activation):
            @jraph.concatenated_args
            def f(feats):
                return hk.nets.MLP(edge_output_sizes, activation=activation)(feats)
            return f
        def make_mlp_node_update(activation, node_output_sizes):
            def f(node_feats, sender_feats, receiver_feats, global_feats):
                return hk.nets.MLP(node_output_sizes, activation=activation)(
                    jnp.concatenate([node_feats, receiver_feats], axis=1) # only aggr over msgs from incoming edges
                )
            return f
        self.mp_trans = GraphNetwork(update_edge_fn=make_mlp_edge_update(jax.nn.relu), 
                                     update_node_fn=make_mlp_node_update(jax.nn.relu, 
                                                                         node_output_sizes_trans), 
                                     update_global_fn=None)
        self.mp_scale = GraphNetwork(update_edge_fn=make_mlp_edge_update(jax.nn.relu), 
                                     update_node_fn=make_mlp_node_update(jax.nn.tanh,
                                                                         node_output_sizes_scale), 
                                     update_global_fn=None)
    def mask_graph(self, g):
        nodes = g.nodes
        mask = jnp.ones_like(nodes)
        mask[:, self.mask_dim] = 0
        g_masked = replace_node_features(g, nodes * mask)
        return g_masked, mask
    
    def forward(self, g: jraph.GraphsTuple) -> Tuple[jraph.GraphsTuple, jnp.ndarray]:
        g_masked, mask = self.mask_graph(g)
        scale, trans = self.mp_scale(g_masked).nodes, self.mp_trans(g_masked).nodes
        new_nodes = nodes * jnp.exp(scale * mask) + (trans * mask)
        logdetJ = jnp.sum(scale * mask)
        g_new = replace_node_features(g, new_nodes)
        return g_new, logdetJ
        
    def reverse(self, g: jraph.GraphsTuple) -> Tuple[jraph.GraphsTuple, jnp.ndarray]:
        g_masked, mask = self.mask_graph(g)
        scale, trans = self.mp_scale(g_masked).nodes, self.mp_trans(g_masked).nodes
        new_nodes = (nodes - (trans * mask)) / jnp.exp(scale * mask)
        logdetJ = -jnp.sum(scale * mask)
        g_new = replace_node_features(g, new_nodes)
        return g_new, logdetJ
    
class GraphNVPBlock(hk.Module):
    def __init__(self, dim, hidden_dim = 16, name=None):
        super().__init__(name=name)
        self.dim = dim
        self.layers = [GraphNVPLayer(dim, mask_dim, hidden_dim) for mask_dim in range(dim)]
        
    def forward(self, g: jraph.GraphsTuple) -> Tuple[jraph.GraphsTuple, jnp.ndarray]:
        ldj_sum = 0
        for layer in self.layers:
            g, ldj = layer.forward(g)
        return g, ldj_sum
        
    def reverse(self, g: jraph.GraphsTuple) -> Tuple[jraph.GraphsTuple, jnp.ndarray]:
        ldj_sum = 0
        for layer in self.layers[::-1]:
            g, ldj = layer.reverse(g)
        return g, ldj_sum
    
class GraphNVP(hk.Module):
    # The potential advantage of GraphNVP over GRevNet is that
    # 1. This handles edge feature updates but the latter couldn't
    # 2. The latter requires breaking node features into two halves which makes
    # application to 3D coordinate as node features difficult. GraphNVP
    # can iterate and update over each dimension using the rest features 
    # (see paper for more details) 
    def __init__(self, n_layers, dim, hidden_dim=16, name=None):
        super().__init__(name=name)
        self.layers = [GRevLayer(dim, hidden_dim) for _ in range(n_layers)]
    
    # these should be same as before, just copied
    def forward(self, g: jraph.GraphsTuple) -> Tuple[jraph.GraphsTuple, jnp.ndarray]:
        ldj_sum = 0
        for layer in self.layers:
            g, ldj = layer.forward(g)
            ldj_sum += ldj
        return g, ldj_sum
        
    def reverse(self, g: jraph.GraphsTuple) -> Tuple[jraph.GraphsTuple, jnp.ndarray]:
        ldj_sum = 0
        for layer in self.layers[::-1]:
            g, ldj = layer.reverse(g)
            ldj_sum += ldj
        return g, ldj_sum
    
    

In [None]:
class CoarseGrainingDecoder(hk.Module):
    """
    Probabilistically decode the embedding from the previous encoder into
    a Coarse-Grained (CG) molecular graph. Returns log proba as well as the
    graph that we can directly evaluate energy on.
    """
    def __init__(self, flow_dim=2, max_nodes=100, n_hidden=16, n_edge_feats=8, name=None):
        super().__init__(name=name)
        self.max_nodes = max_nodes
        self.max_edges = (max_nodes)*(max_nodes-1)//2
        self.flow_dim = flow_dim
        self.node_feat_shape = (max_nodes, flow_dim)
        s, r = list(zip(*list(itertools.product(range(max_nodes), range(max_nodes))))) # transpose tupls
        self.senders = list(s)
        self.receivers = list(r)
        self.mu_decoder = hk.nets.MLP([n_hidden, max_nodes*flow_dim], activation=jax.nn.relu)
        self.sig_decoder = hk.nets.MLP([n_hidden, max_nodes*flow_dim], activation=jax.nn.relu) #?
        self.edge_decoder = hk.nets.MLP([n_hidden, n_edge_feats], activation=jax.nn.relu)
        
    
    def __call__(self, h):
        # first and second position of h is mean and std embeddings
        h_mu, h_sig = h
        V_mu, V_sig = self.mu_decoder(h_mu), self.sig_decoder(h_sig)
        V_mu, V_sig = jax.reshape(V_mu, shape=self.node_feat_shape), jax.reshape(V_sig, shape=self.node_feat_shape)
        eps = jax.random.normal(hk.next_rng_key(), shape=self.node_feat_shape)
        V = eps * jnp.exp(V_sig) + V_mu # v_sig is actually log(sig_z)
        ll = -(jnp.sum(V_sig) + 0.5*jnp.sum(eps))# log likelihood, discard 1/2 log(2pi) term
        eps2 = jax.random.normal(hk.next_rng_key(), )
        E_ = jnp.array([jnp.concatenate((ri, rj)) for ri, rj in itertools.product(V, V)]) # |V|**2 * flow_dim*2 
        E = self.edge_decoder(E) # |V|**2 * n_edge_feats
        
        # build graph and pass to flow model
        G = jraph.GraphsTuple(n_node=self.max_nodes, n_edge=self.max_edges, nodes=V, edges=E,
                              globals=None, senders=self.senders, receivers=self.receivers)

In [6]:
import itertools

N = 100
a, b = list(zip(*list(itertools.product(range(N), range(N)))))


In [15]:
def f_():
    mlp = hk.nets.MLP([8, 16, 8])
    mlp_in = jnp.ones([1, 2])
    y = mlp(mlp_in)
    return y

f = hk.without_apply_rng(hk.transform(f))
params = f.init(rng=jax.random.PRNGKey(42))

In [17]:
f.apply(params).shape

(1, 8)

In [27]:
a = np.repeat(np.arange(0, 5), 2).reshape((5, 2))
b = np.repeat(np.arange(5, 10), 2).reshape((5, 2))

In [28]:
ab = np.array([np.concatenate((ri, rj)) for ri, rj in itertools.product(a, b)])

In [29]:
ab

array([[0, 0, 5, 5],
       [0, 0, 6, 6],
       [0, 0, 7, 7],
       [0, 0, 8, 8],
       [0, 0, 9, 9],
       [1, 1, 5, 5],
       [1, 1, 6, 6],
       [1, 1, 7, 7],
       [1, 1, 8, 8],
       [1, 1, 9, 9],
       [2, 2, 5, 5],
       [2, 2, 6, 6],
       [2, 2, 7, 7],
       [2, 2, 8, 8],
       [2, 2, 9, 9],
       [3, 3, 5, 5],
       [3, 3, 6, 6],
       [3, 3, 7, 7],
       [3, 3, 8, 8],
       [3, 3, 9, 9],
       [4, 4, 5, 5],
       [4, 4, 6, 6],
       [4, 4, 7, 7],
       [4, 4, 8, 8],
       [4, 4, 9, 9]])

In [30]:
mask=jnp.ones((5, 2))

In [32]:
mask.at[:, 0].set(0)

DeviceArray([[0., 1.],
             [0., 1.],
             [0., 1.],
             [0., 1.],
             [0., 1.]], dtype=float32)