In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.functional as F
import pytorch_lightning as pl
import numpy as np
import pickle

  from .autonotebook import tqdm as notebook_tqdm


One pair of (X,y) looks like:
- d = 3 # number of dimensions
- N # number of objects
- X: N x (2 * d + 2) # positions, velocities, mass, charge
- y: N x d # accelerations

In [8]:
batch_size = 2048 * 16
visible = 10
hidden = 0
N = visible + hidden
d = 3

scale_exp = 5

pos = torch.exp(scale_exp * torch.rand(batch_size, N, d))
# make it centered at 0
pos -= pos.mean(axis=1, keepdim = True) 

vel = torch.exp(scale_exp * torch.rand(batch_size, N, d))

# assign fixed positions, velocities??? (this shouldn't matter for now) to hidden objects (this only works for one that is put in the center for now)
pos[:,:hidden,:] *= 0
vel[:,:hidden,:] *= 0

m = torch.rand(1, N, 1)
# hidden mass:
m[0,:hidden,0] = m[0,:hidden,0] * 0 + 1

m = torch.exp(scale_exp * m)
m = m.expand(batch_size,-1,-1)

ch = torch.rand(1, N, 1)
ch_sign = torch.randint(0, 2, (1, N, 1)) * 2 - 1
# hidden mass:
ch[0,:hidden,0] = m[0,:hidden,0] * 0 + 1

ch = torch.exp(scale_exp * ch) * ch_sign
ch = ch.expand(batch_size,-1,-1)

dt = 0.01
k = 0.5

chs = ch.unsqueeze(2).expand(-1,-1,N,-1)
ch1 = chs
ch2 = chs.transpose(1,2)

X_list = []
y_list = []

for _ in range(1):
    xs = pos.unsqueeze(2).expand(-1,-1,N,-1)
    x1 = xs
    x2 = xs.transpose(1,2)

    delta_x = x1 - x2
    delta_x_norm = (delta_x ** 2).sum(dim=-1, keepdim=True)**0.5 + 1e-9
    forces = k * ch1 * ch2 / delta_x_norm ** 2

    # the delta_x_norms were offset by a small number to avoid numeric problems
    # this is fine, when multiplying by delta_x, the self-self terms are zeroed out
    force_vectors = forces * delta_x / delta_x_norm
    a = force_vectors.sum(dim=2) / m

    X_list.append(torch.cat((pos, vel, ch, m), dim=-1))
    y_list.append(a)

    # simple 1 step - could use a more intelligent integrator here.
    vel += a * dt
    pos += vel * dt

X = torch.cat(X_list)
y = torch.cat(y_list)

# remove hidden objects
X = X[:,hidden:,:]
y = y[:,hidden:,:]

# add some random noise
y += 0.1 * torch.randn(y.shape) * y.mean()


In [9]:
ch

tensor([[[ 51.6732],
         [ 44.2164],
         [ -1.4977],
         ...,
         [ -1.2271],
         [-40.4799],
         [ -4.9338]],

        [[ 51.6732],
         [ 44.2164],
         [ -1.4977],
         ...,
         [ -1.2271],
         [-40.4799],
         [ -4.9338]],

        [[ 51.6732],
         [ 44.2164],
         [ -1.4977],
         ...,
         [ -1.2271],
         [-40.4799],
         [ -4.9338]],

        ...,

        [[ 51.6732],
         [ 44.2164],
         [ -1.4977],
         ...,
         [ -1.2271],
         [-40.4799],
         [ -4.9338]],

        [[ 51.6732],
         [ 44.2164],
         [ -1.4977],
         ...,
         [ -1.2271],
         [-40.4799],
         [ -4.9338]],

        [[ 51.6732],
         [ 44.2164],
         [ -1.4977],
         ...,
         [ -1.2271],
         [-40.4799],
         [ -4.9338]]])

In [12]:
class BaseModule(pl.LightningModule):
    def __init__(self):
        super(BaseModule, self).__init__()
        self.input_size = 3 # r, ch1, ch2
        self.output_size = 1
        self.loss = F.mse_loss # torch.log(F.mrse_loss) + angle loss
        self.lr = 1e-3
        self.wd = 0 #1e-5
        # relative mean weighted error - this wasn't helpful at all
        # self.loss = lambda y_hat, y: ((y_hat - y).abs() / (y.abs() + 1e-8)).mean()
        
        self.my_loggers = {
            'r_exp': lambda s: s.formula.weight[0][0].item(),
            'ch1_exp': lambda s: s.formula.weight[0][1].item(),
            'ch2_exp': lambda s: s.formula.weight[0][2].item()
        }
        

    def training_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self.forward(X)

        loss = self.loss(y_hat, y)
        self.log('train_loss', loss.item(), on_epoch=True, on_step=False)

        # log learning terms
        for name, fx in self.my_loggers.items():
            self.log(name, fx(self), on_epoch=True, on_step=False)

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self.forward(X)

        loss = self.loss(y_hat, y)
        self.log('validation_loss', loss.item(), on_epoch=True, on_step=False)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.wd)
        return optimizer 

