# following the example here: 
- https://github.com/patrick-kidger/torchcde/blob/master/example/example.py

In [1]:
import math
import torch
import torchcde

import matplotlib.pyplot as plt


A CDE model looks like

$ z_t = z_0 + \int_0^t f_\theta(z_s) dX_s $

Where $X$ is your data and $f_\theta$ is a neural network. So the first thing we need to do is define such an $f_\theta$.

That's what this CDEFunc class does.
Here we've built a small single-hidden-layer neural network, whose hidden layer is of width 128.

In [2]:
class CDEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        ######################
        # input_channels is the number of input channels in the data X. (Determined by the data.)
        # hidden_channels is the number of channels for z_t. (Determined by you!)
        ######################
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = torch.nn.Linear(hidden_channels, 128)
        self.linear2 = torch.nn.Linear(128, input_channels * hidden_channels)
        
    ######################
    # For most purposes the t argument can probably be ignored; unless you want your CDE to behave differently at
    # different times, which would be unusual. But it's there if you need it!
    ######################
    def forward(self, t, z):
        # z has shape (batch, hidden_channels)
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        ######################
        # Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.
        ######################
        z = z.tanh()
        ######################
        # Ignoring the batch dimension, the shape of the output tensor must be a matrix,
        # because we need it to represent a linear map from R^input_channels to R^hidden_channels.
        ######################
        z = z.view(z.size(0), self.hidden_channels, self.input_channels)
        return z

Next, we need to package CDEFunc up into a model that computes the integral.

In [3]:
class NeuralCDE(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(NeuralCDE, self).__init__()

        self.func = CDEFunc(input_channels, hidden_channels)
        self.initial = torch.nn.Linear(input_channels, hidden_channels)
        self.readout = torch.nn.Linear(hidden_channels, output_channels)

    def forward(self, coeffs):
        X = torchcde.NaturalCubicSpline(coeffs)

        ######################
        # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
        ######################
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)

        ######################
        # Actually solve the CDE.
        ######################
        z_T = torchcde.cdeint(X=X,
                              z0=z0,
                              func=self.func,
                              t=X.interval)

        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[:, 1]
        pred_y = self.readout(z_T)
        return pred_y

