# **libraries**

In [None]:
!pip uninstall -y numpy
!pip install numpy==1.24

In [None]:
!pip uninstall -y pandas
!pip install pandas

In [None]:
!pip install rdkit-pypi

In [None]:
import math
import random
import torch
import numpy as np
import pandas as pd
import csv
from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs, Descriptors, rdChemReactions
from real_reactions import REAL_REACTIONS
from torch import distributions
import matplotlib.pyplot as plt
from rdkit.Chem import Draw
from PIL import Image  # Ensure it's from PIL

##############################
# 1. Property Predictor
##############################

def compute_rdkit_descriptors(mol):
    from rdkit.Chem import Descriptors, rdMolDescriptors, Crippen, Lipinski, Fragments, EState, GraphDescriptors
    try:
        return np.array([
            EState.MaxEStateIndex(mol), EState.MinEStateIndex(mol), EState.MinAbsEStateIndex(mol),
            Descriptors.qed(mol), Descriptors.MolWt(mol), Descriptors.NumRadicalElectrons(mol),
            rdMolDescriptors.CalcMaxPartialCharge(mol), rdMolDescriptors.CalcMinPartialCharge(mol),
            rdMolDescriptors.CalcFractionCSP3(mol),
            rdMolDescriptors.BCUT2D_MWHI(mol), rdMolDescriptors.BCUT2D_MWLOW(mol),
            rdMolDescriptors.BCUT2D_CHGHI(mol), rdMolDescriptors.BCUT2D_CHGLO(mol),
            rdMolDescriptors.BCUT2D_MRHI(mol), rdMolDescriptors.BCUT2D_MRLOW(mol),
            GraphDescriptors.BalabanJ(mol), Descriptors.HallKierAlpha(mol), Descriptors.Kappa3(mol),
            *rdMolDescriptors.PEOE_VSA_(mol), *rdMolDescriptors.SMR_VSA_(mol),
            *rdMolDescriptors.SlogP_VSA_(mol), *rdMolDescriptors.EState_VSA_(mol),
            *rdMolDescriptors.VSA_EState_(mol),
            rdMolDescriptors.CalcFractionCSP3(mol),
            Lipinski.NumAliphaticCarbocycles(mol), Lipinski.NumAliphaticHeterocycles(mol),
            Lipinski.NumAliphaticRings(mol), Lipinski.NumAromaticHeterocycles(mol),
            Descriptors.MolLogP(mol),
            *(getattr(Fragments, name)(mol) for name in dir(Fragments) if name.startswith('fr_'))
        ], dtype=np.float32)[:140]
    except:
        return np.zeros(140, dtype=np.float32)

class ImprovedMolecularNN(torch.nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_dim, 512)
        self.bn1 = torch.nn.BatchNorm1d(512)
        self.fc2 = torch.nn.Linear(512, 256)
        self.bn2 = torch.nn.BatchNorm1d(256)
        self.fc3 = torch.nn.Linear(256, 128)
        self.bn3 = torch.nn.BatchNorm1d(128)
        self.fc4 = torch.nn.Linear(128, 64)
        self.bn4 = torch.nn.BatchNorm1d(64)
        self.fc5 = torch.nn.Linear(64, 1)
        self.leaky_relu = torch.nn.LeakyReLU(0.1)
        self.dropout = torch.nn.Dropout(0.4)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.leaky_relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = self.leaky_relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = self.leaky_relu(self.bn3(self.fc3(x)))
        x = self.dropout(x)
        x = self.leaky_relu(self.bn4(self.fc4(x)))
        x = self.sigmoid(self.fc5(x))
        return x

model = ImprovedMolecularNN(140)
model = torch.load("Best_MPP_ANN_model.pth", map_location=torch.device('cpu'), weights_only=False)
model.eval()

def property_predictor(mol, threshold=0.5):
    try:
        features = compute_rdkit_descriptors(mol)
        input_tensor = torch.tensor(features).unsqueeze(0)
        with torch.no_grad():
            output = model(input_tensor)
            probability = torch.sigmoid(output).item()
            return int(probability >= threshold)
    except:
        return 0

##############################
# 2. Load Data
##############################

def load_building_blocks(path):
    df = pd.read_csv(path)
    return [Chem.MolFromSmiles(smi) for smi in df['smiles'] if Chem.MolFromSmiles(smi)]

