In [None]:
# import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch_geometric.nn as pygnn
from graphene.dataloader.featurizers import MoleculeFeaturizer

In [8]:
feat = MoleculeFeaturizer(allow_unknown=True)

IN_CHANNELS = 30
OUT_CHANNELS = 64 # saida da segunda camada

feat('CCC')

Data(edge_attr=[4, 11], edge_index=[2, 4], x=[3, 30])

In [22]:
class ExampleDataset(Dataset):
    
    def __init__(self, path: str, target: str = None, **kwargs):
        self.feat = MoleculeFeaturizer(**kwargs)
        self.data = pd.read_csv(path)
        self.target = target

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = dict(self.data.iloc[idx, :])
        graph = self.feat(sample['smiles'])

        if self.target:
            graph.y = torch.tensor([sample[self.target]])
        
        return graph

In [25]:
dataset = ExampleDataset(path='zinc_subset.csv', target='mwt', allow_unknown=True)
dataset

<__main__.ExampleDataset at 0x1bcb23825b0>

In [29]:
from torch_geometric.data import DataLoader

dataloader = DataLoader(dataset, batch_size=20)

In [35]:
next(iter(dataloader))

Batch(batch=[413], edge_attr=[882, 11], edge_index=[2, 882], ptr=[21], x=[413, 30], y=[20])

In [50]:
from torch_geometric.nn import global_add_pool


class GlobalAddPool(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.func = global_add_pool
        
    def forward(self, x, edge_list):
        return self.global_add_pool(x, edge_list)


class ExampleModel(nn.Module):
    
    def __init__(self, num_layers: int = 4, in_channels: int = 30, hidden_channels: int = 64, out_channels: int = 1, **kwargs):
        super().__init__()
        layers = []
        
        self.pool_function = GlobalAddPool()
        
        for _ in range(num_layers):
            mlp = nn.Sequential(
                nn.Linear(in_channels, hidden_channels),
                nn.ReLU(),
                nn.Linear(hidden_channels, hidden_channels)
            )
            
            layers.append(pygnn.GINConv(mlp, train_eps=True))
            in_channels = hidden_channels

        self.layers = nn.ModuleList(layers)
        self.out = nn.Linear(in_features=hidden_channels, out_features=1)
        
    def forward(self, batch):
        x, edge_index, batch = batch.x, batch.edge_index, batch.batch
        
        for layer in self.layers:
            x = layer(x, edge_index)
            
        pred = self.pool_function(x, batch)
        pred = self.out(pred)
        
        return pred

In [51]:
model = ExampleModel()
model

ExampleModel(
  (pool_function): GlobalAddPool()
  (layers): ModuleList(
    (0): GINConv(nn=Sequential(
      (0): Linear(in_features=30, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
    ))
    (1): GINConv(nn=Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
    ))
    (2): GINConv(nn=Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
    ))
    (3): GINConv(nn=Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
    ))
  )
  (out): Linear(in_features=64, out_features=1, bias=True)
)