In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.functional as F
import pytorch_lightning as pl
import numpy as np
import pickle

One pair of (X,y) looks like:
- d = 3 # number of dimensions
- N # number of objects
- X: N x (2 * d + 1) # positions, velocities, mass
- y: N x d # accelerations

In [2]:
batch_size = 2048 * 16
visible = 10
hidden = 1
N = visible + hidden
d = 3

scale_exp = 5

pos = torch.exp(scale_exp * torch.rand(batch_size, N, d))
# make it centered at 0
pos -= pos.mean(axis=1, keepdim = True) 

vel = torch.exp(scale_exp * torch.rand(batch_size, N, d))

# assign fixed positions, velocities??? (this shouldn't matter for now) to hidden objects (this only works for one that is put in the center for now)
pos[:,:hidden,:] *= 0
vel[:,:hidden,:] *= 0

m = torch.rand(1, N, 1)
# hidden mass:
m[0,:hidden,0] = m[0,:hidden,0] * 0 + 1

m = torch.exp(scale_exp * m)
m = m.expand(batch_size,-1,-1)

dt = 0.01
steps = 10
g = 0.5

ms = m.unsqueeze(2).expand(-1,-1,N,-1)
m1 = ms
m2 = ms.transpose(1,2)

X_list = []
y_list = []

X_list.append(torch.cat((pos, vel, m), dim=-1))

for _ in range(steps):
    xs = pos.unsqueeze(2).expand(-1,-1,N,-1)
    x1 = xs
    x2 = xs.transpose(1,2)

    delta_x = x1 - x2
    delta_x_norm = (delta_x ** 2).sum(dim=-1, keepdim=True)**0.5 + 1e-9
    forces = -1 * g * m1 * m2 / delta_x_norm ** 2

    # the delta_x_norms were offset by a small number to avoid numeric problems
    # this is fine, when multiplying by delta_x, the self-self terms are zeroed out
    force_vectors = forces * delta_x / delta_x_norm
    a = force_vectors.sum(dim=2) / m1[:,:,0,:]

    # 
    #y_list.append(a)

    # simple 1 step - could use a more intelligent integrator here.
    vel += a * dt
    pos += vel * dt

    #y_list.append(torch.cat((pos, vel, m), dim=-1))

y_list.append(torch.cat((pos, vel, m), dim=-1))

X = torch.cat(X_list)
y = torch.cat(y_list)

# remove hidden objects
X = X[:,hidden:,:]
y = y[:,hidden:,:]

# add some random noise
#y += 1e-6 * torch.randn(y.shape) * y.mean()


In [11]:
!pip install torchdyn



In [3]:
from torchdyn.core import NeuralODE
from torchdyn.datasets import *
from torchdyn import *

In [None]:
 Modern ODE solvers provide guarantees about the growth
of approximation error, monitor the level of error, and adapt their evaluation strategy on the fly to
achieve the requested level of accuracy. This allows the cost of evaluating a model to scale with
problem complexity. After training, accuracy can be reduced for real-time or low-power applications.

In [4]:
device = torch.device("cpu") #torch.device("cuda") 

