<a href="https://colab.research.google.com/github/rsaran-BioAI/AGILE/blob/main/ConnectionAware_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Connecting the drive with Colab

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
#%%bash
#MINICONDA_INSTALLER_SCRIPT=Miniconda3-latest-Linux-x86_64.sh
#MINICONDA_PREFIX=/usr/local
#wget https://repo.continuum.io/miniconda/$MINICONDA_INSTALLER_SCRIPT
#chmod +x $MINICONDA_INSTALLER_SCRIPT
#./$MINICONDA_INSTALLER_SCRIPT -b -f -p $MINICONDA_PREFIX

Process is interrupted.


In [3]:
!pip install rdkit



In [4]:
!pip install torch_geometric



In [5]:
!pip install torch



In [6]:
!pip install guacamol



In [7]:
!pip install tensorboardX



In [8]:
!pip install networkx



## Merging Operation Learning

In [9]:
%cd /content/drive/MyDrive/AGILE2/AI4Sci-MiCaM/src/

/content/drive/MyDrive/AGILE2/AI4Sci-MiCaM/src


In [10]:
import multiprocessing as mp
import os
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from multiprocessing import Process, Queue
from typing import Dict, List, Tuple

import networkx as nx
from rdkit import Chem

In [11]:
import arguments
import model.mydataclass
from arguments import parse_arguments
from model.mydataclass import Paths

In [12]:
@dataclass
class MolGraph:
    idx: int
    mol_graph: Chem.rdchem.Mol
    merging_graph: nx.Graph

    def __init__(self, smiles: str, idx: int=0) -> "MolGraph":
        self.idx = idx
        self.mol_graph = Chem.MolFromSmiles(smiles)
        self.merging_graph = nx.Graph(Chem.rdmolops.GetAdjacencyMatrix(self.mol_graph))
        for atom in self.mol_graph.GetAtoms():
            self.merging_graph.nodes[atom.GetIdx()]["atom_indices"] = set([atom.GetIdx()])

    def apply_merging_operation(self, motif: str, stats: Dict[str, int], indices: Dict[str, Dict[int, int]]) -> None:
        if self.merging_graph.number_of_nodes() == 1:
            return
        new_graph = self.merging_graph.copy()
        for (node1, node2) in self.merging_graph.edges:
            if not new_graph.has_edge(node1, node2):
                continue
            atom_indices = new_graph.nodes[node1]["atom_indices"].union(new_graph.nodes[node2]["atom_indices"])
            motif_smiles = fragment2smiles(self, atom_indices)
            if motif_smiles == motif:
                graph_before_merge = new_graph.copy()
                merge_nodes(new_graph, node1, node2)
                update_stats(self, graph_before_merge, new_graph, node1, node2, stats, indices, self.idx)
        self.merging_graph = new_graph
        indices[motif][self.idx] = 0

    def apply_merging_operation_producer(self, motif: str, q: Queue) -> None:
        if self.merging_graph.number_of_nodes() == 1:
            return
        new_graph = self.merging_graph.copy()
        for (node1, node2) in self.merging_graph.edges:
            if not new_graph.has_edge(node1, node2):
                continue
            atom_indices = new_graph.nodes[node1]["atom_indices"].union(new_graph.nodes[node2]["atom_indices"])
            motif_smiles = fragment2smiles(self, atom_indices)
            if motif_smiles == motif:
                graph_before_merge = new_graph.copy()
                merge_nodes(new_graph, node1, node2)
                update_stats_producer(self, graph_before_merge, new_graph, node1, node2, q, self.idx)
        q.put((motif, self.idx, new_graph))

def load_batch_mols(batch: List[Tuple[int, str]]) -> List[MolGraph]:
    return [MolGraph(smi, idx) for (idx, smi) in batch]

def load_mols(train_path: str, num_workers: int) -> List[MolGraph]:
    print(f"[{datetime.now()}] Loading molecules...")
    smiles_list = [smi.strip("\n") for smi in open(train_path)]
    smiles_list = [(i, smi) for (i, smi) in enumerate(smiles_list)]

    batch_size = (len(smiles_list) - 1) // num_workers + 1
    batches = [smiles_list[i : i + batch_size] for i in range(0, len(smiles_list), batch_size)]
    mols: List[MolGraph]= []
    with mp.Pool(num_workers) as pool:
        for mols_batch in pool.imap(load_batch_mols, batches):
            mols.extend(mols_batch)

    print(f"[{datetime.now()}] Loading molecules finished. Total: {len(mols)} molecules.\n")
    return mols

