from Datas import dataloader
import pandas as pd
target_list = ['Result0']
batch_size = 64
import pickle

df = pd.read_csv('esol.csv')

train_loader,_,_,_ = dataloader(df, batch_size, target_list, shuffle=True, drop_last=False, dim=1)

In [122]:
for data in train_loader:
    print(data)
    print(data.x.shape)
    break


Batch(atom_feature_dim=[0], batch=[877], edge_feat=[1810, 10], edge_index=[2, 1810], x=[877, 49], y0=[64])
torch.Size([877, 49])


In [123]:
import pickle

trainloader = { "train_loader": train_loader }
pickle.dump( trainloader, open( "trainloader.p", "wb" ) )

In [124]:
import torch
from torch import Tensor
from torch.nn import Linear, BatchNorm1d, Dropout
from torch.nn import Parameter as Param
import torch.nn.functional as F
from torch_geometric.nn import global_add_pool, EdgePooling
from torch_sparse import matmul
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
from torch_geometric.typing import PairTensor, Adj, OptTensor, Size
from torch_scatter import scatter_add

from typing import Union

class AttentionAtomEmbedding(MessagePassing):
    """
    This function does only the atom embedding, not the molecule embedding
    """
    def __init__(self, atom_in_channels: int, bond_in_channels: int,  fingerprint_dim: int, dropout: float, bias: bool = True, debug: bool = False,  **kwargs):
        super(AttentionAtomEmbedding, self).__init__()

        self.atom_in_channels = atom_in_channels
        self.bond_in_channels = bond_in_channels
        self.fingerprint_dim = fingerprint_dim
        
        # central atom feature only
        self.atom_fc = Linear(atom_in_channels, fingerprint_dim, bias=bias)
        # feature atom & bond
        self.neighbor_fc = Linear(atom_in_channels + bond_in_channels, fingerprint_dim, bias=bias)
        # align
        self.align = Linear(2*fingerprint_dim,1, bias=bias)

        self.attend = Linear(fingerprint_dim, fingerprint_dim, bias=bias)
        self.debug = debug
        self.rnn =  torch.nn.GRUCell(fingerprint_dim, fingerprint_dim)
        self.dropout = Dropout(p=dropout)

    def forward(self, x: Union[Tensor,PairTensor], edge_index: Adj,
                edge_attr: OptTensor = None, size: Size = None) -> Tensor:
        
        out = self.propagate(edge_index, x = x, edge_attr=edge_attr, size=size)
        return out
    
    def message(self, x_i, x_j, edge_index: Adj, edge_attr: OptTensor, size) -> Tensor:

        atom_feature = F.leaky_relu(self.atom_fc(x_i)) # line 36 # UpProjet => fp_dim 
        if self.debug:
            print('a x_j:',x_j.shape,'x_i:',x_i.shape,'edge_attr:',edge_attr.shape)
        
        # neighbor_feature => neighbor_fc
        neighbor_feature = torch.cat([x_j, edge_attr], dim=-1) # line 43  => Buv
        if self.debug:
            print('b neighbor_feature', neighbor_feature.shape)
        
        neighbor_feature = F.leaky_relu(self.neighbor_fc(neighbor_feature)) # line 44 => UpProject => fp_dim
        if self.debug:
            print('c neighbor_feature', neighbor_feature.shape)
        
        # feature_align
        feature_align = torch.cat([atom_feature, neighbor_feature], dim=-1) # line 61
        if self.debug:
            print('d feature_align',feature_align.shape)
        
        # align score
        align_score = F.leaky_relu(self.align(feature_align)) # line 63
        if self.debug:
            print('e align_score:', align_score.shape)
        
        # attention_weight using EdgePooling softmax method (molecules num_nodes)
        attention_weight = EdgePooling.compute_edge_score_softmax(align_score, edge_index, edge_index.max().item() + 1)
        if self.debug:
            print('f attention_weight:', attention_weight.shape)
        
        neighbor_feature_transform = self.attend(self.dropout(neighbor_feature))
        if self.debug:
            print('g neighbor_feature_transform',neighbor_feature_transform.shape)
    
        #C_v = F.elu(scatter_add(context, edge_index, dim=0)) # line 74 scatter_add ?
        context = torch.mul(attention_weight, neighbor_feature_transform)
        if self.debug:
            print('h context',context.shape)
        context = F.elu(context) # line 74

        # in orignal code they expend dimensions and use a mask before GRU!!!!
        atom_feature = self.rnn(context, atom_feature) # line 77
        
        if self.debug:
                print('i atom_feature end message end:',atom_feature.shape)
        return atom_feature

