In [None]:
#GNN Model to to test performance of Machine Learning on original and reorganized SMILES string #
#The code is initially generated with Openai's LLM i.e. "ChatGPT.com" #
#Then code is modified according to the need of the study #
#Initial code obtained on 12Dec2024 #

import torch
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from rdkit import Chem
import pandas as pd

# Step 1: SMILES to Graph Conversion
def smiles_to_graph(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles,sanitize=False)
        if mol is None:
            print(f"Invalid SMILES: {smiles}")
            return None

        # Node features: atomic numbers
        atom_features = [[atom.GetAtomicNum()] for atom in mol.GetAtoms()]
        x = torch.tensor(atom_features, dtype=torch.float)

        # Edge indices
        edge_indices = []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            edge_indices.append([start, end])
            edge_indices.append([end, start])
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()

        return Data(x=x, edge_index=edge_index)
    except Exception as e:
        print(f"Error processing SMILES: {smiles}, Error: {e}")
        return None

# Step 2: Load Data from Text Files
def load_data(valid_file, invalid_file):
    try:
        valid_smiles = pd.read_csv(valid_file, header=None)[0].tolist()
        invalid_smiles = pd.read_csv(invalid_file, header=None)[0].tolist()
    except FileNotFoundError:
        print("Error: One or both of the input files do not exist.")
        return [], []

    data_list = []

    for smiles in valid_smiles:
        graph = smiles_to_graph(smiles)
        if graph:
            graph.y = torch.tensor([1], dtype=torch.long)  # Valid SMILES are labeled as 1
            data_list.append(graph)

    for smiles in invalid_smiles:
        graph = smiles_to_graph(smiles)
        if graph:
            graph.y = torch.tensor([0], dtype=torch.long)  # Invalid SMILES are labeled as 0
            data_list.append(graph)

    return data_list

# Step 3: Define the GNN Model
class GNNClassifier(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNNClassifier, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, data):
        # Graph data: x (node features), edge_index (connectivity), batch
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)  # Global pooling (mean pooling)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

# Step 4: Training and Evaluation
def train_and_evaluate(model, train_loader, test_loader, optimizer, epochs=10):
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in train_loader:
            if batch.x.shape[0] == 0:  # Skip empty batches
                continue
            optimizer.zero_grad()
            output = model(batch)
            loss = F.nll_loss(output, batch.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")

    # Evaluation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in test_loader:
            if batch.x.shape[0] == 0:  # Skip empty batches
                continue
            output = model(batch)
            pred = output.argmax(dim=1)
            correct += (pred == batch.y).sum().item()
            total += batch.y.size(0)

    print(f"Accuracy: {correct / total:.2f}")

# Step 5: Prediction
def predict(model, smiles_file):
    model.eval()
    predictions = []
    
    # Read SMILES from the uploaded file
    try:
        new_smiles = pd.read_csv(smiles_file, header=None)[0].tolist()
    except FileNotFoundError:
        print("Error: The input file does not exist.")
        return []

    new_smiles = [smiles.strip() for smiles in new_smiles if smiles.strip()]  # Remove empty lines
    
    for smiles in new_smiles:
        graph = smiles_to_graph(smiles)
        if graph is None:
            predictions.append((smiles, "Error: Invalid SMILES"))
            continue
            # Add a batch dimension for single-graph inference
        graph.batch = torch.tensor([0], dtype=torch.long)

        with torch.no_grad():
            output = model(graph)
            pred = output.argmax(dim=1).item()
            label = "Valid" if pred == 1 else "Invalid"
            predictions.append((smiles, label))
    return predictions

# Step 6: Main Pipeline
if __name__ == "__main__":
    # Load data
    valid_file = input("Enter the path to the file containing Valid SMILES strings: ")
    invalid_file = input("Enter the path to the file containing Invalid SMILES strings: ")
    data_list = load_data(valid_file, invalid_file)

    if not data_list:
        print("Error: No valid data found.")
        exit()

    # Train/test split
    train_size = int(0.8 * len(data_list))
    train_data = data_list[:train_size]
    test_data = data_list[train_size:]

    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=32)

    # Define model
    input_dim = 1  # Atomic number as node feature
    hidden_dim = 32
    output_dim = 2  # Valid/Invalid classification
    model = GNNClassifier(input_dim, hidden_dim, output_dim)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # Train and evaluate
    train_and_evaluate(model, train_loader, test_loader, optimizer, epochs=10)

    # Predict new SMILES from uploaded file
    smiles_file = input("Enter the path to the file containing new SMILES strings: ")
    predictions = predict(model, smiles_file)
    
    print("\nPredictions:")
    for smiles, label in predictions:
        print(f"SMILES: {smiles} -> Prediction: {label}")
        
    # Save the trained model
    torch.save(model.state_dict(), 'gnn_model.pth')

print("\nModel saved successfully!")