In [25]:
class ODEModule(pl.LightningModule):
    def __init__(self, diff_model):
        super(ODEModule, self).__init__()
        self.loss = F.mse_loss # torch.log(F.mrse_loss) + angle loss
        self.lr = 1e-3
        self.wd = 1e-5
        self.steps = 1 #steps
        self.dt = dt * steps
        # relative mean weighted error - this wasn't helpful at all
        # self.loss = lambda y_hat, y: ((y_hat - y).abs() / (y.abs() + 1e-8)).mean()

        self.diff_model = diff_model#.to(device)
        self.ode_model = NeuralODE(self.diff_model, return_t_eval = False, sensitivity='adjoint', solver='rk4', solver_adjoint='dopri5', atol_adjoint=1e-2, rtol_adjoint=1e-2)#.to(device)


    #def ode_forward(self, X, dt=0.01):
    #    return X + dt * self.diff_model(X)    

    # def forward(self, x):
    #     #return self.ode_model(x)[-1]
    #     #return x + dt * self.diff_model(x)    

    #     x_local = []
    #     x_local.append(x)
    #     for i in range(self.steps):
    #         x_local.append(x_local[i] + self.dt * self.diff_model(x_local[i]))
        
    #     return x_local[-1]
    
    def forward(self, x, i=None):
        #return self.ode_model(x)[-1]
        #return x + dt * self.diff_model(x)    
        if i is None:
            i = self.steps

        if i == 0:
            return x
        
        return self.forward(x + self.dt * self.diff_model(x), i-1)

    def training_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self.forward(X)

        loss = self.loss(y_hat, y)
        self.log('train_loss', loss.item(), on_epoch=True, on_step=False)

        # log learning terms
        for name, fx in self.diff_model.my_loggers.items():
            self.log(name, fx(self.diff_model), on_epoch=True, on_step=False)

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self.forward(X)

        loss = self.loss(y_hat, y)
        self.log('validation_loss', loss.item(), on_epoch=True, on_step=False)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.wd)
        return optimizer 

class GnnLogLinearModel(torch.nn.Module):
    def __init__(self):
        super(GnnLogLinearModel, self).__init__()
        self.input_size = 3 # r, m1, m2
        self.output_size = 1
        self.formula = torch.nn.Linear(self.input_size, self.output_size) 
        self.my_loggers = {
            'r_exp': lambda s: s.formula.weight[0][0].item(),
            'm1_exp': lambda s: s.formula.weight[0][1].item(),
            'm2_exp': lambda s: s.formula.weight[0][2].item()
        }

    def forward(self, X):
        N = X.shape[1]
        xs = X[:,:,:d].unsqueeze(2).expand(-1,-1,N,-1)
        v = X[:,:,d:2*d]
        m = X[:,:,-1:]

        # safe-guarding for neuralode.
        #m = torch.max(m,torch.tensor([1e-8]))
        #m = m * 0 + 1
        ms = m.unsqueeze(2).expand(-1,-1,N,-1)

        x1 = xs
        x2 = xs.transpose(1,2)

        delta_x = x1 - x2
        delta_x_norm = (delta_x ** 2).sum(dim=-1, keepdim=True)**0.5 + 1e-9

        m1 = ms
        m2 = ms.transpose(1,2)

        inp = torch.cat((delta_x_norm, m1, m2), dim=-1)

        inp_log = torch.log(inp)

        # one linear layer
        forces_log = self.formula(inp_log)

        forces = torch.exp(forces_log)

        # the delta_x_norms were offset by a small number to avoid numeric problems
        # this is fine, when multiplying by delta_x, the self-self terms are zeroed out
        force_vectors = forces * delta_x / delta_x_norm

        a = -1 * force_vectors.sum(dim=2) / m1[:,:,0,:]
        
        dX = torch.cat((v, a, X[:,:,-1:]*0), dim=-1)

        # later learn this directionality too (the -1)
        return dX

In [20]:
X = X.to(device)
y = y.to(device)

In [18]:
train_set = list(zip(X, y))
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size
train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(42))

best_model = None
best_score = 1e15
times = 2

for _ in range(times):
    diff_model = GnnLogLinearModel()#.to(device)
    model = ODEModule(diff_model)#.to(device)
    y_hat = model.forward(X)
    loss = model.loss(y_hat, y)
    if loss < best_score:
        print(loss)
        best_score = loss
        best_model = model


model = best_model

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.
tensor(0.0003, grad_fn=<MseLossBackward0>)
Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


In [14]:
# forward pass and loss calculation runs okay, without issues
y_hat = model.forward(X)
loss = model.loss(y_hat, y)
# backward pass takes an eternity (even on a batch of 16)
loss.backward()

In [27]:
early_stop_callback = EarlyStopping(monitor="validation_loss", patience=300, verbose=False, mode="min")

train_set = DataLoader(train_set, shuffle=True, batch_size=128)
valid_set = DataLoader(valid_set, shuffle=True, batch_size=1000)

