# Adaptations of COATGIN and GNNS for Normal, 3d, and Line graph dual

In [1]:
__dataset_root_repo = '/data/pcqm4mv2_kpgt/'  # the kpgt content is there too
__codes_repo = '/home/shayan/phoenix/graphite/'
__conformers_filepath = '/data/conformers.np'

In [5]:
import torch
import torch.nn
import sys
sys.path.insert(0, __codes_repo)
from tqdm import tqdm
from torch.utils.data.dataloader import DataLoader
from torch_geometric.loader.dataloader import Collater
from graphite.data.pcqm4mv2.pyg import PCQM4Mv2DatasetFull
from graphite.data.pcqm4mv2.pyg.collator import collate_fn, default_collate_fn
from graphite.data.utilities.sequence_collate.utilities import pad_sequence, pad_sequence_2d
from graphite.data.pcqm4mv2.pyg.transforms import ComenetEdgeFeatures, LineGraphTransform
from torchvision.transforms import Compose
from graphite.cortex.model.model.gnn import CoAtGINGeneralPipeline
from graphite.utilities.miscellaneous import count_parameters

## Usual Graph

In [6]:
%%time
dataset = PCQM4Mv2DatasetFull(
    root=__dataset_root_repo,
    descriptor=True,
    fingerprint=True,
    conformers_memmap=__conformers_filepath,
    num_conformers_to_return=2,
    transform=Compose([
        ComenetEdgeFeatures(cutoff=3.0, edge_index_key='edge_index', concatenate_with_edge_attr=True)
    ]),
)

CPU times: user 1.43 s, sys: 6.5 s, total: 7.92 s
Wall time: 7.92 s


In [7]:
def collate_fn(batch):
    fingerprint = torch.stack([g.fingerprint for g in batch])
    molecule_descriptor = torch.stack([g.molecule_descriptor for g in batch])
    
    g = default_collate_fn(batch)
    del g.fingerprint
    del g.molecule_descriptor
    g.fingerprint = fingerprint
    g.molecule_descriptor = molecule_descriptor[:, 1:]
    return g

In [8]:
split_idx = dataset.get_idx_split()
train_dataset = dataset[split_idx["train"]]
valid_dataset = dataset[split_idx["valid"]]

test_dataset = dataset[split_idx["test-dev"]]

train_sampler = torch.utils.data.RandomSampler(train_dataset)
valid_sampler = torch.utils.data.SequentialSampler(valid_dataset)
test_sampler = torch.utils.data.SequentialSampler(test_dataset)

dataloader_args=dict(batch_size=32, collate_fn=collate_fn)

train_dataloader = DataLoader(train_dataset, sampler=train_sampler, **dataloader_args)
val_dataloader = DataLoader(valid_dataset, sampler=valid_sampler, **dataloader_args)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, **dataloader_args)

In [9]:
g = next(iter(train_dataloader))
g

DataBatch(edge_index=[2, 952], edge_attr=[952, 21], x=[461, 9], y=[32], batch=[461], ptr=[33], fingerprint=[32, 512], molecule_descriptor=[32, 200])

In [10]:
model = CoAtGINGeneralPipeline(
    node_encoder_config=dict(
        type='CoAtGIN',
        args=dict(
            num_layers=5,
            model_dim=256,
            conv_hop=2,
            conv_kernel=2,
            use_virt=True,
            use_att=True,
            line_graph=False,
            pos_features=18
        )
    ),
    graph_pooling="sum"
)
print(f"# of parameters: {count_parameters(model)}")

# of parameters: 6649088


In [11]:
graph_reps = model(g)
graph_reps.shape

torch.Size([32, 256])

In [12]:
model = CoAtGINGeneralPipeline(
    node_encoder_config=dict(
        type='GNNWithVirtualNode',
        args=dict(
            num_layers=5,
            model_dim=256,
            drop_ratio=0.5,
            JK="last",
            residual=True,
            gnn_type='gin',
            line_graph=False,
            pos_features=18
        )
    ),
    graph_pooling="sum"
)
print(f"# of parameters: {count_parameters(model)}")

# of parameters: 1281541


In [13]:
graph_reps = model(g)
graph_reps.shape

torch.Size([32, 256])

In [14]:
model = CoAtGINGeneralPipeline(
    node_encoder_config=dict(
        type='GNNWithVirtualNode',
        args=dict(
            num_layers=5,
            model_dim=256,
            drop_ratio=0.5,
            JK="last",
            residual=True,
            gnn_type='gcn',
            line_graph=False,
            pos_features=18
        )
    ),
    graph_pooling="sum"
)
print(f"# of parameters: {count_parameters(model)}")

# of parameters: 951296


In [15]:
graph_reps = model(g)
graph_reps.shape

torch.Size([32, 256])

In [16]:
model = CoAtGINGeneralPipeline(
    node_encoder_config=dict(
        type='GNN',
        args=dict(
            num_layers=5,
            model_dim=256,
            drop_ratio=0.5,
            JK="last",
            residual=True,
            gnn_type='gin',
            line_graph=False,
            pos_features=18
        )
    ),
    graph_pooling="sum"
)
print(f"# of parameters: {count_parameters(model)}")

# of parameters: 750853


In [17]:
graph_reps = model(g)
graph_reps.shape

torch.Size([32, 256])

In [18]:
model = CoAtGINGeneralPipeline(
    node_encoder_config=dict(
        type='GNN',
        args=dict(
            num_layers=5,
            model_dim=256,
            drop_ratio=0.5,
            JK="last",
            residual=True,
            gnn_type='gcn',
            line_graph=False,
            pos_features=18
        )
    ),
    graph_pooling="sum"
)
print(f"# of parameters: {count_parameters(model)}")