class GnnLogLinearModel(BaseModule):
    def __init__(self):
        super(GnnLogLinearModel, self).__init__()
        self.formula = torch.nn.Linear(self.input_size, self.output_size) 

    def forward(self, X):
        N = X.shape[1]
        xs = X[:,:,:d].unsqueeze(2).expand(-1,-1,N,-1)
        chs = X[:,:,-2:-1].unsqueeze(2).expand(-1,-1,N,-1)
        m = X[:,:,-1:]

        x1 = xs
        x2 = xs.transpose(1,2)

        delta_x = x1 - x2
        delta_x_norm = (delta_x ** 2).sum(dim=-1, keepdim=True)**0.5 + 1e-9

        ch1 = chs
        ch2 = chs.transpose(1,2)

        inp = torch.cat((delta_x_norm, ch1.abs(), ch2.abs()), dim=-1)

        inp_log = torch.log(inp) 

        # one linear layer
        forces_log = self.formula(inp_log) 

        forces = torch.exp(forces_log) * ch1.sign() * ch2.sign()

        # the delta_x_norms were offset by a small number to avoid numeric problems
        # this is fine, when multiplying by delta_x, the self-self terms are zeroed out
        force_vectors = forces * delta_x / delta_x_norm

        # later learn this directionality too (the -1)
        return force_vectors.sum(dim=2) / m

In [16]:
class GnnLogLinearChargeMassModel(BaseModule):
    def __init__(self, N=10, formula_given=False):
        super(GnnLogLinearChargeMassModel, self).__init__()
        #self.formula = torch.nn.Linear(self.input_size, self.output_size) 
        self.formula = torch.nn.Linear(self.input_size, self.output_size, bias=False) 
        if formula_given:
            #self.formula.weight.requires_grad_(False)
            self.formula.weight = torch.nn.Parameter(torch.tensor([[-2.0, 1.0, 1.0]]), requires_grad=False)

        # one could be problematic when taking the log, won't set the scale
        #self.fixed_mass = torch.nn.Parameter(torch.tensor([[[10.0]]]), requires_grad=False)
        #self.other_masses = torch.nn.Parameter(torch.rand(1, N-1, 1), requires_grad=True)
        #self.masses = torch.cat((self.fixed_mass, self.other_masses), dim=1).cuda()
        self.masses = torch.nn.Parameter(torch.rand(1, N, 1))
        self.charges = torch.nn.Parameter(torch.rand(1, N, 1) * 2 - 1) 
        #self.masses[0,0,0] = 1.0

    def forward(self, X):
        N = X.shape[1]
        batch_size = X.shape[0]
        xs = X[:,:,:d].unsqueeze(2).expand(-1,-1,N,-1)
        m = self.masses.expand(batch_size,-1,-1)
        chs = self.charges.expand(batch_size,-1,-1).unsqueeze(2).expand(-1,-1,N,-1)
        x1 = xs
        x2 = xs.transpose(1,2)


        delta_x = x1 - x2
        delta_x_norm = (delta_x ** 2).sum(dim=-1, keepdim=True)**0.5 + 1e-9

        ch1 = chs
        ch2 = chs.transpose(1,2)

        # also avoiding the use of .abs by squaring (as the formula can be learnt either way)
        inp = torch.cat((delta_x_norm, ch1 ** 2, ch2 ** 2), dim=-1)

        inp_log = torch.log(inp) 

        # one linear layer
        forces_log = self.formula(inp_log) 

        # this way we avoid using the sign and 
        forces = ch1 * ch2 * torch.exp(forces_log)

        # the delta_x_norms were offset by a small number to avoid numeric problems
        # this is fine, when multiplying by delta_x, the self-self terms are zeroed out
        force_vectors = forces * delta_x / delta_x_norm

        # later learn this directionality too (the -1)
        return force_vectors.sum(dim=2) / m


In [17]:
from torch.utils.data import DataLoader, Dataset, random_split
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