logger = TensorBoardLogger("lightning_logs", name=f'gnn_log_linear_ode') # _masses, hidden_multiple

# train with both splits
trainer = pl.Trainer(gpus=1, max_epochs=10000, #gpus=1
                            #gradient_clip_val=0.5,
                            callbacks=[early_stop_callback],
                            logger=logger,
                            enable_progress_bar=False)

trainer.fit(model, train_set, valid_set)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type              | Params
-------------------------------------------------
0 | diff_model | GnnLogLinearModel | 4     
1 | ode_model  | NeuralODE         | 4     
-------------------------------------------------
4         Trainable params
0         Non-trainable params
4         Total params
0.000     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


In [26]:
from torch.utils.data import DataLoader, Dataset, random_split
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

In [28]:
# not a single epoch in 1 hour



for mult in [1]:

    train_set = list(zip(X, y))
    train_set_size = int(len(train_set) * 0.8)
    valid_set_size = len(train_set) - train_set_size
    train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(42))

    best_model = None
    best_score = 1e15
    times = 3

    for _ in range(times):
        diff_model = GnnLogLinearModel()#.to(device)
        model = ODEModule(diff_model)#.to(device)
        y_hat = model.forward(X)
        loss = model.loss(y_hat, y)
        if loss < best_score:
            print(loss)
            best_score = loss
            best_model = model


    model = best_model
    model.steps = 2
    early_stop_callback = EarlyStopping(monitor="validation_loss", patience=300, verbose=False, mode="min")

    train_set = DataLoader(train_set, shuffle=True, batch_size=128)
    valid_set = DataLoader(valid_set, shuffle=True, batch_size=1000)

    logger = TensorBoardLogger("lightning_logs", name=f'gnn_log_linear_ode_{model.steps}') # _masses, hidden_multiple

    # train with both splits
    trainer = pl.Trainer(gpus=1, max_epochs=10000, #gpus=1
                                #gradient_clip_val=0.5,
                                callbacks=[early_stop_callback],
                                logger=logger,
                                enable_progress_bar=False)

    trainer.fit(model, train_set, valid_set)

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.
tensor(2.1634, grad_fn=<MseLossBackward0>)
Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.
tensor(0.0333, grad_fn=<MseLossBackward0>)
Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type              | Params
-------------------------------------------------
0 | diff_model | GnnLogLinearModel | 4     
1 | ode_model  | NeuralODE         | 4     
-------------------------------------------------
4         Trainable params
0         Non-trainable params
4         Total params
0.000     Total estimated model params size (MB)
  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [11]:
diff_model = GnnLogLinearModel()#.to(device)
model = ODEModule(diff_model)

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


In [18]:
x_local = []
x = X[:128]
model.steps = 5
x_local.append(x)
for i in range(model.steps):
    x_local.append(x_local[i] + model.dt * model.diff_model(x_local[i]))

In [None]:
x_local[0]

In [19]:
x_local[0].min(), x_local[0].max()

(tensor(-60.7288), tensor(148.3710))

In [20]:
x_local[1].min(), x_local[1].max()

(tensor(-60.3059, grad_fn=<MinBackward1>),
 tensor(156.5332, grad_fn=<MaxBackward1>))

In [21]:
x_local[2].min(), x_local[2].max()

(tensor(-74.1496, grad_fn=<MinBackward1>),
 tensor(163.1830, grad_fn=<MaxBackward1>))

In [22]:
x_local[3].min(), x_local[3].max()

(tensor(-110.7478, grad_fn=<MinBackward1>),
 tensor(173.4416, grad_fn=<MaxBackward1>))

