In [None]:
%%capture
!pip install wandb
!apt-get install git
!apt autoremove
!pip3 install awscli

!mkdir -p /root/workspace/data/
!mkdir -p /root/workspace/out/

In [None]:
%%capture
%cd /root/workspace

!git clone https://github.com/chaitjo/geometric-gnn-dojo.git
!git clone https://github.com/Open-Catalyst-Project/ocp.git
!pip3 install -r ./steerable-v1/requirements.txt

In [None]:
%%capture
%cd /root/workspace/ocp/
!pip3 install -e .
!pip3 install lmdb
!pip3 install orjson

In [None]:
%cd /root/workspace/steerable-v1/
!git stash
!git pull

In [None]:
%cd /root/workspace/geometric-gnn-dojo/
!git stash
!git pull

In [None]:
# %%capture
%cd /root/workspace
!cp ./steerable-v1/train_utils.py ./geometric-gnn-dojo/experiments/utils/train_utils.py # remove once iclr is pulled

!cp ./steerable-v1/comenet.py ./geometric-gnn-dojo/models/ # remove once iclr is pulled
!echo "from models.comenet import ComENetModel" >> ./geometric-gnn-dojo/models/__init__.py

!cp ./steerable-v1/painn.py ./geometric-gnn-dojo/models/painn.py # remove once iclr is pulled
!echo "from models.painn import PaiNN" >> ./geometric-gnn-dojo/models/__init__.py

!cp ./steerable-v1/escn.py ./geometric-gnn-dojo/models/escn.py # remove once iclr is pulled
!echo "from models.escn import eSCN" >> ./geometric-gnn-dojo/models/__init__.py

!cp ./steerable-v1/equiformer_v2.py ./geometric-gnn-dojo/models/equiformer.py # remove once iclr is pulled
!echo "from models.equiformer import EquiformerV2_OC20" >> ./geometric-gnn-dojo/models/__init__.py

!cp ./steerable-v1/gemnet_t.py ./geometric-gnn-dojo/models/gemnet_t.py # remove once iclr is pulled
!echo "from models.gemnet_t import GemNetT" >> ./geometric-gnn-dojo/models/__init__.py

!cp ./steerable-v1/gemnet_q.py ./geometric-gnn-dojo/models/gemnet_q.py # remove once iclr is pulled
!echo "from models.gemnet_q import GemNetOC" >> ./geometric-gnn-dojo/models/__init__.py

!cp ./steerable-v1/_steerable.py ./geometric-gnn-dojo/models/ # remove once iclr is pulled
!cp ./steerable-v1/segnn.py ./geometric-gnn-dojo/models/ # remove once iclr is pulled
!echo "from models.segnn import SEGNN" >> ./geometric-gnn-dojo/models/__init__.py

# Models

# Datasets

## Simple Chain Dataset

In [None]:
import sys
sys.path.append('/root/workspace/geometric-gnn-dojo/')

import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_undirected
import e3nn
from functools import partial

from torch_geometric.seed import seed_everything

from experiments.utils.plot_utils import plot_3d

def create_kchains(k,n):
    seed_everything(10)
    assert k >= 2
    assert n >= 1

    dataset = []
    for i in range(n):
      M = torch.rand(3,3)
      Q, _ = torch.linalg.qr(M, mode='complete')
      b = torch.rand(3)

      # Graph 0
      atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )
      cell = torch.diag(torch.ones(3,dtype=torch.float)).view(1,3,3)
      edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )
      pos = torch.FloatTensor(
          [[-4, -3, 0]] +
          [[0, 5*i , 0] for i in range(k)] +
          [[4, 5*(k-1) + 3, 0]]
      )
      # center_of_mass = torch.mean(pos, dim=0)
      # pos = pos - center_of_mass
      y = torch.LongTensor([0])  # Label gvp0
      # data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
      # data1.edge_index = to_undirected(data1.edge_index)
      # dataset.append(data1)
      transf_pos = [Q@val+b for val in pos]
      transf_pos = torch.vstack(transf_pos)
      data1 = Data(atoms=atoms, edge_index=edge_index, pos=transf_pos, y=y, natoms=k+2, cell=cell)
      data1.edge_index = to_undirected(data1.edge_index)
      dataset.append(data1)

      # Graph 1
      atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )
      edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )
      pos = torch.FloatTensor(
          [[4, -3, 0]] +
          [[0, 5*i , 0] for i in range(k)] +
          [[4, 5*(k-1) + 3, 0]]
      )
      # center_of_mass = torch.mean(pos, dim=0)
      # pos = pos - center_of_mass
      y = torch.LongTensor([1])  # Label 1
      # data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
      # data2.edge_index = to_undirected(data2.edge_index)
      # dataset.append(data2)
      transf_pos = [Q@val+b for val in pos]
      transf_pos = torch.vstack(transf_pos)
      data2 = Data(atoms=atoms, edge_index=edge_index, pos=transf_pos, y=y, natoms=k+2, cell=cell)
      data2.edge_index = to_undirected(data2.edge_index)
      dataset.append(data2)

    return dataset

# Create dataset
k = 4
dataset = create_kchains(k=k,n=1)
for data in dataset:
    print(data.pos)
    plot_3d(data, lim=2*k)

# Experiments

## Simple Chain Experiment