def fragment2smiles(mol: MolGraph, indices: List[int]) -> str:
    smiles = Chem.MolFragmentToSmiles(mol.mol_graph, tuple(indices))
    return Chem.MolToSmiles(Chem.MolFromSmiles(smiles, sanitize=False))

def merge_nodes(graph: nx.Graph, node1: int, node2: int) -> None:
    neighbors = [n for n in graph.neighbors(node2)]
    atom_indices = graph.nodes[node1]["atom_indices"].union(graph.nodes[node2]["atom_indices"])
    for n in neighbors:
        if node1 != n and not graph.has_edge(node1, n):
            graph.add_edge(node1, n)
        graph.remove_edge(node2, n)
    graph.remove_node(node2)
    graph.nodes[node1]["atom_indices"] = atom_indices

def get_stats_producer(batch: List[MolGraph], q: Queue):
    for mol in batch:
        for (node1, node2) in mol.merging_graph.edges:
            atom_indices = mol.merging_graph.nodes[node1]["atom_indices"].union(mol.merging_graph.nodes[node2]["atom_indices"])
            motif_smiles = fragment2smiles(mol, atom_indices)
            q.put((mol.idx, motif_smiles))
    q.put(None)

def get_stats_consumer(stats: Dict[str, int], indices: Dict[str, Dict[int, int]], q: Queue, num_workers: int):
    num_tasks_done = 0
    while True:
        info = q.get()
        if info == None:
            num_tasks_done += 1
            if num_tasks_done == num_workers:
                break
        else:
            (idx, smi) = info
            stats[smi] += 1
            indices[smi][idx] += 1

def get_stats(mols: List[MolGraph], num_workers: int) -> Tuple[Dict[str, int], Dict[int, int]]:
    print(f"[{datetime.now()}] Begin getting statistics.")
    stats = defaultdict(int)
    indices = defaultdict(lambda: defaultdict(int))
    if num_workers == 1:
        for mol in mols:
            for (node1, node2) in mol.merging_graph.edges:
                atom_indices = mol.merging_graph.nodes[node1]["atom_indices"].union(mol.merging_graph.nodes[node2]["atom_indices"])
                motif_smiles = fragment2smiles(mol, atom_indices)
                stats[motif_smiles] += 1
                indices[motif_smiles][mol.idx] += 1
    else:
        batch_size = (len(mols) - 1) // num_workers + 1
        batches = [mols[i : i + batch_size] for i in range(0, len(mols), batch_size)]
        q = Queue()
        producers = [Process(target=get_stats_producer, args=(batches[i], q)) for i in range(num_workers)]
        [p.start() for p in producers]
        get_stats_consumer(stats, indices, q, num_workers)
        [p.join() for p in producers]
    return stats, indices

def update_stats(mol: MolGraph, graph: nx.Graph, new_graph: nx.Graph, node1: int, node2: int, stats: Dict[str, int], indices: Dict[str, Dict[int, int]], i: int):
    neighbors1 = [n for n in graph.neighbors(node1)]
    for n in neighbors1:
        if n != node2:
            atom_indices = graph.nodes[node1]["atom_indices"].union(graph.nodes[n]["atom_indices"])
            motif_smiles = fragment2smiles(mol, atom_indices)
            stats[motif_smiles] -= 1
            indices[motif_smiles][i] -= 1
    neighbors2 = [n for n in graph.neighbors(node2)]
    for n in neighbors2:
        if n != node1:
            atom_indices = graph.nodes[node2]["atom_indices"].union(graph.nodes[n]["atom_indices"])
            motif_smiles = fragment2smiles(mol, atom_indices)
            stats[motif_smiles] -= 1
            indices[motif_smiles][i] -= 1
    neighbors = [n for n in new_graph.neighbors(node1)]
    for n in neighbors:
        atom_indices = new_graph.nodes[node1]["atom_indices"].union(new_graph.nodes[n]["atom_indices"])
        motif_smiles = fragment2smiles(mol, atom_indices)
        stats[motif_smiles] += 1
        indices[motif_smiles][i] += 1