def load_reactions():
    rxn_list = []
    for entry in REAL_REACTIONS:
        try:
            reactant_smarts = ".".join([Chem.MolToSmarts(mol) for mol in entry["reactants"]])
            product_smarts = Chem.MolToSmarts(entry["product"])
            rxn = rdChemReactions.ReactionFromSmarts(f"{reactant_smarts}>>{product_smarts}")
            rxn.Initialize()
            rxn.reaction_id = entry["reaction_id"]
            rxn_list.append(rxn)
        except Exception as e:
            print(f"Error loading reaction {entry['reaction_id']}: {e}")
    return rxn_list

building_block_pool = load_building_blocks("1_building_blocks.csv")
reactions = load_reactions()

##############################
# 3. MCTS
##############################

class MCTSNode:
    def __init__(self, building_blocks, parent=None, reaction=None, product=None):
        self.building_blocks = building_blocks
        self.parent = parent
        self.children = []
        self.reaction = reaction
        self.product = product
        self.score = 0.0
        self.visits = 0
        self.is_terminal = False

    def ucb1(self, exploration_constant=1.4):
        if self.visits == 0:
            return float('inf')
        return (self.score / self.visits) + exploration_constant * math.sqrt(math.log(self.parent.visits + 1) / self.visits)

class MCTS:
    def __init__(self, root, reactions, building_block_pool, property_predictor):
        self.root = root
        self.reactions = reactions
        self.building_block_pool = building_block_pool
        self.property_predictor = property_predictor
        self.generated_smiles = set()

    def select(self, node):
        while node.children:
            node = max(node.children, key=lambda n: n.ucb1())
        return node

    def expand(self, node):
        random.shuffle(self.reactions)
        for rxn in self.reactions:
            try:
                k = rxn.GetNumReactantTemplates()
                if len(self.building_block_pool) < k:
                    continue
                sampled = random.sample(self.building_block_pool, k)
                products = rxn.RunReactants(sampled)
                for product_tuple in products:
                    product = product_tuple[0]
                    Chem.SanitizeMol(product)
                    smi = Chem.MolToSmiles(product)
                    if smi in self.generated_smiles:
                        continue
                    self.generated_smiles.add(smi)
                    child = MCTSNode(sampled, parent=node, reaction=rxn, product=product)
                    child.is_terminal = True
                    node.children.append(child)
                    return child
            except:
                continue
        return None

    def simulate(self, node):
        if node.product is not None:
            label = self.property_predictor(node.product)
            node.score += label
            node.visits += 1
            return label
        return 0

    def backpropagate(self, node, score):
        while node is not None:
            node.visits += 1
            node.score += score
            node = node.parent

    def run(self, num_iterations=10000, log_every=500):
        for i in range(num_iterations):
            try:
                node = self.select(self.root)
                child = self.expand(node)
                if child:
                    score = self.simulate(child)
                    self.backpropagate(child, score)
                else:
                    if i % log_every == 0:
                        print(f"[{i}] No valid child. Total unique molecules: {len(self.generated_smiles)}")
            except Exception as e:
                print(f"[{i}] Error: {e}")
                continue
            if i % log_every == 0:
                print(f"[{i}] Progress: {self.root.visits} visits, {len(self.root.children)} root children")

##############################
# 4. Diversity Filter + Output + Uniqueness Check
##############################

def get_top_molecules(root, top_k=1000):
    all_nodes = []
    def dfs(node):
        if node.is_terminal and node.product is not None:
            score = node.score / max(node.visits, 1)
            all_nodes.append((score, node))
        for child in node.children:
            dfs(child)
    dfs(root)
    all_nodes.sort(key=lambda x: -x[0])
    return all_nodes[:top_k]

def filter_diverse_molecules(scored_nodes, threshold=0.8):
    diverse = []
    fps = []
    for score, node in scored_nodes:
        fp = AllChem.GetMorganFingerprint(node.product, radius=2)
        if all(DataStructs.TanimotoSimilarity(fp, prev_fp) < threshold for prev_fp in fps):
            diverse.append((score, node))
            fps.append(fp)
    return diverse

def write_unique_molecules(diverse_nodes, train_csv):
    train_df = pd.read_csv(train_csv)
    train_smiles = set(train_df['smiles'].dropna().unique())
    unique_nodes = [(score, node) for score, node in diverse_nodes if Chem.MolToSmiles(node.product) not in train_smiles]

    with open("mcts unique batch label.csv", "w", newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["Rank", "Label", "Product_SMILES", "Reaction_ID", "Building_Blocks"])
        for i, (score, node) in enumerate(unique_nodes):
            smiles = Chem.MolToSmiles(node.product)
            bb_smiles = [Chem.MolToSmiles(m) for m in node.building_blocks]
            rxn_id = getattr(node.reaction, 'reaction_id', 'N/A')
            label = int(round(score))
            print(f"Unique Rank {i+1}: Label = {label}, SMILES = {smiles}")
            writer.writerow([i + 1, label, smiles, rxn_id, ".".join(bb_smiles)])
    return unique_nodes  # <-- RETURN unique_nodes here!