In [None]:
# Create dataloaders
import random

from experiments.utils.train_utils import run_experiment
from models import SchNetModel, DimeNetPPModel, SphereNetModel, ComENetModel, GemNetT, GemNetOC, EGNNModel, GVPGNNModel, PaiNN, eSCN, EquiformerV2_OC20, MACEModel, TFNModel, SEGNN


total = 50
seed_everything(10)
permuted_g1 = list(range(total))
permuted_g2 = list(range(total))
random.shuffle(permuted_g1)
random.shuffle(permuted_g2)

print('split_pt1',permuted_g1)
# print('split_pt2',permuted_g2)


def run(model_name,cutoff_name=None):
  for k, num_layers in zip([2,2,2,3,3,3,3,4,4,4,4,4],[1,2,3,1,2,3,4,2,3,4,5,6]):
      train_n = int(.5*total)
      val_n = int(.3*total)
      test_n = int(.2*total)

      dataset = create_kchains(k=k, n=total)

      train_data = [dataset[2*i+1] for i in permuted_g1[:train_n]]
      train_data = train_data + [dataset[2*i] for i in permuted_g1[:train_n]]
      dataloader = DataLoader(train_data, batch_size=32, shuffle=True)

      val_data = [dataset[2*i+1] for i in permuted_g1[train_n:train_n+val_n]]
      val_data = val_data + [dataset[2*i] for i in permuted_g1[train_n:train_n+val_n]]
      val_loader = DataLoader(val_data, batch_size=32, shuffle=False)

      test_data = [dataset[2*i+1] for i in permuted_g1[train_n+val_n:]]
      test_data = test_data + [dataset[2*i] for i in permuted_g1[train_n+val_n:]]
      test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

      for name,count_data in zip(['train','val','test'],[train_data, val_data, test_data]):
        all_y_values = torch.cat([data.y for data in count_data])
        unique_values, counts = all_y_values.unique(return_counts=True)
        value_counts = {value.item(): count.item() for value, count in zip(unique_values, counts)}
        print(name,value_counts)

      print(f"\nNumber of layers: {num_layers}")
      print(f"Chain Length: {k}")

      correlation = 2
      kwargs = {cutoff_name:5.1} if cutoff_name else {}
      model = {
          # INV
          "schnet": SchNetModel,
          "dimenet": DimeNetPPModel,
          "spherenet": SphereNetModel,
          "comenet": partial(ComENetModel, hidden_channels=128, num_output_layers=2),
          # Equiv
          "egnn": EGNNModel,
          "gvp": partial(GVPGNNModel, s_dim=32, v_dim=1),
          # Steerable
          "mace_1": partial(MACEModel, correlation=correlation, max_ell=1),
          "mace_2": partial(MACEModel, correlation=correlation, max_ell=2),
          "escn_1": partial(eSCN, lmax_list=[1], mmax_list=[1], hidden_channels=256),#, sphere_channels= 16, hidden_channels = 128, edge_channels = 16, num_sphere_samples = 16),
          "escn_2": partial(eSCN, lmax_list=[2], mmax_list=[2], hidden_channels=(256*4//9)),#, sphere_channels= 16, hidden_channels = 128, edge_channels = 16, num_sphere_samples = 16),
          "equiformer_0":partial(EquiformerV2_OC20, attn_hidden_channels=64, lmax_list=[0], mmax_list=[0]),
          "equiformer_1":partial(EquiformerV2_OC20, attn_hidden_channels=16, lmax_list=[1], mmax_list=[1]),
          "equiformer_2":partial(EquiformerV2_OC20, attn_hidden_channels=7, lmax_list=[2], mmax_list=[2]),
          # If Time
          "gemnet_t": GemNetT,
          "gemnet_q": GemNetOC,
          "painn":PaiNN,
          "tfn": TFNModel,
          "segnn": SEGNN,
      }[model_name](num_layers=num_layers, in_dim=1, out_dim=2, **kwargs)

      best_val_acc, test_acc, train_time = run_experiment(
          model,
          dataloader,
          val_loader,
          test_loader,
          n_epochs=100,
          n_times=10,
          verbose=False,
          device='cuda',
      )


In [None]:
# SCHNET
run('schnet','cutoff')

In [None]:
# DIMENET
run('dimenet','cutoff')

In [None]:
# SPHERENET
run('spherenet','cutoff')

In [None]:
# COMENET
run('comenet','cutoff')

In [None]:
#EGNN
run('egnn')

In [None]:
#GVP
run('gvp','r_max')

In [None]:
#eSCN
run('escn_1','cutoff')

In [None]:
#eSCN
run('escn_2','cutoff')

In [None]:
#MACE
run('mace_1','r_max')

In [None]:
#MACE
run('mace_2','r_max')


In [None]:
#Equiformer
run('equiformer_0','max_radius')

In [None]:
#Equiformer
run('equiformer_1','max_radius')

In [None]:
#Equiformer
run('equiformer_2','max_radius')

In [None]:
#Equiformer
run('equiformer_2','max_radius')

In [None]:
#PaiNN
run('painn','cutoff')

In [None]:
#GemNetT
run('gemnet_t','cutoff')

In [None]:
#GemNetQ
run('gemnet_q','cutoff')

In [None]:
#TFN
run('tfn','r_max')

In [None]:
#SEGNN
run('segnn')