In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.functional as F
import pytorch_lightning as pl
import numpy as np
import pickle

  from .autonotebook import tqdm as notebook_tqdm


One pair of (X,y) looks like:
- d = 3 # number of dimensions
- N # number of objects
- X: N x (2 * d + 1) # positions, velocities, mass
- y: N x d # accelerations

In [189]:
batch_size = 2048 * 16
visible = 10
hidden = 1
N = visible + hidden
d = 3

scale_exp = 5

pos = torch.exp(scale_exp * torch.rand(batch_size, N, d))
# make it centered at 0
pos -= pos.mean(axis=1, keepdim = True) 

vel = torch.exp(scale_exp * torch.rand(batch_size, N, d))

# assign fixed positions, velocities??? (this shouldn't matter for now) to hidden objects (this only works for one that is put in the center for now)
pos[:,:hidden,:] *= 0
vel[:,:hidden,:] *= 0

m = torch.rand(1, N, 1)
# hidden mass:
m[0,:hidden,0] = m[0,:hidden,0] * 0 + 1

m = torch.exp(scale_exp * m)
m = m.expand(batch_size,-1,-1)

dt = 0.01
g = 0.5

ms = m.unsqueeze(2).expand(-1,-1,N,-1)
m1 = ms
m2 = ms.transpose(1,2)

X_list = []
y_list = []

for _ in range(1):
    xs = pos.unsqueeze(2).expand(-1,-1,N,-1)
    x1 = xs
    x2 = xs.transpose(1,2)

    delta_x = x1 - x2
    delta_x_norm = (delta_x ** 2).sum(dim=-1, keepdim=True)**0.5 + 1e-9
    forces = -1 * g * m1 * m2 / delta_x_norm ** 2

    # the delta_x_norms were offset by a small number to avoid numeric problems
    # this is fine, when multiplying by delta_x, the self-self terms are zeroed out
    force_vectors = forces * delta_x / delta_x_norm
    a = force_vectors.sum(dim=2) / m1[:,:,0,:]

    X_list.append(torch.cat((pos, vel, m), dim=-1))
    y_list.append(a)

    # simple 1 step - could use a more intelligent integrator here.
    vel += a * dt
    pos += vel * dt

X = torch.cat(X_list)
y = torch.cat(y_list)

# remove hidden objects
X = X[:,hidden:,:]
y = y[:,hidden:,:]

# add some random noise
y += 0.1 * torch.randn(y.shape) * y.mean()


In [18]:
class BaseModule(pl.LightningModule):
    def __init__(self):
        super(BaseModule, self).__init__()
        self.input_size = 3 # r, m1, m2
        self.output_size = 1
        self.loss = F.mse_loss # torch.log(F.mrse_loss) + angle loss
        self.lr = 1e-3
        self.wd = 1e-5
        # relative mean weighted error - this wasn't helpful at all
        # self.loss = lambda y_hat, y: ((y_hat - y).abs() / (y.abs() + 1e-8)).mean()
        
        self.my_loggers = {
            'r_exp': lambda s: s.formula.weight[0][0].item(),
            'm1_exp': lambda s: s.formula.weight[0][1].item(),
            'm2_exp': lambda s: s.formula.weight[0][2].item()
        }
        

    def training_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self.forward(X)

        loss = self.loss(y_hat, y)
        self.log('train_loss', loss.item(), on_epoch=True, on_step=False)

        # log learning terms
        for name, fx in self.my_loggers.items():
            self.log(name, fx(self), on_epoch=True, on_step=False)

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self.forward(X)

        loss = self.loss(y_hat, y)
        self.log('validation_loss', loss.item(), on_epoch=True, on_step=False)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.wd)
        return optimizer 