##############################
# 5. Run MCTS
##############################

from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def visualize_summary(unique_nodes, train_smiles):
    gen_fps = []
    gen_smiles = []
    tanimoto_to_train = []
    diversity_scores = []
    mcts_labels = []
    mcts_scores = []
    reaction_ids = []
    building_blocks_list = []
    gen_bit_fps = []

    # Precompute training molecules and fingerprints
    train_mols = [Chem.MolFromSmiles(s) for s in train_smiles if Chem.MolFromSmiles(s)]
    train_fps = [AllChem.GetMorganFingerprint(mol, 2) for mol in train_mols]
    train_bit_fps = [AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024) for mol in train_mols]

    for i, (score, node) in enumerate(unique_nodes):
        mol = node.product
        smi = Chem.MolToSmiles(mol)
        fp = AllChem.GetMorganFingerprint(mol, 2)
        bit_fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)

        gen_smiles.append(smi)
        gen_fps.append(fp)
        gen_bit_fps.append(bit_fp)
        mcts_labels.append(int(round(score)))
        mcts_scores.append(score)

        # Reaction ID and Building Blocks
        rxn_id = getattr(node.reaction, 'reaction_id', 'N/A')
        bb_smiles = ".".join([Chem.MolToSmiles(m) for m in node.building_blocks])
        reaction_ids.append(rxn_id)
        building_blocks_list.append(bb_smiles)

        # Tanimoto to training set
        max_sim_train = max(DataStructs.TanimotoSimilarity(fp, tfp) for tfp in train_fps)
        tanimoto_to_train.append(max_sim_train)

        # Diversity score (1 - max similarity to previous)
        if i == 0:
            diversity_scores.append(1.0)
        else:
            max_sim_prev = max(DataStructs.TanimotoSimilarity(fp, prev_fp) for prev_fp in gen_fps[:i])
            diversity_scores.append(1.0 - max_sim_prev)

    # Plot histogram of Tanimoto similarity to training
    plt.figure(figsize=(6, 4))
    plt.hist(tanimoto_to_train, bins=20, color='green', edgecolor='black')
    plt.title("Max Tanimoto Similarity to Training Set")
    plt.xlabel("Tanimoto Similarity")
    plt.ylabel("Frequency")
    plt.tight_layout()
    plt.savefig("similarity_distribution.png")
    plt.show()

    # Save all scores and metadata to CSV
    df = pd.DataFrame({
        'rank': list(range(1, len(gen_smiles) + 1)),
        'generated_smiles': gen_smiles,
        'tanimoto_to_train': tanimoto_to_train,
        'diversity_score': diversity_scores,
        'mcts_label': mcts_labels,
        'mcts_score': mcts_scores,
        'reaction_id': reaction_ids,
        'building_blocks': building_blocks_list
    })
    df.to_csv("generated_molecule_scores.csv", index=False)
    print(" Saved summary: generated_molecule_scores.csv")

    # ========================
    # t-SNE Visualization
    # ========================
    gen_arr = [np.array(list(fp.ToBitString()), dtype=int) for fp in gen_bit_fps]
    train_arr = [np.array(list(fp.ToBitString()), dtype=int) for fp in train_bit_fps]

    all_arr = gen_arr + train_arr
    all_arr_np = np.array(all_arr)

    tsne = TSNE(n_components=2, perplexity=30, random_state=42)
    tsne_result = tsne.fit_transform(all_arr_np)

    gen_tsne = tsne_result[:len(gen_arr)]
    train_tsne = tsne_result[len(gen_arr):]

    # Plot t-SNE
    plt.figure(figsize=(8, 6))
    plt.scatter(train_tsne[:, 0], train_tsne[:, 1], c='gray', label='Training', alpha=0.5)
    plt.scatter(gen_tsne[:, 0], gen_tsne[:, 1], c='blue', label='Generated', alpha=0.7)
    plt.title("t-SNE of Morgan Fingerprints")
    plt.xlabel("t-SNE 1")
    plt.ylabel("t-SNE 2")
    plt.legend()
    plt.tight_layout()
    plt.savefig("tsne_plot.png")
    plt.show()
    print(" Saved t-SNE plot: tsne_plot.png")

# Load training SMILES and run visualization
train_df = pd.read_csv("1-unique_clean_smiles.csv")
train_smiles = train_df['smiles'].dropna().unique()
visualize_summary(unique_nodes, train_smiles)