def update_stats_producer(mol: MolGraph, graph: nx.Graph, new_graph: nx.Graph, node1: int, node2: int, q: Queue, i: int):
    neighbors1 = [n for n in graph.neighbors(node1)]
    for n in neighbors1:
        if n != node2:
            atom_indices = graph.nodes[node1]["atom_indices"].union(graph.nodes[n]["atom_indices"])
            motif_smiles = fragment2smiles(mol, atom_indices)
            q.put((motif_smiles, i, -1))
    neighbors2 = [n for n in graph.neighbors(node2)]
    for n in neighbors2:
        if n != node1:
            atom_indices = graph.nodes[node2]["atom_indices"].union(graph.nodes[n]["atom_indices"])
            motif_smiles = fragment2smiles(mol, atom_indices)
            q.put((motif_smiles, i, -1))
    neighbors = [n for n in new_graph.neighbors(node1)]
    for n in neighbors:
        atom_indices = new_graph.nodes[node1]["atom_indices"].union(new_graph.nodes[n]["atom_indices"])
        motif_smiles = fragment2smiles(mol, atom_indices)
        q.put((motif_smiles, i, 1))

def apply_merging_operation_producer(motif: str, batch: List[MolGraph], q: Queue):
    [mol.apply_merging_operation_producer(motif, q) for mol in batch]
    q.put(None)

def apply_merging_operation_consumer(mols: List[MolGraph], stats: Dict[str, int], indices: Dict[str, Dict[int, int]], q: Queue, num_workers: int):
    num_tasks_done = 0
    while True:
        info = q.get()
        if info == None:
            num_tasks_done += 1
            if num_tasks_done == num_workers:
                break
        else:
            (motif, i, change) = info
            if isinstance(change, int):
                stats[motif] += change
                indices[motif][i] += change
            else:
                assert isinstance(change, nx.Graph)
                indices[motif][i] = 0
                mols[i].merging_graph = change

def apply_merging_operation(
    motif: str,
    mols: List[MolGraph],
    stats: Dict[str, int],
    indices: Dict[str, Dict[int, int]],
    num_workers: int = 1,
):
    mols_to_process = [mols[i] for i, freq in indices[motif].items() if freq > 0]
    if num_workers > 1:
        batch_size = (len(mols_to_process) -1 ) // num_workers + 1
        batches = [mols_to_process[i : i + batch_size] for i in range(0, len(mols_to_process), batch_size)]
        q = Queue()
        producers = [Process(target=apply_merging_operation_producer, args=(motif, batches[i], q)) for i in range(num_workers)]
        [p.start() for p in producers]
        apply_merging_operation_consumer(mols, stats, indices, q, num_workers)
        [p.join() for p in producers]
    else:
        [mol.apply_merging_operation(motif, stats, indices) for mol in mols_to_process]
    stats[motif] = 0

def merging_operation_learning(
    train_path: str,
    operation_path: str,
    num_iters: int,
    min_frequency: int,
    num_workers: int,
    mp_threshold: int,
):

    print(f"[{datetime.now()}] Learning merging operations from {train_path}.")
    print(f"Number of workers: {num_workers}. Total number of CPUs: {mp.cpu_count()}.\n")

    mols = load_mols(train_path, num_workers)
    stats, indices = get_stats(mols, num_workers)

    trace = []
    dir = os.path.split(operation_path)[0]
    os.makedirs(dir, exist_ok=True)
    output = open(operation_path, "w")
    for i in range(num_iters):
        print(f"[{datetime.now()}] Iteration {i}.")
        motif = max(stats, key=lambda x: (stats[x], x))
        if stats[motif] < min_frequency:
            print(f"No motif has frequency >= {min_frequency}. Stopping.\n")
            break
        print(f"[Iteration {i}] Most frequent motif: {motif}, frequency: {stats[motif]}.\n")
        trace.append((motif, stats[motif]))

        apply_merging_operation(
            motif = motif,
            mols = mols,
            stats = stats,
            indices = indices,
            num_workers = num_workers if stats[motif] >= mp_threshold else 1,
        )

        output.write(f"{motif}\n")

    output.close()
    print(f"[{datetime.now()}] Merging operation learning finished.")
    print(f"The merging operations are in {operation_path}.\n\n")

    return trace


In [13]:
import argparse
from arguments import parse_arguments

In [14]:
from model.mydataclass import Paths

In [15]:
import argparse # Just checking of the arguments.py file is imported and working
from arguments import parse_arguments
args = parse_arguments()
print(args)

