In [None]:
!pip install cloud-tpu-client==0.10 torch==2.0.0 torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl

In [None]:
!pip install rdkit-pypi

In [None]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import numpy as np
from collections import defaultdict
import pandas as pd
from collections import defaultdict, Counter
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem ,Descriptors, Draw
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
df = pd.read_csv('Transformations.csv', encoding='latin1', nrows=1000)
df1 = pd.read_csv('substances.csv', encoding='latin1', nrows=1000)
df2 = pd.read_csv('metabolicdb.csv', encoding='latin1', nrows=1000)

In [None]:
df['Predecessor_Name'] = df['Predecessor_Name'].astype(str)
df['Successor_CID'] = df['Successor_CID'].astype(float)
df1['SubstanceName'] = df1['SubstanceName'].astype(str)
df1['PubChem_CID'] = df1['PubChem_CID'].astype(float)
df2['substrate_cid'] = df2['substrate_cid'].astype(str)
df2['prod_cid'] = df2['prod_cid'].astype(str)

# First merge
mergedDf1 = pd.merge(df, df1,
                      left_on=['Successor_CID'],
                      right_on=['PubChem_CID'],
                      how='inner')

# Second merge
mergedDf = pd.merge(mergedDf1, df2,
                     left_on=['Enzyme'],
                     right_on=['enzyme'],
                     how='inner')

In [None]:
%%time
enzymeCounter = Counter()
allPossiblePairs = set()
allPossibleEnzymes = set()
def updateSetsAndCounter(row):
    enzymes = row['Enzyme'].split('; ')
    enzymeCounter.update(enzymes)
    allPossiblePairs.add((row['Predecessor_CID'], row['Successor_CID']))
    allPossibleEnzymes.update(enzymes)
mergedDf.apply(updateSetsAndCounter, axis=1)
print("[INFO]: PHASE-1 DONE !")

enzymeFrequenciesDf = pd.DataFrame.from_dict(enzymeCounter, orient='index', columns=['Frequency']).reset_index()
enzymeFrequenciesDf.rename(columns={'index': 'Enzyme'}, inplace=True)
enzymeFrequenciesDf.sort_values(by='Frequency', ascending=False, inplace=True)
print("[INFO]: PHASE-2 DONE !")

highFrequencyEnzymes = set(enzymeFrequenciesDf[enzymeFrequenciesDf['Frequency'] >= 10]['Enzyme'])
print("[INFO]: PHASE-3 DONE !")

transformationEnzymeGroups = mergedDf.groupby(['Transformation', 'Enzyme']).size().reset_index(name='Counts')
enzymeSpecificTransformationsSet = set(
    transformationEnzymeGroups.groupby('Transformation')
    .filter(lambda x: len(x) == 1)['Transformation'])
print("[INFO]: PHASE-4 DONE !")

In [None]:
%%time
def calculateWeights(row):
    rowBasedWeight = 10 * (row['Enzyme'] in highFrequencyEnzymes)
    rowBasedWeight += 5 * (row['Transformation'] in enzymeSpecificTransformationsSet)
    rowBasedWeight += 3 * (row['Biosystem'] == 'Human')
    return rowBasedWeight
mergedDf['Row_Based_Weight'] = mergedDf.apply(calculateWeights, axis=1)
print("[INFO]: PHASE-5 DONE !")

def calculateChemicalStructureWeight(inchi, smiles):
    weight = 0
    molInchi = Chem.MolFromInchi(inchi)
    molSmiles = Chem.MolFromSmiles(smiles)
    if molInchi and molSmiles:
        mwInchi = Descriptors.MolWt(molInchi)
        mwSmiles = Descriptors.MolWt(molSmiles)
        fpInchi = Chem.RDKFingerprint(molInchi)
        fpSmiles = Chem.RDKFingerprint(molSmiles)
        tanimotoSimilarity = DataStructs.FingerprintSimilarity(fpInchi, fpSmiles)
        weight = mwInchi * tanimotoSimilarity + mwSmiles
    return weight
print("[INFO]: PHASE-6 DONE !")


Q = defaultdict(int)
N = 5 # at least have 5 pairs that catalyze with X enzyme

# Updating Q Model (Dictionary) [Q: QUBO Model]
def updateQ(row):
    try:
        enzymes = row['Enzyme'].split('; ')
        target = row['Successor_CID']
        predecessor = row['Predecessor_CID']
        weight1 = row['Row_Based_Weight']
        weight2 = calculateChemicalStructureWeight(row['InChI'], row['SMILES'])
        totalWeight = weight1 + weight2
        for enzyme in enzymes:
            enzyme = enzyme.strip()
            pair = f"{enzyme}_{predecessor}_{target}"
            Q[(pair, pair)] += -totalWeight
            for otherPredecessor, otherTarget in allPossiblePairs:
                otherPair = f"{enzyme}_{otherPredecessor}_{otherTarget}"
                Q[(pair, otherPair)] += 2 * N
    except Exception as e:
        print(f"An error occurred finding in chemical informatics: {e}")
        print(f"Skipping row: {row}")
mergedDf.apply(updateQ, axis=1)
print("[INFO]: PHASE-7 DONE !")