In [28]:
for mult in [1]:

    train_set = list(zip(X, y))
    train_set_size = int(len(train_set) * 0.8)
    valid_set_size = len(train_set) - train_set_size
    train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(42))

    model.steps += 1
    model.dt = dt * steps / model.steps
    print(model.steps, model.dt)

    early_stop_callback = EarlyStopping(monitor="validation_loss", patience=300, verbose=False, mode="min")

    train_set = DataLoader(train_set, shuffle=True, batch_size=128)
    valid_set = DataLoader(valid_set, shuffle=True, batch_size=1000)

    logger = TensorBoardLogger("lightning_logs", name=f'gnn_log_linear_ode_{model.steps}') # _masses, hidden_multiple

    # train with both splits
    trainer = pl.Trainer(gpus=1, max_epochs=10000, #gpus=1
                                #gradient_clip_val=0.5,
                                callbacks=[early_stop_callback],
                                logger=logger,
                                enable_progress_bar=False)

    trainer.fit(model, train_set, valid_set)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: lightning_logs\gnn_log_linear_ode_4


4 0.025


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type              | Params
-------------------------------------------------
0 | diff_model | GnnLogLinearModel | 4     
1 | ode_model  | NeuralODE         | 4     
-------------------------------------------------
4         Trainable params
0         Non-trainable params
4         Total params
0.000     Total estimated model params size (MB)


In [29]:
for mult in [1]:

    train_set = list(zip(X, y))
    train_set_size = int(len(train_set) * 0.8)
    valid_set_size = len(train_set) - train_set_size
    train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(42))

    model.steps = 1
    model.dt = dt * steps / model.steps
    print(model.steps, model.dt)

    early_stop_callback = EarlyStopping(monitor="validation_loss", patience=300, verbose=False, mode="min")

    train_set = DataLoader(train_set, shuffle=True, batch_size=128)
    valid_set = DataLoader(valid_set, shuffle=True, batch_size=1000)

    logger = TensorBoardLogger("lightning_logs", name=f'gnn_log_linear_ode_{model.steps}') # _masses, hidden_multiple

    # train with both splits
    trainer = pl.Trainer(gpus=1, max_epochs=10000, #gpus=1
                                #gradient_clip_val=0.5,
                                callbacks=[early_stop_callback],
                                logger=logger,
                                enable_progress_bar=False)

    trainer.fit(model, train_set, valid_set)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


1 0.1


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type              | Params
-------------------------------------------------
0 | diff_model | GnnLogLinearModel | 4     
1 | ode_model  | NeuralODE         | 4     
-------------------------------------------------
4         Trainable params
0         Non-trainable params
4         Total params
0.000     Total estimated model params size (MB)


In [31]:
model.diff_model.formula.weight

Parameter containing:
tensor([[nan, nan, nan]], requires_grad=True)

In [15]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
early_stop_callback = EarlyStopping(monitor="validation_loss", patience=300, verbose=False, mode="min")

train_set = list(zip(X, y))
train_set = list(zip(X, y))
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size
train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(42))

train_set = DataLoader(train_set, shuffle=True, batch_size=128)
valid_set = DataLoader(valid_set, shuffle=True, batch_size=1000)

logger = TensorBoardLogger("lightning_logs", name='gnn_log_linear_mult1') # _masses, hidden_multiple

# train with both splits
trainer = pl.Trainer(gpus=1, max_epochs=10000,
                            #gradient_clip_val=0.5,
                            callbacks=[early_stop_callback],
                            logger=logger)

trainer.fit(model, train_set, valid_set)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type              | Params
-------------------------------------------------
0 | diff_model | GnnLogLinearModel | 4     
1 | ode_model  | NeuralODE         | 4     
-------------------------------------------------
4         Trainable params
0         Non-trainable params
4         Total params
0.000     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
# compare fitted masses to ground truth
g * m[0],  model.masses ** model.formula.weight[0][2]

(tensor([[40.4088],
         [ 2.9725],
         [ 5.6881],
         [ 8.0561],
         [32.0705],
         [ 2.0293],
         [ 1.0138],
         [59.0399],
         [60.7772],
         [ 3.3891]]),
 tensor([[[39.0646],
          [ 2.9494],
          [ 5.6877],
          [ 8.0410],
          [31.9989],
          [ 2.0037],
          [ 0.9688],
          [58.7323],
          [60.7730],
          [ 3.3401]]], device='cuda:0', grad_fn=<PowBackward1>))

In [None]:
# to run tensorboard in the terminal:
# tensorboard --logdir lightning_logs