In [4]:
######################
# Now we need some data.
# Here we have a simple example which generates some spirals, some going clockwise, some going anticlockwise.
######################
def get_data():
    t = torch.linspace(0., 4 * math.pi, 100)

    start = torch.rand(128) * 2 * math.pi
    x_pos = torch.cos(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
    x_pos[:64] *= -1
    y_pos = torch.sin(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
    x_pos += 0.01 * torch.randn_like(x_pos)
    y_pos += 0.01 * torch.randn_like(y_pos)
    ######################
    # Easy to forget gotcha: time should be included as a channel; Neural CDEs need to be explicitly told the
    # rate at which time passes. Here, we have a regularly sampled dataset, so appending time is pretty simple.
    ######################
    X = torch.stack([t.unsqueeze(0).repeat(128, 1), x_pos, y_pos], dim=2)
    y = torch.zeros(128)
    y[:64] = 1

    perm = torch.randperm(128)
    X = X[perm]
    y = y[perm]

    ######################
    # X is a tensor of observations, of shape (batch=128, sequence=100, channels=3)
    # y is a tensor of labels, of shape (batch=128,), either 0 or 1 corresponding to anticlockwise or clockwise respectively.
    ######################
    return X, y

# main func

In [6]:
train_X, train_y = get_data()

In [7]:
######################
# input_channels=3 because we have both the horizontal and vertical position of a point in the spiral, and time.
# hidden_channels=8 is the number of hidden channels for the evolving z_t, which we get to choose.
# output_channels=1 because we're doing binary classification.
######################
model = NeuralCDE(input_channels=3, hidden_channels=8, output_channels=1)
optimizer = torch.optim.Adam(model.parameters())

In [8]:
######################
# Now we turn our dataset into a continuous path. We do this here via natural cubic spline interpolation.
# The resulting `train_coeffs` is a tensor describing the path.
# For most problems, it's probably easiest to save this tensor and treat it as the dataset.
######################
train_coeffs = torchcde.natural_cubic_coeffs(train_X)

In [9]:
num_epochs = 15
train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch_coeffs, batch_y = batch
        pred_y = model(batch_coeffs).squeeze(-1)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y, batch_y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print('Epoch: {}   Training loss: {}'.format(epoch, loss.item()))

test_X, test_y = get_data()
test_coeffs = torchcde.natural_cubic_coeffs(test_X)
pred_y = model(test_coeffs).squeeze(-1)
binary_prediction = (torch.sigmoid(pred_y) > 0.5).to(test_y.dtype)
prediction_matches = (binary_prediction == test_y).to(test_y.dtype)
proportion_correct = prediction_matches.sum() / test_y.size(0)
print('Test Accuracy: {}'.format(proportion_correct))

Epoch: 0   Training loss: 0.9556078314781189
Epoch: 1   Training loss: 1.0249571800231934
Epoch: 2   Training loss: 0.6855239272117615
Epoch: 3   Training loss: 0.8307387828826904
Epoch: 4   Training loss: 0.745247483253479
Epoch: 5   Training loss: 0.6234510540962219
Epoch: 6   Training loss: 0.6001538038253784
Epoch: 7   Training loss: 0.5706719160079956
Epoch: 8   Training loss: 0.5286034941673279
Epoch: 9   Training loss: 0.4988443851470947
Epoch: 10   Training loss: 0.43296709656715393
Epoch: 11   Training loss: 0.3257259726524353
Epoch: 12   Training loss: 0.21124131977558136
Epoch: 13   Training loss: 0.12150489538908005
Epoch: 14   Training loss: 0.05994701758027077
Test Accuracy: 1.0


In [11]:
binary_prediction

tensor([0., 1., 0., 0., 1., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 0., 1.,
        1., 1., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0.,
        0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1., 0., 0.,
        0., 1., 0., 1., 0., 0., 1., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1.,
        1., 0., 1., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 1., 1., 0., 1., 1.,
        0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0.,
        1., 1., 0., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 1., 0.,
        1., 1.])

In [10]:
pred_y

tensor([-2.8642,  4.7660, -3.3683, -3.4361,  4.8082, -3.0356,  4.6428, -3.1727,
         2.8118,  3.6808, -3.1978, -3.2700, -3.1142,  4.8418, -3.4221,  2.9090,
        -2.3949,  2.2152,  3.4612,  2.7512, -3.0920,  2.4418, -3.1267, -3.2749,
         4.4941, -3.3270,  2.7463, -2.9806,  4.4395, -2.5967,  2.7151,  2.7531,
         2.7058, -3.0503, -3.0278, -2.4734, -3.6836,  4.8273,  2.4549, -2.4488,
        -3.0870, -2.4042,  3.4958, -2.4246,  2.7067,  4.5567, -2.4310, -2.9452,
         2.9359,  3.8290, -3.6689,  2.7041, -2.9999, -3.7177, -3.6549,  2.7057,
        -2.4493,  3.4272, -3.0260, -3.6310,  2.5254,  4.8286,  4.1304, -2.4858,
        -2.4195, -2.7480,  4.4479, -3.0528,  3.4386,  2.5928,  2.3664,  3.6087,
         2.8040, -3.2621,  4.6705,  2.7597,  2.7862,  2.8325, -2.5907, -2.9079,
        -2.6772,  2.7817,  4.5088, -3.1706,  4.5548,  2.7352,  4.6248, -3.5122,
         4.6695,  2.8016, -3.5389, -2.6833,  3.9014, -3.0424, -2.3798,  2.8218,
         3.6908, -3.3098, -3.2880, -3.34