class GnnLogLinearModel(BaseModule):
    def __init__(self):
        super(GnnLogLinearModel, self).__init__()
        self.formula = torch.nn.Linear(self.input_size, self.output_size) 

    def forward(self, X):
        N = X.shape[1]
        xs = X[:,:,:d].unsqueeze(2).expand(-1,-1,N,-1)
        ms = X[:,:,-1:].unsqueeze(2).expand(-1,-1,N,-1)

        x1 = xs
        x2 = xs.transpose(1,2)

        delta_x = x1 - x2
        delta_x_norm = (delta_x ** 2).sum(dim=-1, keepdim=True)**0.5 + 1e-9

        m1 = ms
        m2 = ms.transpose(1,2)

        inp = torch.cat((delta_x_norm, m1, m2), dim=-1)

        inp_log = torch.log(inp)

        # one linear layer
        forces_log = self.formula(inp_log)

        forces = torch.exp(forces_log)

        # the delta_x_norms were offset by a small number to avoid numeric problems
        # this is fine, when multiplying by delta_x, the self-self terms are zeroed out
        force_vectors = forces * delta_x / delta_x_norm

        # later learn this directionality too (the -1)
        return -1 * force_vectors.sum(dim=2) #/ m1[:,:,0,:]

In [29]:
class GnnLogLinearMassModel(BaseModule):
    def __init__(self, N=10, formula_given=False):
        super(GnnLogLinearMassModel, self).__init__()
        #self.formula = torch.nn.Linear(self.input_size, self.output_size) 
        self.formula = torch.nn.Linear(self.input_size, self.output_size, bias=False) 
        if formula_given:
            #self.formula.weight.requires_grad_(False)
            self.formula.weight = torch.nn.Parameter(torch.tensor([[-2.0, 1.0, 1.0]]), requires_grad=False)

        # one could be problematic when taking the log, won't set the scale
        #self.fixed_mass = torch.nn.Parameter(torch.tensor([[[10.0]]]), requires_grad=False)
        #self.other_masses = torch.nn.Parameter(torch.rand(1, N-1, 1), requires_grad=True)
        #self.masses = torch.cat((self.fixed_mass, self.other_masses), dim=1).cuda()
        self.masses = torch.nn.Parameter(torch.rand(1, N, 1))
        #self.masses[0,0,0] = 1.0

    def forward(self, X):
        N = X.shape[1]
        batch_size = X.shape[0]
        xs = X[:,:,:d].unsqueeze(2).expand(-1,-1,N,-1)
        #ms = X[:,:,-1:].unsqueeze(2).expand(-1,-1,N,-1)
        #ms = torch.cat((self.fixed_mass, self.other_masses), dim=1).expand(batch_size,-1,-1).unsqueeze(2).expand(-1,-1,N,-1)
        ms = self.masses.expand(batch_size,-1,-1).unsqueeze(2).expand(-1,-1,N,-1)
        x1 = xs
        x2 = xs.transpose(1,2)

        delta_x = x1 - x2
        delta_x_norm = (delta_x ** 2).sum(dim=-1, keepdim=True)**0.5 + 1e-9

        m1 = ms
        m2 = ms.transpose(1,2)

        inp = torch.cat((delta_x_norm, m1, m2), dim=-1)

        inp_log = torch.log(inp)

        # one linear layer
        forces_log = self.formula(inp_log)

        forces = torch.exp(forces_log)

        # the delta_x_norms were offset by a small number to avoid numeric problems
        # this is fine, when multiplying by delta_x, the self-self terms are zeroed out
        force_vectors = forces * delta_x / delta_x_norm

        # later learn this directionality too (the -1)
        return -1 * force_vectors.sum(dim=2) / m1[:,:,0,:]