Namespace(data_dir='/content/drive/MyDrive/AGILE2/AI4Sci-MiCaM/data/', preprocess_dir='/content/drive/MyDrive/AGILE2/AI4Sci-MiCaM/preprocess/', output_dir='/content/drive/MyDrive/AGILE2/AI4Sci-MiCaM/output/', tensorboard_dir='/content/drive/MyDrive/AGILE2/AI4Sci-MiCaM/tensorboard/', dataset='QM9/', job_name='', model_dir=None, generate_path='samples', num_workers=60, cuda=0, seed=2, num_operations=1000, num_iters=3000, min_frequency=0, mp_thd=100000.0, hidden_size=256, atom_embed_size=[192, 16, 16, 16, 16], edge_embed_size=256, motif_embed_size=[256, 256], latent_size=256, depth=15, motif_depth=6, dropout=0.3, virtual=False, pooling='add', steps=50000, batch_size=128, lr=100000.0, lr_anneal_iter=500, lr_anneal_rate=0.99, grad_clip_norm=1.0, beta_warmup=3000, beta_min=0.001, beta_max=0.6, beta_anneal_period=20000, prop_weight=0.5, num_sample=10000)


The following hyperparameters were changed:

parser.add_argument('--num_workers', type=int, default=60)  # Reduced number of workers

parser.add_argument('--num_iters', type=int, default=1000)  # Fewer iterations (previously 3000)  

parser.add_argument('--steps', type=int, default=10000)  # Fewer training steps (previously 50000)

parser.add_argument('--batch_size', type=int, default=32)  # Smaller batch size (previously 128) - tried changing this but gave an error

parser.add_argument('--dropout', type=float, default=0.2)  # Slightly lower dropout (previously 0.3)

parser.add_argument('--lr', type=float, default=1e-3)  # Lower learning rate (previously 1e5)

In [15]:
if __name__ == "__main__":

    args = parse_arguments()
    paths = Paths(args)

    learning_trace = merging_operation_learning(
        train_path = paths.train_path,
        operation_path = paths.operation_path,
        num_iters = args.num_iters,
        min_frequency = args.min_frequency,
        num_workers = args.num_workers,
        mp_threshold = args.mp_thd,
    )

