In [None]:
#pip install torchdiffeq

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchdiffeq import odeint_adjoint as odeint

In [None]:
"""Data Loader and Generator"""

def data_loader(num_batches = 10, batch_size = 64, train_test_split = .8, data_mean = 5, data_variance = 5):
    target_distr = torch.distributions.MultivariateNormal(data_mean*torch.ones(1),data_variance*torch.eye(1))
    data = list()
    
    for _ in range(num_batches):
        batch = target_distr.sample()
        for __ in range(batch_size-1):
            batch = torch.cat((batch, target_distr.sample()), 0)
        data.append(batch.unsqueeze(1))

    split = int(train_test_split*len(data))

    return data[:split], data[split:]


def generate(num_samples, device):
    model = torch.load("trained_model.pth")
    with torch.no_grad():
        z0 = torch.randn(num_samples, 1).to(device)
        t_reverse = torch.linspace(1, 0, steps=5).to(device)
        generated_points = model(z0, t_reverse)

    show(generated_points)


def on_key(event):
    if event.key == 'q':
        plt.close()


def show(datapoints, data_mean, data_variance):
    plt.hist(datapoints.T, bins=50, density=True, color='blue', alpha=0.7)
    plt.title('Histogram')
    plt.grid(False)

    x = np.linspace(-3, 3, 100)
    y = 1 / (data_variance*np.sqrt(2 * np.pi)) * np.exp(-0.5 * ((x-data_mean)/data_variance)**2)  
    plt.plot(x, y, color='red', label='Normal Gaussian')
    
    plt.gcf().canvas.mpl_connect('key_press_event', on_key)

    plt.show()

In [None]:
"""RNODE model for 1D normal distribution"""

class ODEfunc(nn.Module):
    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.linear1 = nn.Linear(dim, 64)
        self.linear2 = nn.Linear(64, 64)
        self.linear3 = nn.Linear(64, dim)
        self.softplus = nn.Softplus()

    def forward(self, t, z):
        out = self.linear1(z)
        out = self.softplus(out)
        out = self.linear2(out)
        out = self.softplus(out)
        out = self.linear3(out)
        out = self.softplus(out)

        return out
    

class CNF(nn.Module):
    def __init__(self, dim):
        super(CNF, self).__init__()
        self.odefunc = ODEfunc(dim)
        self.distr = torch.distributions.MultivariateNormal(torch.ones(dim),torch.eye(dim))
        self.lambda_k = 0.01
        self.lambda_j = 0.01
        
    def forward(self, t, states):
        z = states[0]       #dynamics f
        logp_z = states[1]  #log-det of the Jacobian
        E = states[2]       #kinetic Energy
        n = states[3]       #Frobenius norm of the Jacobian
        batchsize = z.shape[0]
        
        with torch.set_grad_enabled(True):
            #z.requires_grad = True

            dz_dt, dlogp_z_dt = vjp(self.odefunc, (t,z))
            dlogp_z_dt =  dlogp_z_dt[1]

            dE_dt = (torch.linalg.vector_norm(dz_dt, dim=1)**2).unsqueeze(1)
            dn_dt = (torch.linalg.vector_norm(dlogp_z_dt, dim=1)**2).unsqueeze(1)

            return (dz_dt, dlogp_z_dt, dE_dt, dn_dt)


def vjp(f, z):
    return torch.autograd.functional.vjp(f, z, v=torch.randn_like(z[1]), create_graph=True, strict=False)

In [None]:
"""Parameters and Initialization"""

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_batches = 10
batch_size = 1
train_test_split = .8
data_mean = 5
data_variance = 5

train_data, test_data = data_loader(num_batches, batch_size, train_test_split, data_mean, data_variance)

model = CNF(1).to(device)
t0 = 0
t1 = 1
q0 = torch.distributions.MultivariateNormal(torch.zeros(1),torch.eye(1))
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
model.train()



In [None]:
"""Training the Model"""

for epoch in range(num_epochs):
    progress_bar = tqdm(train_data, desc=f'Epoch {epoch + 1}/{num_epochs}', total=num_batches)
    for batchidx, x in enumerate(progress_bar):
        optimizer.zero_grad()

        x.requires_grad = True
        l0 = torch.zeros(x.size(), requires_grad=True)
        E0 = torch.zeros(x.size(), requires_grad=True)
        n0 = torch.zeros(x.size(), requires_grad=True)
        initial_values = (x, l0, E0, n0)
        print("start")

        z_t, logpz, E_t, n_t = odeint(model, initial_values, torch.tensor([t0, t1]).type(torch.float32).to(device))
        z_t0, logp_t0, E_t0, n_t0 = z_t[-1], logpz[-1], E_t[-1], n_t[-1]

        logp_x = q0.log_prob(z_t0).to(device) - logp_t0 - E_t0 - n_t0 
        loss = -logp_x.mean(0)

        print("loss")
        loss.backward()
        optimizer.step()
        
        print("optimized")
        
    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

torch.save(model.state_dict(), "rnode_normal.pth")