In [13]:
class GnnLogLinearHiddenMassModel(BaseModule):
    def __init__(self, N=10, e=1, formula_given=False):
        super(GnnLogLinearHiddenMassModel, self).__init__()
        self.e = e
        self.formula = torch.nn.Linear(self.input_size, self.output_size, bias=False) 
        if formula_given:
            self.formula.weight = torch.nn.Parameter(torch.tensor([[-2.0, 1.0, 1.0]]), requires_grad=False)
        
        self.masses = torch.nn.Parameter(torch.exp(scale_exp * torch.rand(1, N+e, 1)))
        
        #self.position = torch.nn.Parameter(torch.rand(1, 1, d) * 0, requires_grad=False)
        self.position = torch.nn.Parameter(torch.exp(scale_exp * torch.rand(1, e, d)))

        self.my_loggers['pos_norm'] = lambda s: (s.position ** 2).sum() ** 0.5
        self.my_loggers['hidden_mass'] = lambda s: s.masses[0][0][0]

    def forward(self, X_o):
        batch_size = X_o.shape[0]
        #X = 
        X = torch.cat((self.position.expand(batch_size,-1,-1), X_o[:,:,:d]), dim=1)
        N = X.shape[1]
        
        xs = X.unsqueeze(2).expand(-1,-1,N,-1)
        #ms = X[:,:,-1:].unsqueeze(2).expand(-1,-1,N,-1)
        #ms = torch.cat((self.fixed_mass, self.other_masses), dim=1).expand(batch_size,-1,-1).unsqueeze(2).expand(-1,-1,N,-1)
        ms = self.masses.expand(batch_size,-1,-1).unsqueeze(2).expand(-1,-1,N,-1)
        x1 = xs
        x2 = xs.transpose(1,2)

        delta_x = x1 - x2 + 1e-20*(x1+x2)
        delta_x_norm = (delta_x ** 2).sum(dim=-1, keepdim=True)**0.5 + 1e-9

        m1 = ms
        m2 = ms.transpose(1,2)

        inp = torch.cat((delta_x_norm, m1, m2), dim=-1)

        inp_log = torch.log(inp)

        # one linear layer
        forces_log = self.formula(inp_log)

        forces = torch.exp(forces_log)

        # the delta_x_norms were offset by a small number to avoid numeric problems
        # this is fine, when multiplying by delta_x, the self-self terms are zeroed out
        force_vectors = forces * delta_x / delta_x_norm

        # later learn this directionality too (the -1)
        acceleration_vectors =  -1 * force_vectors.sum(dim=2) #/ m1[:,:,0,:]
        return acceleration_vectors[:,self.e:,:]

In [19]:
class GnnLogLinearModelMult(BaseModule):
    def __init__(self, mult = 3):
        super(GnnLogLinearModelMult, self).__init__()
        self.mult = mult
        self.formula = torch.nn.Linear(self.input_size * self.mult, self.output_size) 

    def forward(self, X):
        N = X.shape[1]
        xs = X[:,:,:d].unsqueeze(2).expand(-1,-1,N,-1)
        ms = X[:,:,-1:].unsqueeze(2).expand(-1,-1,N,-1)

        x1 = xs
        x2 = xs.transpose(1,2)

        delta_x = x1 - x2
        delta_x_norm = (delta_x ** 2).sum(dim=-1, keepdim=True)**0.5 + 1e-9

        m1 = ms
        m2 = ms.transpose(1,2)

        inp = torch.cat((delta_x_norm, m1, m2), dim=-1)
        inp_log = torch.log(inp).repeat(1,1,1,self.mult)

        # one linear layer
        forces_log = self.formula(inp_log)

        forces = torch.exp(forces_log)

        # the delta_x_norms were offset by a small number to avoid numeric problems
        # this is fine, when multiplying by delta_x, the self-self terms are zeroed out
        force_vectors = forces * delta_x / delta_x_norm

        # later learn this directionality too (the -1)
        return -1 * force_vectors.sum(dim=2) / m1[:,:,0,:]