In [None]:
def updatePairConstraint(predecessor, target):
    pairsForThisReaction = [
        f"{enzyme.strip()}_{predecessor}_{target}"
        for enzyme in allPossibleEnzymes
    ]
    for i, pair1 in enumerate(pairsForThisReaction):
        for j, pair2 in enumerate(pairsForThisReaction):
            if i != j:
                Q[(pair1, pair2)] -= 2  # EXPERIMENTAL::Encourage at least one enzyme to catalyze each (predecessor, target)
print("[INFO]: PHASE-8 DONE !")

for predecessor, target in allPossiblePairs:
    updatePairConstraint(predecessor, target)
pairToIndex = {pair: i for i, pair in enumerate(allPossiblePairs)}
enzymeToIndex = {enzyme: j for j, enzyme in enumerate(allPossibleEnzymes)}
adjacencyMatrix = np.zeros((len(allPossiblePairs), len(allPossibleEnzymes)))
def updateAdjacencyMatrix(pair, i):
    for enzyme, j in enzymeToIndex.items():
        pairStr = f"{enzyme.strip()}_{pair[0]}_{pair[1]}"
        if (pairStr, pairStr) in Q:
            adjacencyMatrix[i, j] = 1
print("[INFO]: PHASE-9 DONE !")

for pair, i in pairToIndex.items():
    updateAdjacencyMatrix(pair, i)
def refineConstraints(adjacencyMatrix):
    stronglyConnectedEnzymes = np.sum(adjacencyMatrix, axis=0) > 1
    stronglyConnectedPairs = np.sum(adjacencyMatrix, axis=1) > 1
    for pair, i in pairToIndex.items():
        for enzyme, j in enzymeToIndex.items():
            pairStr = f"{enzyme.strip()}_{pair[0]}_{pair[1]}"
            qVal = Q.get((pairStr, pairStr), 0)
            if qVal:
                if stronglyConnectedEnzymes[j]:
                    Q[(pairStr, pairStr)] = qVal * 1.5  # Increase by 50%
                if stronglyConnectedPairs[i]:
                    Q[(pairStr, pairStr)] = qVal * 1.5  # Increase by 50%

refineConstraints(adjacencyMatrix)
print("[INFO]: PHASE-10 DONE !")

In [None]:
%%time
allPossiblePairsList = list(allPossiblePairs)
n = len(allPossiblePairsList)
Q_np = np.array([[Q.get((allPossiblePairsList[i], allPossiblePairsList[j]), 0) for j in range(n)] for i in range(n)])
Q_tpu = torch.tensor(Q_np, dtype=torch.float32, device=xm.xla_device())
x = torch.rand(n, requires_grad=True, device=xm.xla_device())
optimizer = torch.optim.SGD([x], lr=0.01)
def roundTensor(tensor):
    return torch.round(tensor)
def objectiveFunction(x):
    return torch.matmul(torch.matmul(x, Q_tpu), x.reshape(-1, 1))
numEpochs = 10000
for epoch in range(numEpochs):
    optimizer.zero_grad()
    loss = objectiveFunction(x)
    loss.backward()
    optimizer.step()
    with torch.no_grad():
        x.data = roundTensor(x.data)
result = x.cpu().detach().numpy()
reverseIndexMapping = {i: '_'.join(map(str, pair)) for pair, i in pairToIndex.items()}
optimalPairs = [reverseIndexMapping[i] for i, value in enumerate(result) if value == 1]

In [None]:
optimalTransformations = {}
for pair in optimalPairs:
    print(f"Pair: {pair}")
    predecessor, target = pair.split('_')
    predQuery = df1.loc[df1['PubChem_CID'] == int(predecessor), 'SMILES']
    targetQuery = df1.loc[df1['PubChem_CID'] == float(target), 'SMILES']
    predecessorSmiles = predQuery.iloc[0] if not predQuery.empty else 'Not Found'
    targetSmiles = targetQuery.iloc[0] if not targetQuery.empty else 'Not Found'
    optimalTransformations[predecessorSmiles] = targetSmiles

    #print(f"Optimal transformation: {predecessorSmiles} -> {targetSmiles}")

In [None]:
def Visualize(predecessorSmiles):
    targetSmiles = optimalTransformations.get(predecessorSmiles, "Not found")
    if targetSmiles == "Not found":
        print("Predecessor SMILES not found in the optimal transformations.")
        return
    # Drawing predecessor molecule
    predMol = Chem.MolFromSmiles(predecessorSmiles)
    predImg = Draw.MolToMPL(predMol, size=(300, 300), kekulize=True)
    plt.title("Predecessor Molecule")
    plt.axis("off")
    plt.show()

    # Drawing target molecule
    targetMol = Chem.MolFromSmiles(targetSmiles)
    targetImg = Draw.MolToMPL(targetMol, size=(300, 300), kekulize=True)
    plt.title("Predicted Target Molecule")
    plt.axis("off")
    plt.show()

predecessorSmiles = "C[C@]12CCC(=O)C=C1CC[C@H]1[C@@H]3CC[C@@H]([C@@]3(C)CC[C@H]21)O" # EXAMPLE
Visualize(predecessorSmiles)