We now have our ready encoder architecture, now we need to make the decoder which takes input of encoders and then predicting out the predictions for our reactions which will predict the reagent, solvent and catalyst in this case

# Importing Libraries

In [None]:
!pip install torch-geometric rdkit-pypi

In [2]:
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from torch_geometric.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Features File

In [4]:
def featurize_molecule(mol):
    # Compute Morgan fingerprints for each atom
    atom_features = []
    for atom in mol.GetAtoms():
        idx = atom.GetIdx()
        atom_feature = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, atomIndices=[idx])
        atom_features.append(np.array(atom_feature))

    return np.array(atom_features)

In [5]:
def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)

    # Add explicit hydrogens
    mol = Chem.AddHs(mol)

    # Generate 3D coordinates for visualization
    AllChem.EmbedMolecule(mol, randomSeed=42)  # You can choose any seed value

    # Get atom features and adjacency matrix
    num_atoms = mol.GetNumAtoms()
    atom_features = np.zeros((num_atoms, 3))  # You may need to adjust the feature dimensions
    adjacency_matrix = np.zeros((num_atoms, num_atoms))

    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        adjacency_matrix[i, j] = adjacency_matrix[j, i] = 1  # Adjacency matrix is symmetric

    for atom in mol.GetAtoms():
        idx = atom.GetIdx()
        atom_features[idx, 0] = atom.GetAtomicNum()  # Atom type or atomic number
        atom_features[idx, 1] = atom.GetTotalNumHs()  # Number of hydrogen atoms
        atom_features[idx, 2] = atom.GetFormalCharge()  # Formal charge

    # Convert to PyTorch tensors
    atom_features = torch.tensor(atom_features, dtype=torch.float)

    # Create edge_index using the adjacency matrix
    edge_index = torch.tensor(np.column_stack(np.where(adjacency_matrix)), dtype=torch.long)

    # Create PyTorch Geometric data object
    data = Data(x=atom_features, edge_index=edge_index.t().contiguous())  # Transpose edge_index

    return data

# GCN

In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, input_size, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.encoding = nn.Embedding(max_len, input_size)

    def forward(self, x):
        positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0)
        positions = positions.expand(x.size(0), -1)  # Expand along batch dimension
        return x + self.encoding(positions)

In [7]:
class DistanceAttentionEncoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(DistanceAttentionEncoder, self).__init__()

        self.embedding = PositionalEncoding(input_size)
        self.encoder = nn.Linear(input_size, hidden_size)
        self.decoder = nn.Linear(hidden_size, 1)
        self.softmax = nn.Softmax(dim=1)

    def pairwise_distances(self, x):
        # Calculate pairwise distances using L2 norm
        distances = torch.norm(x[:, None, :] - x, dim=-1, p=2)
        return distances

    def forward(self, input_sequence):
        # Assuming input_sequence has shape (batch_size, sequence_length, input_size)

        # Apply positional embeddings
        embedded_sequence = self.embedding(input_sequence)

        # Encode the embedded sequence
        encoded_sequence = self.encoder(embedded_sequence)

        # Calculate attention scores
        attention_scores = self.decoder(torch.tanh(encoded_sequence))

        # Apply softmax to get attention weights
        attention_weights = self.softmax(attention_scores)

        # Apply attention weights to the encoded sequence
        context_vector = torch.sum(encoded_sequence * attention_weights, dim=1)

        return context_vector

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, in_channels, out_channels, attention_heads=1):
        super().__init__()
        self.W_q = nn.Linear(in_channels, out_channels)
        self.W_k = nn.Linear(in_channels, out_channels)
        self.W_v = nn.Linear(in_channels, out_channels)
        self.attention_heads = attention_heads

    def forward(self, x):
        # Apply linear transformations to obtain queries, keys, and values
        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)

        # Reshape queries, keys, and values for multi-head attention
        q = q.view(-1, self.attention_heads, q.size(-1))
        k = k.view(-1, self.attention_heads, k.size(-1))
        v = v.view(-1, self.attention_heads, v.size(-1))

        # Compute scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(q.size(-1), dtype=torch.float32))
        attention_weights = F.softmax(scores, dim=-1)
        attention_output = torch.matmul(attention_weights, v).view(x.size(0), -1)

        return attention_output

In [9]:
class GATModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=1, external_attention_heads=None):
        super(GATModel, self).__init__()

        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
        self.external_attention = MultiHeadAttention(hidden_channels * heads, hidden_channels, attention_heads=external_attention_heads)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)
        self.external_attention_heads = external_attention_heads



    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # First GAT layer
        x = self.conv1(x, edge_index)
        x = torch.relu(x)

        if self.external_attention_heads is not None:
            # External Attention
            external_attention_output = self.external_attention(x)

            # Concatenate GAT output and external attention output
            x = torch.cat([x, external_attention_output], dim=-1)



        # Second GAT layer
        x = self.conv2(x,edge_index)


        self.distance_attention_encoder = DistanceAttentionEncoder(x.size(1), hidden_size=64)

        # Apply distance attention encoder
        distance_attention_output = self.distance_attention_encoder(x.unsqueeze(0))

        return x

