In [1]:
import config as cfg
import torch.nn as nn
from torch_geometric.nn import MLP
from torch_geometric.loader import DataLoader
from data.moldataset import MolDataset 
import torch

In [3]:
train_data = MolDataset(root=cfg.TRAIN_DIR)
train_data.data

Data(x=[22752, 24], edge_index=[2, 50672], edge_attr=[50672, 8], y=[1648], pos=[22764, 10, 3], atoms=[22752], u=[824, 790])

In [4]:
trian_loader = DataLoader(train_data, batch_size=cfg.BATCH_SIZE, shuffle=True, num_workers=cfg.NUM_WORKERS)

In [5]:
batch = next(iter(trian_loader))
batch


DataBatch(x=[52, 24], edge_index=[2, 116], edge_attr=[116, 8], y=[4], pos=[52, 10, 3], atoms=[52], u=[2, 790], batch=[52], ptr=[3])

In [6]:
f= nn.Flatten(start_dim=0, end_dim=-1)
f(batch.x).size(), batch.x.size()
len(batch.x)

52

In [12]:

class FeatureEmbedder(nn.Module):
    """_summary_
    """
    
    def __init__(
                self, input_mol_emb_dim=cfg.IN_MOL_DIM, input_node_emb_dim=cfg.NODE_DIM,
                out_emb_dim=cfg.OUT_EMB_DIM, **kwargs
                ):
        super(FeatureEmbedder, self).__init__()
        self.flatten = nn.Flatten(0, -1)
        self.mlp_u_emb = MLP(
            in_channels=input_mol_emb_dim,
            hidden_channels=128,
            out_channels=out_emb_dim,
            num_layers=3,
            act=torch.nn.SiLU(),
            norm='LayerNorm'
            )
                
        self.mlp_x_emb = MLP(
            in_channels=input_node_emb_dim,
            hidden_channels=128,
            out_channels=out_emb_dim,
            num_layers=3,
            act=torch.nn.SiLU(),
            norm='LayerNorm'
            )
            
    def forward(self, batch_data):
        
        batch_data.u = self.mlp_u_emb(batch_data.u)
        batch_data.x = self.mlp_x_emb(batch_data.x)
        
        return batch_data
        
        

In [13]:
model = FeatureEmbedder()
model.eval()

FeatureEmbedder(
  (flatten): Flatten(start_dim=0, end_dim=-1)
  (mlp_u_emb): MLP(790, 128, 128, 64)
  (mlp_x_emb): MLP(24, 128, 128, 64)
)

In [14]:
model(batch)

DataBatch(x=[52, 64], edge_index=[2, 116], edge_attr=[116, 8], y=[4], pos=[52, 10, 3], atoms=[52], u=[2, 64], batch=[52], ptr=[3])

In [None]:
class FeatureEmbedder(nn.Module):
    """_summary_
    """
    
    def __init__(
                self, batch_data,
                out_emb_dim=cfg.OUT_EMB_DIM, **kwargs
                ):
        super(FeatureEmbedder, self).__init__()
        self.flatten = nn.Flatten(0, -1)
        self.batch_data = batch_data
        self.mlp_u_emb = MLP(
            in_channels=len(self.flatten(self.batch_data.u)),
            hidden_channels=128,
            out_channels=out_emb_dim,
            num_layers=3,
            act=torch.nn.SiLU(),
            norm='LayerNorm'
            )
                
        self.mlp_x_emb = MLP(
            in_channels=len(self.flatten(self.batch_data.x)),
            hidden_channels=128,
            out_channels=out_emb_dim,
            num_layers=3,
            act=torch.nn.SiLU(),
            norm='LayerNorm'
            )
            
    def forward(self):
        
        u = self.flatten(self.batch_data.u)
        u = self.mlp_u_emb(u)
        self.batch_data.u = u
        
        x = self.flatten(self.batch_data.x)
        x = self.mlp_x_emb(x)
        self.batch_data.x = x
        
        return self.batch_data

In [75]:
model = FeatureEmbedder(batch)
model.eval()

FeatureEmbedder(
  (flatten): Flatten(start_dim=0, end_dim=-1)
  (mlp_u_emb): MLP(1580, 128, 128, 64)
  (mlp_x_emb): MLP(1128, 128, 128, 64)
)

In [77]:
model()

torch.Size([1580])


DataBatch(x=[64], edge_index=[2, 104], edge_attr=[104, 8], y=[4], pos=[47, 10, 3], atoms=[47], u=[64], batch=[47], ptr=[3])