In [22]:
class GnnLogLinearModelNonLin(BaseModule):
    def __init__(self, hidden_size):
        super(GnnLogLinearModelNonLin, self).__init__()
        self.hidden_size = hidden_size
        self.formula = torch.nn.Linear(self.input_size, self.output_size) 
        self.formula_2 = torch.nn.Sequential(
            torch.nn.Linear(self.input_size, self.hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(self.hidden_size, self.output_size)
        )

    def forward(self, X):
        N = X.shape[1]
        xs = X[:,:,:d].unsqueeze(2).expand(-1,-1,N,-1)
        ms = X[:,:,-1:].unsqueeze(2).expand(-1,-1,N,-1)

        x1 = xs
        x2 = xs.transpose(1,2)

        delta_x = x1 - x2
        delta_x_norm = (delta_x ** 2).sum(dim=-1, keepdim=True)**0.5 + 1e-9

        m1 = ms
        m2 = ms.transpose(1,2)

        inp = torch.cat((delta_x_norm, m1, m2), dim=-1)
        inp_log = torch.log(inp)

        # one linear layer
        forces_log = self.formula(inp_log)
        forces_log_2 = self.formula_2(inp_log)

        forces = torch.exp(forces_log + forces_log_2)

        # the delta_x_norms were offset by a small number to avoid numeric problems
        # this is fine, when multiplying by delta_x, the self-self terms are zeroed out
        force_vectors = forces * delta_x / delta_x_norm

        # later learn this directionality too (the -1)
        return -1 * force_vectors.sum(dim=2) / m1[:,:,0,:]

    # see get_parameters for more complex configurations: https://stackoverflow.com/questions/69217682/what-is-the-best-way-to-define-adam-optimizer-in-pytorch
    def configure_optimizers(self):
        optimizer = torch.optim.Adam([
                {'params': self.formula.parameters()},
                {'params': self.formula_2.parameters(), 'weight_decay': 1e-4}
            ], lr=self.lr)
        return optimizer 

In [None]:
from torch.utils.data import DataLoader, Dataset, random_split
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

for mult in [1,3,10,20]:

    train_set = list(zip(X, y))
    train_set_size = int(len(train_set) * 0.8)
    valid_set_size = len(train_set) - train_set_size
    train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(42))


    best_model = None
    best_score = 1e15
    times = 5

    for _ in range(times):
        #model = GnnLogLinearModelMult(mult=mult)
        model = GnnLogLinearModelNonLin(hidden_size=mult)
        #model = GnnLogLinearMassModel(10)
        #model = GnnLogLinearHiddenMassModel(10, e=10, formula_given=False)
        y_hat = model.forward(X)
        loss = model.loss(y_hat, y)
        if loss < best_score:
            print(loss)
            best_score = loss
            best_model = model


    model = best_model
    early_stop_callback = EarlyStopping(monitor="validation_loss", patience=30, verbose=False, mode="min")

    train_set = DataLoader(train_set, shuffle=True, batch_size=128)
    valid_set = DataLoader(valid_set, shuffle=True, batch_size=1000)

    logger = TensorBoardLogger("lightning_logs", name=f'gnn_log_linear_dec_hidden{mult}') # _masses, hidden_multiple

    # train with both splits
    trainer = pl.Trainer(gpus=1, max_epochs=10000,
                                #gradient_clip_val=0.5,
                                callbacks=[early_stop_callback],
                                logger=logger,
                                enable_progress_bar=False)

    trainer.fit(model, train_set, valid_set)

In [None]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
early_stop_callback = EarlyStopping(monitor="validation_loss", patience=300, verbose=False, mode="min")

train_set = DataLoader(train_set, shuffle=True, batch_size=128)
valid_set = DataLoader(valid_set, shuffle=True, batch_size=1000)

logger = TensorBoardLogger("lightning_logs", name='gnn_log_linear_mult1') # _masses, hidden_multiple

# train with both splits
trainer = pl.Trainer(gpus=1, max_epochs=10000,
                            #gradient_clip_val=0.5,
                            callbacks=[early_stop_callback],
                            logger=logger)

trainer.fit(model, train_set, valid_set)


In [None]:
# compare fitted masses to ground truth
g * m[0],  model.masses ** model.formula.weight[0][2]

(tensor([[40.4088],
         [ 2.9725],
         [ 5.6881],
         [ 8.0561],
         [32.0705],
         [ 2.0293],
         [ 1.0138],
         [59.0399],
         [60.7772],
         [ 3.3891]]),
 tensor([[[39.0646],
          [ 2.9494],
          [ 5.6877],
          [ 8.0410],
          [31.9989],
          [ 2.0037],
          [ 0.9688],
          [58.7323],
          [60.7730],
          [ 3.3401]]], device='cuda:0', grad_fn=<PowBackward1>))

In [None]:
# to run tensorboard in the terminal:
# tensorboard --logdir lightning_logs