# of parameters: 420608


In [19]:
graph_reps = model(g)
graph_reps.shape

torch.Size([32, 256])

## Line Graph

In [20]:
%%time
dataset = PCQM4Mv2DatasetFull(
    root=__dataset_root_repo,
    descriptor=True,
    fingerprint=True,
    conformers_memmap=__conformers_filepath,
    num_conformers_to_return=2,
    transform=Compose([
        ComenetEdgeFeatures(cutoff=3.0, edge_index_key='edge_index', concatenate_with_edge_attr=True),
        LineGraphTransform(bring_in_adjacent_nodes=True, keep_as_is=['fingerprint', 'molecule_descriptor', 'y'])
    ]),
)

CPU times: user 1.04 s, sys: 6.32 s, total: 7.36 s
Wall time: 7.27 s


In [21]:
def collate_fn(batch):
    fingerprint = torch.stack([g.fingerprint for g in batch])
    molecule_descriptor = torch.stack([g.molecule_descriptor for g in batch])
    
    g = default_collate_fn(batch)
    del g.fingerprint
    del g.molecule_descriptor
    g.fingerprint = fingerprint
    g.molecule_descriptor = molecule_descriptor[:, 1:]
    return g

In [22]:
split_idx = dataset.get_idx_split()
train_dataset = dataset[split_idx["train"]]
valid_dataset = dataset[split_idx["valid"]]

test_dataset = dataset[split_idx["test-dev"]]

train_sampler = torch.utils.data.RandomSampler(train_dataset)
valid_sampler = torch.utils.data.SequentialSampler(valid_dataset)
test_sampler = torch.utils.data.SequentialSampler(test_dataset)

dataloader_args=dict(batch_size=32, collate_fn=collate_fn)

train_dataloader = DataLoader(train_dataset, sampler=train_sampler, **dataloader_args)
val_dataloader = DataLoader(valid_dataset, sampler=valid_sampler, **dataloader_args)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, **dataloader_args)

In [23]:
g_line = next(iter(train_dataloader))
g_line

DataBatch(x=[970, 39], edge_index=[2, 2300], num_nodes=970, y=[32], batch=[970], ptr=[33], fingerprint=[32, 512], molecule_descriptor=[32, 200])

In [24]:
model = CoAtGINGeneralPipeline(
    node_encoder_config=dict(
        type='CoAtGIN',
        args=dict(
            num_layers=5,
            model_dim=256,
            conv_hop=2,
            conv_kernel=2,
            use_virt=True,
            use_att=True,
            line_graph=True
        )
    ),
    graph_pooling="sum"
)
print(f"# of parameters: {count_parameters(model)}")

# of parameters: 6755620


In [25]:
graph_reps = model(g_line)
graph_reps.shape

torch.Size([32, 256])

In [26]:
model = CoAtGINGeneralPipeline(
    node_encoder_config=dict(
        type='CoAtGIN',
        args=dict(
            num_layers=5,
            model_dim=256,
            conv_hop=2,
            conv_kernel=2,
            use_virt=True,
            use_att=True,
            line_graph=True
        )
    ),
    graph_pooling="sum"
)
print(f"# of parameters: {count_parameters(model)}")

# of parameters: 6755620


In [27]:
graph_reps = model(g_line)
graph_reps.shape

torch.Size([32, 256])

In [28]:
model = CoAtGINGeneralPipeline(
    node_encoder_config=dict(
        type='GNNWithVirtualNode',
        args=dict(
            num_layers=5,
            model_dim=256,
            drop_ratio=0.5,
            JK="last",
            residual=True,
            gnn_type='gin',
            line_graph=True
        )
    ),
    graph_pooling="sum"
)
print(f"# of parameters: {count_parameters(model)}")

# of parameters: 1510953


In [29]:
graph_reps = model(g_line)
graph_reps.shape

torch.Size([32, 256])

In [30]:
model = CoAtGINGeneralPipeline(
    node_encoder_config=dict(
        type='GNNWithVirtualNode',
        args=dict(
            num_layers=5,
            model_dim=256,
            drop_ratio=0.5,
            JK="last",
            residual=True,
            gnn_type='gin',
            line_graph=True
        )
    ),
    graph_pooling="sum"
)
print(f"# of parameters: {count_parameters(model)}")

# of parameters: 1510953


In [31]:
graph_reps = model(g_line)
graph_reps.shape

torch.Size([32, 256])

In [32]:
model = CoAtGINGeneralPipeline(
    node_encoder_config=dict(
        type='GNN',
        args=dict(
            num_layers=5,
            model_dim=256,
            drop_ratio=0.5,
            JK="last",
            residual=True,
            gnn_type='gin',
            line_graph=True
        )
    ),
    graph_pooling="sum"
)
print(f"# of parameters: {count_parameters(model)}")

# of parameters: 980265


In [33]:
graph_reps = model(g_line)
graph_reps.shape

torch.Size([32, 256])

In [34]:
model = CoAtGINGeneralPipeline(
    node_encoder_config=dict(
        type='GNN',
        args=dict(
            num_layers=5,
            model_dim=256,
            drop_ratio=0.5,
            JK="last",
            residual=True,
            gnn_type='gcn',
            line_graph=True
        )
    ),
    graph_pooling="sum"
)
print(f"# of parameters: {count_parameters(model)}")

# of parameters: 650020


In [35]:
graph_reps = model(g_line)
graph_reps.shape

torch.Size([32, 256])