[1;30;43mStreaming output truncated to the last 5000 lines.[0m

[2023-12-07 17:33:41.168302] Iteration 1335.
[Iteration 1335] Most frequent motif: CCCCCCCC/C=C\CCCCCCCN(C(=O)C(CCCCCC)CCCCCCCC)C(CC1CCCN1CC)C(=O)NC(C)(C)C, frequency: 1.

[2023-12-07 17:33:41.170109] Iteration 1336.
[Iteration 1336] Most frequent motif: CCCCCCCC/C=C\CCCCCCCN(C(=O)C(CCCCCC)CCCCCCCC)C(C(=O)NCCN1CCOCC1)c1cc(C)nn1C, frequency: 1.

[2023-12-07 17:33:41.172286] Iteration 1337.
[Iteration 1337] Most frequent motif: CCCCCCCC/C=C\CCCCCCCN(C(=O)C(CCCCCC)CCCCCCCC)C(C(=O)NCCN1CCOCC1)N1CCOCC1, frequency: 1.

[2023-12-07 17:33:41.174500] Iteration 1338.
[Iteration 1338] Most frequent motif: CCCCCCCC/C=C\CCCCCCCN(C(=O)C(CCCCCC)CCCCCCCC)C(C(=O)NCCN1CCOCC1)N1CCCCC1, frequency: 1.

[2023-12-07 17:33:41.177448] Iteration 1339.
[Iteration 1339] Most frequent motif: CCCCCCCC/C=C\CCCCCCCN(C(=O)C(CCCCCC)CCCCCCCC)C(C(=O)NCCN1CCOCC1)N(C)C, frequency: 1.

[2023-12-07 17:33:41.180705] Iteration 1340.
[Iteration 1340] Most frequen

# Motif Vocab Construction

In [1]:
pwd # just checking

'/content'

In [16]:
from rdkit import Contrib
from rdkit.Contrib import SA_Score
from rdkit.Contrib.SA_Score import sascorer

In [17]:
import multiprocessing as mp
import os
import os.path as path
import pickle
from collections import Counter
from datetime import datetime
from functools import partial
from typing import List, Tuple

from tqdm import tqdm

from arguments import parse_arguments
from model.mol_graph import MolGraph
from model.mydataclass import Paths

In [18]:
def apply_operations(batch: List[Tuple[int, str]], mols_pkl_dir: str) -> Counter:
    vocab = Counter()
    pos = mp.current_process()._identity[0]
    with tqdm(total = len(batch), desc=f"Processing {pos}", position=pos-1, ncols=80, leave=False) as pbar:
        for idx, smi in batch:
            mol = MolGraph(smi, tokenizer="motif")
            with open(path.join(mols_pkl_dir, f"{idx}.pkl"), "wb") as f:
                pickle.dump(mol, f)
            vocab = vocab + Counter(mol.motifs)
            pbar.update()
    return vocab

def motif_vocab_construction(
    train_path: str,
    vocab_path: str,
    operation_path: str,
    num_operations: int,
    num_workers: int,
    mols_pkl_dir: str,
):

    print(f"[{datetime.now()}] Construcing motif vocabulary from {train_path}.")
    print(f"Number of workers: {num_workers}. Total number of CPUs: {mp.cpu_count()}.")

    data_set = [(idx, smi.strip("\n")) for idx, smi in enumerate(open(train_path))]
    batch_size = (len(data_set) - 1) // num_workers + 1
    batches = [data_set[i : i + batch_size] for i in range(0, len(data_set), batch_size)]
    print(f"Total: {len(data_set)} molecules.\n")

    print(f"Processing...")
    vocab = Counter()
    os.makedirs(mols_pkl_dir, exist_ok=True)
    MolGraph.load_operations(operation_path, num_operations)
    func = partial(apply_operations, mols_pkl_dir=mols_pkl_dir)
    with mp.Pool(num_workers, initializer=tqdm.set_lock, initargs=(mp.RLock(),)) as pool:
        for batch_vocab in pool.imap(func, batches):
            vocab = vocab + batch_vocab

    atom_list = [x for (x, _) in vocab.keys() if x not in MolGraph.OPERATIONS]
    atom_list.sort()
    new_vocab = []
    full_list = atom_list + MolGraph.OPERATIONS
    for (x, y), value in vocab.items():
        assert x in full_list
        new_vocab.append((x, y, value))

    index_dict = dict(zip(full_list, range(len(full_list))))
    sorted_vocab = sorted(new_vocab, key=lambda x: index_dict[x[0]])
    with open(vocab_path, "w") as f:
        for (x, y, _) in sorted_vocab:
            f.write(f"{x} {y}\n")

    print(f"\r[{datetime.now()}] Motif vocabulary construction finished.")
    print(f"The motif vocabulary is in {vocab_path}.\n\n")


In [19]:
if __name__ == "__main__":

    args = parse_arguments()
    paths = Paths(args)
    os.makedirs(paths.preprocess_dir, exist_ok=True)

    motif_vocab_construction(
        train_path = paths.train_path,
        vocab_path = paths.vocab_path,
        operation_path = paths.operation_path,
        num_operations = args.num_operations,
        mols_pkl_dir = paths.mols_pkl_dir,
        num_workers = args.num_workers,
    )

[2023-12-07 17:52:53.627021] Construcing motif vocabulary from /content/drive/MyDrive/AGILE2/AI4Sci-MiCaM/data/QM9/train.smiles.
Number of workers: 60. Total number of CPUs: 2.
Total: 1920 molecules.

Processing...


[1;30;43mStreaming output truncated to the last 5000 lines.[0m



Processing 16:  34%|█████████▋                  | 11/32 [05:47<10:18, 29.43s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A


Processing 4:  22%|██████▌                       | 7/32 [05:59<19:33, 46.93s/it][A[A[A















Processing 17:  34%|█████████▋                  | 11/32 [06:02<11:40, 33.34s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A






Processing 8:  12%|███▊                          | 4/32 [05:58<40:02, 85.79s/it][A[A[A[A[A[A[A

Processing 3:  22%|██████▌                       | 7/32 [06:01<20:29, 49.18s/it][A[A
















Processing 18:  34%|█████████▋                  | 11/32 [06:05<10:13, 29.22s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A

















Processing 19:  31%|████████▊                   | 10/32 [06:03<13:04, 35.66s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A











Processing 13:  16%|████▌                        | 5/32 [06:05<

[2023-12-07 18:19:35.682970] Motif vocabulary construction finished.
The motif vocabulary is in /content/drive/MyDrive/AGILE2/AI4Sci-MiCaM/preprocess/QM9/num_ops_1000/vocab.txt.




# Make Training Data

In [20]:
import multiprocessing as mp
import os
import os.path as path
import pickle
from datetime import datetime
from functools import partial
from typing import List, Tuple

import torch
from tqdm import tqdm

from arguments import parse_arguments
from model.mol_graph import MolGraph
from model.mydataclass import Paths