for mult in [1]: # [1,3,10,20]:

    train_set = list(zip(X, y))
    train_set_size = int(len(train_set) * 0.8)
    valid_set_size = len(train_set) - train_set_size
    train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(42))


    best_model = None
    best_score = 1e15
    times = 5

    for _ in range(times):
        #model = GnnLogLinearModelMult(mult=mult)
        #model = GnnLogLinearModel()
        model = GnnLogLinearChargeMassModel(10)
        #model = GnnLogLinearMassModel(10)
        #model = GnnLogLinearHiddenMassModel(10, e=10, formula_given=False)
        y_hat = model.forward(X)
        loss = model.loss(y_hat, y)
        if loss < best_score:
            print(loss)
            best_score = loss
            best_model = model


    model = best_model
    early_stop_callback = EarlyStopping(monitor="validation_loss", patience=30, verbose=False, mode="min")

    train_set = DataLoader(train_set, shuffle=True, batch_size=128)
    valid_set = DataLoader(valid_set, shuffle=True, batch_size=1000)

    logger = TensorBoardLogger("lightning_logs", name=f'coulomb_log_linear_charge_mass') # _masses, hidden_multiple

    # train with both splits
    trainer = pl.Trainer(gpus=1, max_epochs=10000,
                                #gradient_clip_val=0.5,
                                callbacks=[early_stop_callback],
                                logger=logger,
                                enable_progress_bar=False)

    trainer.fit(model, train_set, valid_set)

tensor(33055.9883, grad_fn=<MseLossBackward0>)
tensor(2.7744, grad_fn=<MseLossBackward0>)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: lightning_logs\coulomb_log_linear_charge_mass
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type   | Params
-----------------------------------
0 | formula | Linear | 3     
-----------------------------------
23        Trainable params
0         Non-trainable params
23        Total params
0.000     Total estimated model params size (MB)
  rank_zero_warn(


In [None]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
early_stop_callback = EarlyStopping(monitor="validation_loss", patience=300, verbose=False, mode="min")

train_set = DataLoader(train_set, shuffle=True, batch_size=128)
valid_set = DataLoader(valid_set, shuffle=True, batch_size=1000)

logger = TensorBoardLogger("lightning_logs", name='gnn_log_linear_mult1') # _masses, hidden_multiple

# train with both splits
trainer = pl.Trainer(gpus=1, max_epochs=10000,
                            #gradient_clip_val=0.5,
                            callbacks=[early_stop_callback],
                            logger=logger)

trainer.fit(model, train_set, valid_set)


In [20]:
# it's almost impossible to disentangle these now, how should we now if it is close? (what if we ensure symmtery - exponent should be the same for the two charges?)
k * ch[0], model.charges ** 2 ** model.formula.weight[0][1]

(tensor([[ 25.8366],
         [ 22.1082],
         [ -0.7488],
         [ 21.0288],
         [ -0.7313],
         [  3.0027],
         [ 10.9575],
         [ -0.6136],
         [-20.2400],
         [ -2.4669]]),
 Parameter containing:
 tensor([[[-1.2047],
          [-1.1463],
          [ 0.1734],
          [-1.0718],
          [ 0.1755],
          [-0.3780],
          [-0.7624],
          [ 0.1659],
          [ 1.0736],
          [ 0.3332]]], requires_grad=True))

In [22]:
m[0], model.masses

(tensor([[ 62.9773],
         [ 74.8648],
         [ 88.6603],
         [ 14.6456],
         [ 60.0125],
         [ 10.1003],
         [ 48.7612],
         [  1.9424],
         [101.5216],
         [ 61.0661]]),
 Parameter containing:
 tensor([[[0.0697],
          [0.0886],
          [2.5736],
          [0.0202],
          [2.2849],
          [0.1049],
          [0.1239],
          [0.0823],
          [0.1410],
          [0.6899]]], requires_grad=True))

In [None]:
# compare fitted masses to ground truth
g * m[0],  model.masses ** model.formula.weight[0][2]

(tensor([[40.4088],
         [ 2.9725],
         [ 5.6881],
         [ 8.0561],
         [32.0705],
         [ 2.0293],
         [ 1.0138],
         [59.0399],
         [60.7772],
         [ 3.3891]]),
 tensor([[[39.0646],
          [ 2.9494],
          [ 5.6877],
          [ 8.0410],
          [31.9989],
          [ 2.0037],
          [ 0.9688],
          [58.7323],
          [60.7730],
          [ 3.3401]]], device='cuda:0', grad_fn=<PowBackward1>))

In [None]:
# to run tensorboard in the terminal:
# tensorboard --logdir lightning_logs