class AttentiveMolEmbedding(torch.nn.Module):
    """
    This function does the molecule embedding
    """
    def __init__(self, radius:int, T: int, fingerprint_dim: int, dropout: int, debug: bool = False):
        super(AttentiveMolEmbedding, self).__init__()
        # need to find the correct dimensions 
        self.mol_align = Linear(2*fingerprint_dim,1)
        self.mol_expand = Linear(1,fingerprint_dim)

        self.mol_attend = Linear(fingerprint_dim,fingerprint_dim)
        self.dropout = Dropout(p=dropout)
        self.debug = debug
        
        # let start with one AAE with 49 atom features and 10 bond features and FP of 200
        self.atom_embedding =  AttentionAtomEmbedding(atom_in_channels = 49, bond_in_channels = 10,  fingerprint_dim = 200, dropout=0.3, debug = debug)

        self.rnn = torch.nn.GRUCell(fingerprint_dim, fingerprint_dim)
        self.output = Linear(fingerprint_dim,1)
        
    def forward(self, data):
        if self.debug:
            print('0 Go Run!')
        x, edge_index, batch, edge_attr = data.x, data.edge_index, data.batch, data.edge_feat
        
        # Radius = 1 for the moment
        # only one loop there: 
        activated_features =  self.atom_embedding(x, edge_index, edge_attr)
        
        if self.debug:
            print('1 back to h_v back from AE:',activated_features.shape) # message passing why from 4 => 3 atoms ?

        # in orignal code they expend dimensions and use a mask before GRU!!!!
        mol_feature = torch.sum(activated_features, dim=-1) # 113 
        mol_feature_expanded = self.mol_expand(mol_feature.reshape(x.shape[0],1))
        
        if self.debug:
            print('2 mol_feature',mol_feature_expanded.shape)
        
        mol_feature_expanded = F.relu(mol_feature_expanded) # 116
        if self.debug:
            print('3 activated_features', activated_features.shape)
        
        mol_cat_feature = torch.cat([mol_feature_expanded, activated_features], dim=-1)
        
        if self.debug:
            print('4 mol_cat_feature: ',mol_cat_feature.shape)
       
        mol_align = self.mol_align(mol_cat_feature)
        
        mol_align_score = F.softmax(F.leaky_relu(mol_align), dim=-1) # 127 ,129
        
        if self.debug:
            print('5 mol_align_score:',mol_align_score.shape)
        
        activated_features_transform = self.mol_attend(self.dropout(activated_features)) # 132
        if self.debug:
            print('6 activated_features_transform:', activated_features_transform.shape)
        
        mol_context = torch.mul(mol_align_score, activated_features_transform) # 134
        if self.debug:
            print('7 mol_context' ,mol_context.shape)
        mol_context = F.elu(mol_context) # line 136
        
        # in orignal code they expend dimensions and use a mask before GRU!!!!
        mol_feature = self.rnn(mol_context, mol_feature_expanded) # 137
        if self.debug:
            print('8 mol_feature' ,mol_feature.shape)

        activated_features_mol = F.relu(mol_feature)      #140     
        out = self.output(self.dropout(mol_feature)) # 142
        
        out = global_add_pool(out, batch)
        
        return out

In [125]:
trainloader = pickle.load( open( "trainloader.p", "rb" ) )


In [126]:
train_loader = trainloader['train_loader']

In [133]:
# generate the model architecture
# radius 1, T = 1, fpdim = 200
model = AttentiveMolEmbedding(1, 1, 200, dropout = 0.2, debug = False)

In [134]:
#model

In [137]:
# loop over data in a batch
import time

start = time.time()
for data in train_loader:
    #print(data)
    #print(data.x.shape)
    y = model(data)
stop = time.time()
print(stop-start)


0.4055018424987793