# Seperating compunds in the SMILES

In [10]:
def separate_compounds(smiles_reaction):
    # Split the reaction string using '>>' as the separator
    compounds = smiles_reaction.split(">>")

    # Ensure that there are exactly two compounds
    if len(compounds) == 2:
        reactant = compounds[0].strip()
        product = compounds[1].strip()
        return reactant, product
    else:
        raise ValueError("Invalid SMILES reaction format. Expected one '>>' separator.")

# Given SMILES reaction
smiles_reaction = "O=C1CCCN1C1CCN(Cc2ccccc2)CC1>>O=C1CCCN1C1CCNCC1"

# Separate compounds
reactant, product = separate_compounds(smiles_reaction)

# Print the separated compounds
print("Reactant:", reactant)
print("Product:", product)


Reactant: O=C1CCCN1C1CCN(Cc2ccccc2)CC1
Product: O=C1CCCN1C1CCNCC1


# Cross Attention for both
Using cross attention to concatenate for both the compounds into a single embedding space

In [11]:
def concatenate_with_cross_attention(emb1, emb2, out_channels, heads):
    # Assuming emb1 and emb2 are the output embeddings from two GAT models

    # Project the smaller embedding (emb2) to the same dimension as the larger one (emb1)
    if emb1.shape[0] > emb2.shape[0]:
        linear_projection = nn.Linear(emb2.shape[0], emb1.shape[0])
        emb2 = linear_projection(emb2.T).T

    elif emb1.shape[0] < emb2.shape[0]:
        linear_projection = nn.Linear(emb1.shape[0], emb2.shape[0])
        emb1 = linear_projection(emb1.T).T

    # Concatenate the embeddings along the feature dimension
    concatenated_emb = torch.cat((emb1, emb2), dim=1)


    # Apply cross-attention using MultiheadAttention
    multihead_attention = nn.MultiheadAttention(embed_dim=2*out_channels, num_heads=heads)
    cross_attended_emb, _ = multihead_attention(concatenated_emb, concatenated_emb, concatenated_emb)


    return cross_attended_emb

In [18]:
def get_cross_attention_output(smiles_string):
  reactant, product = separate_compounds(smiles_string)
  graph_data_reactant = smiles_to_graph(reactant)
  graph_data_product = smiles_to_graph(product)

  #for the reactant
  in_channels = graph_data_reactant.x.size(1)  # Number of input features
  hidden_channels = 64
  out_channels = 32
  heads = 2  # Number of attention heads
  gat_model_reactant = GATModel(in_channels, hidden_channels, out_channels, heads).to(device)
  output_reactant = gat_model_reactant(graph_data_reactant)

  #for the product
  in_channels = graph_data_product.x.size(1)  # Number of input features
  hidden_channels = 64
  out_channels = 32
  heads = 2  # Number of attention heads
  gat_model_product = GATModel(in_channels, hidden_channels, out_channels, heads).to(device)
  output_product = gat_model_product(graph_data_product)

  #applying the cross attention
  out=concatenate_with_cross_attention(output_reactant, output_product, out_channels=32, heads=2)


  return out

In [19]:
smiles_string = "O=C1CCCN1C1CCN(Cc2ccccc2)CC1>>O=C1CCCN1C1CCNCC1"
out=get_cross_attention_output(smiles_string)

In [14]:
print(out.shape)

torch.Size([41, 64])


# Vocab Size

# Decoder
This will decode the whole output for the given encoded input, I am taking a 3 head decoder for this task as, we have to predict three different things in this case

In [64]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size_condition1, vocab_size_condition2, vocab_size_condition3, d_model=512, nhead=8, num_layers=6):
        super(TransformerDecoder, self).__init__()

        self.vocab_size_condition1=vocab_size_condition1
        self.linear_layer=nn.Linear(vocab_size_condition1, d_model)

        self.embedding = nn.Embedding(vocab_size_condition1, d_model)

        self.transformer_layers = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model, nhead),
            num_layers
        )

        # Linear layers for each condition
        self.fc_condition1 = nn.Linear(d_model, vocab_size_condition1)
        self.fc_condition2 = nn.Linear(d_model, vocab_size_condition2)
        self.fc_condition3 = nn.Linear(d_model, vocab_size_condition3)

        # Softmax activations for each condition
        self.softmax_condition1 = nn.Softmax(dim=1)
        self.softmax_condition2 = nn.Softmax(dim=1)
        self.softmax_condition3 = nn.Softmax(dim=1)

    def forward(self, encoded_input):
        #a linear to project out the encoded input
        self.linear=nn.Linear(encoded_input.size(1),self.vocab_size_condition1)
        encoded_input=self.linear(encoded_input)
        # encoded_input: The common encoded input from the encoder
        embedded=self.linear_layer(encoded_input)
        # Expand dimensions to add sequence length dimension


        # Transformer decoder layers
        memory = torch.rand(32, 512)
        transformer_out = self.transformer_layers(embedded,memory=memory)

        # Fully connected layers and softmax activations for each condition
        output_condition1 = self.fc_condition1(transformer_out)
        output_condition2 = self.fc_condition2(transformer_out)
        output_condition3 = self.fc_condition3(transformer_out)

        output_probs_condition1 = self.softmax_condition1(output_condition1)
        output_probs_condition2 = self.softmax_condition2(output_condition2)
        output_probs_condition3 = self.softmax_condition3(output_condition3)

        return output_probs_condition1, output_probs_condition2, output_probs_condition3


