In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random

# Normalising flows: comparison to CNF
In this notebook we compare the performance of CNF to a normalising flow with a similar architecture. As with the CNF example, we will learn a two-modal distribution with two concentric circles, and a uniform distribution of a triangle.
## Hyperparameters

In [2]:
args = {
    'niters': 5000,
    'lr': 1e-3,
    'num_samples': 512,
    
}

## The Model
We will use a normalising flow with the following architecture, so it is similar to that of the CNF:
- 10 layers
- Hypernet with 2 hidden layers to generate the parameters of the affine transformation
- Affine transformation

The hypernet is conditioned on the context, which is the same as the CNF.
    

In [None]:
class HyperNetwork(nn.Module):
    """Hyper-network allowing the flow to be conditioned on an external context."""
    
    def __init__(self, context_dim, hidden_dim, out_dim):
        super().__init__()
        
        
        self.net = nn.Sequential(
            nn.Linear(context_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
        )

## Triangle distribution

In [4]:
def point_on_triangle(pt1, pt2, pt3):
    """
    Random point on the triangle with vertices pt1, pt2 and pt3.
    """
    x, y = random.random(), random.random()
    q = abs(x - y)
    s, t, u = q, 0.5 * (x + y - q), 1 - 0.5 * (q + x + y)
    return (
        s * pt1[0] + t * pt2[0] + u * pt3[0],
        s * pt1[1] + t * pt2[1] + u * pt3[1],
    )


def get_batch(num_samples):
    """
    Generate random points uniformly distributed inside a custom triangle.
    """
    pt1 = [-0.2, 0.0]
    pt2 = [0.6, 0.0]
    pt3 = [0.0, 0.7]

    # Generate random points inside the triangle
    points = [point_on_triangle(pt1, pt2, pt3) for _ in range(num_samples)]

    x = torch.tensor(points).type(torch.float32)
    logp_diff_t1 = torch.zeros(num_samples, 1).type(torch.float32)

    return x, logp_diff_t1