In [65]:
x=torch.randn([41, 64])
decoder=TransformerDecoder(vocab_size_condition1=1000, vocab_size_condition2=1000, vocab_size_condition3=1000)

In [66]:
out=decoder(x)

In [67]:
print(out)

(tensor([[0.0006, 0.0015, 0.0004,  ..., 0.0006, 0.0017, 0.0022],
        [0.0007, 0.0017, 0.0004,  ..., 0.0007, 0.0023, 0.0019],
        [0.0011, 0.0022, 0.0006,  ..., 0.0005, 0.0021, 0.0017],
        ...,
        [0.0006, 0.0017, 0.0006,  ..., 0.0009, 0.0019, 0.0016],
        [0.0009, 0.0014, 0.0004,  ..., 0.0010, 0.0028, 0.0012],
        [0.0008, 0.0016, 0.0006,  ..., 0.0010, 0.0016, 0.0018]],
       grad_fn=<SoftmaxBackward0>), tensor([[0.0019, 0.0013, 0.0009,  ..., 0.0015, 0.0008, 0.0004],
        [0.0016, 0.0012, 0.0010,  ..., 0.0012, 0.0005, 0.0003],
        [0.0020, 0.0014, 0.0013,  ..., 0.0023, 0.0007, 0.0004],
        ...,
        [0.0014, 0.0012, 0.0013,  ..., 0.0022, 0.0004, 0.0005],
        [0.0023, 0.0010, 0.0013,  ..., 0.0015, 0.0009, 0.0005],
        [0.0021, 0.0008, 0.0010,  ..., 0.0011, 0.0008, 0.0004]],
       grad_fn=<SoftmaxBackward0>), tensor([[0.0035, 0.0003, 0.0021,  ..., 0.0006, 0.0014, 0.0010],
        [0.0025, 0.0002, 0.0029,  ..., 0.0006, 0.0022, 0.0011],
   

# Seq2Seq

In [68]:
class TransformerSeq2Seq(nn.Module):
    def __init__(self,
                 decoder_input_dim, decoder_d_model, decoder_nhead, decoder_layers,
                 vocab_size_condition1, vocab_size_condition2, vocab_size_condition3):
        super(TransformerSeq2Seq, self).__init__()


        # Decoder
        self.decoder = TransformerDecoder(vocab_size_condition1, vocab_size_condition2, vocab_size_condition3,
                                          d_model=decoder_d_model, nhead=decoder_nhead, num_layers=decoder_layers)

        # Linear layer to project encoder output to decoder input dimension

    def forward(self, input_sequence):
        # input_sequence: The input sequence (reactants or products)
        self.hidden_state=get_cross_attention_output(input_sequence)

        # Decoder forward pass
        output_probs_condition1, output_probs_condition2, output_probs_condition3 = self.decoder(self.hidden_state)

        return output_probs_condition1, output_probs_condition2, output_probs_condition3

In [69]:
decoder_input_dim = 256  # Replace with the desired input dimension for the decoder (should match encoder_hidden_dim)
decoder_d_model = 512  # Replace with the desired d_model for the decoder
decoder_nhead = 8  # Replace with the desired number of heads for the decoder
decoder_layers = 6  # Replace with the desired number of layers for the decoder

seq2seq_model = TransformerSeq2Seq(
                                   decoder_input_dim, decoder_d_model, decoder_nhead, decoder_layers,
                                   vocab_size_condition1=1000, vocab_size_condition2=1000, vocab_size_condition3=1000)



# Forward pass
output_probs_condition1, output_probs_condition2, output_probs_condition3 = seq2seq_model(input_sequence='O=C1CCCN1C1CCN(Cc2ccccc2)CC1>>O=C1CCCN1C1CCNCC1')

# Display the shapes of the output probabilities

In [72]:
output_probs_condition1.shape

torch.Size([41, 1000])

# Training Loop