In [None]:
import os
import sys
os.chdir('/home/msun415/induction/')
sys.path.append('/home/msun415/my_data_efficient_grammar')
import argparse 
from rdkit import Chem
from rdkit.Chem import rdchem
from multiprocessing import Pool
from tqdm import tqdm
from itertools import permutations
from functools import reduce
from rdkit.Chem import Draw
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
import networkx as nx
from networkx.algorithms.isomorphism import GraphMatcher
from collections import defaultdict, Counter
import sys
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw
import json
from rdkit.Chem.rdmolops import FastFindRings
from itertools import accumulate, product
from copy import deepcopy
from private.molecule_graph import MolGraph
import networkx.algorithms.chordal as chordal
import os
import networkx as nx
import pandas as pd
import pickle
from src.draw.graph import draw_graph
import networkx as nx
import random
from collections.abc import Iterable
import re

def flatten(nested_iterable):
    if isinstance(nested_iterable, Iterable):
        return sum([flatten(iterable) for iterable in nested_iterable], [])
    else:
        return [nested_iterable]

# SEED = 0
# random.seed(SEED)
# np.random.seed(SEED)
# import pygsp as gsp
# from pygsp import graphs

from src.api.get_motifs import prepare_images
import openai
openai.api_key = open('notebooks/api_key.txt').readline().rstrip('\n')

In [None]:
# IMAGE_PATHS = [
#     "/home/msun415/SynTreeNet/induction/CCOC(C(N=C=O)CCCCN=C=O)=O.png",
#     "/home/msun415/SynTreeNet/induction/O=C=NC1CCC(CC2CCC(CC2)N=C=O)CC1.png",
#     "/home/msun415/SynTreeNet/induction/CC1=C(C=C(C=C1)CN=C=O)N=C=O.png",
#     "/home/msun415/SynTreeNet/induction/CC1(CC(CC(CN=C=O)(C1)C)N=C=O)C.png",
#     "/home/msun415/SynTreeNet/induction/O=C=NCCCCCCCCCCCCCCCCCCCCCCCCN=C=O.png"
#     ]


In [None]:
from rdkit.Chem.Draw import IPythonConsole
import rdkit.Chem as Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Geometry.rdGeometry import Point2D
import io
IPythonConsole.ipython_useSVG=True  #< set this to False if you want PNGs instead of SVGs
from fuseprop import find_clusters, extract_subgraph, get_mol, get_smiles, find_fragments, find_fragments_with_scaffold, __extract_subgraph
from private import *

def GetBondPosition(mol, bond, return_atom_pos=False):
    if isinstance(bond, int):
        bond = mol.GetBondWithIdx(bond)
    conf = mol.GetConformer()
    idx1 = bond.GetBeginAtomIdx()
    idx2 = bond.GetEndAtomIdx()
    pos1 = conf.GetAtomPosition(idx1)
    pos2 = conf.GetAtomPosition(idx2)
    mid_x = (pos1.x + pos2.x) / 2
    mid_y = (pos1.y + pos2.y) / 2    
    if return_atom_pos:
        return mid_x, mid_y, pos1.x, pos1.y, pos2.x, pos2.y
    else:
        return mid_x, mid_y


def draw_smiles(smiles, ax=None, order=[]):
    mol = Chem.MolFromSmiles(smiles)
    for j, a in enumerate(mol.GetAtoms()):
        a.SetProp('atomLabel', f"{a.GetSymbol()}{j}")
    AllChem.Compute2DCoords(mol)
    drawer = rdMolDraw2D.MolDraw2DCairo(600, 600)    
    options = drawer.drawOptions()  
    options.maxFontSize = 12
    options.atomLabelFontSize = 12
    drawer.DrawMolecule(mol) 
    drawer.SetFontSize(15)            
    for bond in mol.GetBonds():
        mid_x, mid_y = GetBondPosition(mol, bond)
        if bond.GetIdx() in order:
            index = order.index(bond.GetIdx())
            bond_label = f'({index+1}) bond_{bond.GetIdx()}'
        else:
            bond_label = f'bond_{bond.GetIdx()}'
        drawer.DrawString(bond_label, Point2D(mid_x, mid_y))
    drawer.FinishDrawing()
    # drawer.WriteDrawingText(os.path.join(dir_name, f'{i}.png'))
    img_data = drawer.GetDrawingText()
    if ax:
        image = io.BytesIO(img_data)
        img = plt.imread(image)
        ax.imshow(img)    
    else:
        from IPython.display import Image
        return Image(data=img_data)
    


def draw_mol(mol, ax=None, bonds=None, return_drawer=False):
    drawer = rdMolDraw2D.MolDraw2DCairo(600, 600)
    if bonds is None:
        drawer.DrawMolecule(mol)
    else:
        highlight_bond_map = {}
        for b in bonds:
            highlight_bond_map[b] = [(0, 0, 1)]
        drawer.DrawMoleculeWithHighlights(mol, '', 
                                          highlight_atom_map={}, 
                                          highlight_bond_map=highlight_bond_map, 
                                          highlight_radii={}, 
                                          highlight_linewidth_multipliers={})
    if return_drawer:
        return drawer
    drawer.FinishDrawing()
    img_data = drawer.GetDrawingText()
    if ax:
        image = io.BytesIO(img_data)
        img = plt.imread(image)
        ax.imshow(img)    
    else:
        return Image(data=img_data)    
    


def draw_cliques(cg, mol, ax=None, cq=None, label=True):
    """
    This function draws the cliques in cq, highlighting them in mol.
    Parameters:
        cg: clique graph, where nodes are bonds of mol, edges are atoms
        mol: the mol object to draw
        ax: if given, draw on ax
        cq: if given, draws predefined cliques, can be one of following:
            tuple (id, nodes, *color)
            list of tuples
        label: whether to annotate the id as text
    Output:
        Image drawn
    """
    for j, a in enumerate(mol.GetAtoms()):
        a.SetProp('atomLabel', f"{a.GetSymbol()}{j}")
    global_color = (1, 0, 0)
    AllChem.Compute2DCoords(mol)
    drawer = rdMolDraw2D.MolDraw2DCairo(600, 600)
    options = drawer.drawOptions()  
    options.noAtomLabels = True
    drawer.SetFillPolys(True)
    drawer.SetColour(global_color)
    options.maxFontSize = 20
    if cq:
        if isinstance(cq, tuple):
            cqs = [cq]
        else:
            assert isinstance(cq, list)
            cqs = cq
    else:
        cqs = list(enumerate(nx.find_cliques(cg)))
    for i, cq_arg in enumerate(cqs):
        if len(cq_arg) == 2:
            e, cq = cq_arg
            color = global_color
        else:
            e, cq, color = cq_arg
        x, y = 0, 0
        highlight_bond_map = {}
        for b in cq:
            bx, by = GetBondPosition(mol, b)
            x += bx
            y += by
            highlight_bond_map[b] = [color]      
        drawer.DrawMoleculeWithHighlights(mol, '', 
                                          highlight_atom_map={}, 
                                          highlight_bond_map=highlight_bond_map, 
                                          highlight_radii={}, 
                                          highlight_linewidth_multipliers={})
        if label: # TODO: Support len(cqs) > 1
            x /= len(cq)
            y /= len(cq)              
            label = f"e{e}"
            drawer.DrawString(label, Point2D(x, y))
    drawer.FinishDrawing()
    # drawer.WriteDrawingText(os.path.join(dir_name, f'{i}.png'))
    img_data = drawer.GetDrawingText()
    if ax:
        image = io.BytesIO(img_data)
        img = plt.imread(image)
        ax.imshow(img)    
    else:
        return Image(data=img_data)

# subgraphs = []
# subgraphs_idx_i = []
# clusters, atom_cls = find_clusters(mol)
# for i,cls in enumerate(clusters):
#     clusters[i] = set(list(cls))
# for i, cluster in enumerate(clusters):
#     _, subgraph_i_mapped, _ = extract_subgraph(smiles, cluster)
#     subgraphs.append(SubGraph(subgraph_i_mapped, mapping_to_input_mol=subgraph_i_mapped, subfrags=list(cluster)))
#     subgraphs_idx_i.append(list(cluster))
    


In [None]:
def mol_to_graph(mol):
    g = nx.Graph()
    for b in mol.GetBonds():
        g.add_node(b.GetIdx())
    for a in mol.GetAtoms():
        bs = [b.GetIdx() for b in a.GetBonds()]
        g.add_edges_from(product(bs, bs))
    g.remove_edges_from(nx.selfloop_edges(g))
    return g

class MolHG:
    def __init__(self, mol):
        if isinstance(mol, str):
            mol = Chem.MolFromSmiles(mol)
        self.mol = mol
        self.chordal_graph = self.llm_chordalize(mol)
    
    @staticmethod
    def llm_chordalize(mol):
        pass


In [None]:
def isRingAromatic(mol, bondRing):
    for id in bondRing:
        if not mol.GetBondWithIdx(id).GetIsAromatic():
            return False
    return True


def GetBondsAmongAtoms(mol, ring):
    return [mol.GetBondBetweenAtoms(i, j).GetIdx() for i in ring for j in ring if mol.GetBondBetweenAtoms(i, j)]    


def GetRingsWithBond(mol, b_idx):
    res = []
    for b_ring in mol.GetRingInfo().BondRings():
        add = False
        for idx in b_ring:
            if idx == b_idx:
                add = True
                break
        if add:
            res.append(b_ring)
    return res

# def get_clique_graph(mol):
#     g = mol_to_graph(mol)
#     for ring in mol.GetRingInfo().BondRings():
#         if isRingAromatic(mol, ring):
#             g.add_edges_from(product(ring, ring))
#     g.remove_edges_from(nx.selfloop_edges(g))
#     return g

def get_clique_graph(mol):
    g = mol_to_graph(mol)
    for ring in Chem.GetSymmSSSR(mol): # TODO
        ring = list(ring)
        ring = GetBondsAmongAtoms(mol, ring)
        g.add_edges_from(product(ring, ring))
    g.remove_edges_from(nx.selfloop_edges(g))
    return g

def my_complete_to_chordal(cg, mol):
    while True:
        try:            
            res, order = chordal._find_chordality_breaker(cg)
        except:
            break
        u, _, w = res
        u_rings = GetRingsWithBond(mol, u)
        w_rings = GetRingsWithBond(mol, w)
        if len(u_rings) > 1 or len(w_rings) > 1:
            breakpoint()
        else:
            u_r = u_rings[0]
            w_r = w_rings[0]
            print(u_r, w_r)
            cg.add_edges_from(product(u_r, w_r))
            cg.remove_edges_from(nx.selfloop_edges(cg))
    return cg


In [None]:
# smiles = 'C1=CC=C2[NH]C(C3=N[NH]N=N3)=CC2=C1'
# smiles = 'COC(=O)c1cc2csc(c3cc4c(s3)c(OC)c3ccsc3c4OC)c2s1'
smiles = 'C12(Cl)C(Cl)=C(Cl)C(Cl)(C1(Cl)Cl)C3C2C4CC3C5OC45'
def chordal_mol_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    cg = get_clique_graph(mol)
    # res, order = chordal._find_chordality_breaker(cg)
    # cg, _, chords = chordal.complete_to_chordal_graph(cg)
    cg = my_complete_to_chordal(cg, mol)
    assert chordal.is_chordal(cg)
    return mol, cg
    # chordal.chordal_graph_cliques(cg)    



In [None]:
def clique_drawing(cg, mol, path):
    # draw_cliques(cg, mol)
    cliques = list(nx.find_cliques(cg))
    n = len(cliques)
    d = int(np.sqrt(n))
    fig, axes = plt.subplots(d, n//d, figsize=(20, 20))
    axes = flatten(axes)
    for i, (cq, ax) in enumerate(zip(cliques, axes)):
        ax.axis('off')
        ax.set_title(f"{i}")
        draw_cliques(cg, mol, ax=ax, cq=(i, cq), label=False)
    fig.set_facecolor('white')
    fig.savefig(path, bbox_inches='tight', dpi=100)
    print(os.path.abspath(path))
    return cliques

In [None]:
fig_dir = 'data/api_mol_hg/'
prompt_1_path = 'data/api_mol_hg_1.txt'
prompt_2_path = 'data/api_mol_hg_2.txt'
prompt_3_path = 'data/api_mol_hg_3.txt'

def get_next_version(fig_dir, dir=True):
    if dir:
        check = os.path.isdir
    else:
        check = os.path.isfile
    dirs = [d for d in os.listdir(fig_dir) if check(os.path.join(fig_dir, d))]
    if dir:
        versions = [int(d) for d in dirs]
    else:
        versions = [int(f.split('.')[0]) for f in dirs]
    return max(versions)+1 if len(versions) else 0


def llm_call(img_paths, prompt_path, optional_prompt=None):
    """
    This function uses prompt read from prompt_path and a list of img content.
    Parameters:
        img_paths: list of paths to img files
        prompt_path: a .txt file path
        optional_prompt: lambda function with single arg
        optional text prompt to process the output of the response
    Output:
        Response of call
    """
    base64_images = prepare_images(img_paths)
    prompt = ''.join(open(prompt_path).readlines())
    completion = openai.ChatCompletion.create(model="gpt-4o", 
        messages=[{"role": "user", 
                "content": [
                    {"type": "text",
                    "text": prompt}]+[
                        {"type": "image_url",
                        "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}
                        } for base64_image in base64_images
                    ]}]
    )
    res = completion.choices[0].message.content    
    if optional_prompt:
        completion = openai.ChatCompletion.create(model="gpt-4o", 
            messages=[{"role": "user", 
                    "content": [
                        {"type": "text",
                        "text": optional_prompt(res)}]}]
        )
        res = completion.choices[0].message.content
    return res    
                                                     

def llm_choose_edit(img_path, prompt_path):
    post_prompt = lambda res: f"I want you to perform a simple data post-processing step of the following response:\n{res}\n The input is a response from another language agent. It may or may not contain an answer in the form of a single pair. If it does, output the pair in x,y format and NOTHING ELSE. Don't include explanations or superlatives. Just output the answer. If it doesn't contain an answer in the form of a pair, output the single word NONE."
    return llm_call([img_path], prompt_path, post_prompt)


def llm_edit_cliques(cg, mol, prompt_path):
    d = get_next_version(fig_dir)
    dir_name = os.path.join(fig_dir, f'{d}')
    os.makedirs(dir_name)
    while True:
        i = get_next_version(dir_name, dir=False)
        path = os.path.join(dir_name, f'{i}.png')
        cliques = clique_drawing(cg, mol, path)
        pair = llm_choose_edit(path, prompt_path)
        match = re.match(f"(\d+),(\d+)", pair)
        if match:
            e1 = int(match.groups()[0])
            e2 = int(match.groups()[1])
            if max(e1, e2) >= len(cliques):
                break
        else:
            break
        cq = cliques[e1]+cliques[e2]
        cg.add_edges_from(product(cq, cq))
        cg.remove_edges_from(nx.selfloop_edges(cg))    
    return cg, path


def llm_choose_root(img_path, prompt_path):
    post_prompt = lambda res: f"I want you to perform a simple data post-processing step of the following response:\n{res}\n The input is a response from another language agent. It may or may not contain an answer in the form of a single integer. If it does, output the integer and NOTHING ELSE. Don't include explanations or superlatives. Just output the answer. If it doesn't contain an answer in the form of a pair, output the single word NONE."
    root = llm_call([img_path], prompt_path, post_prompt)
    match = re.match("^\d+$", root)
    if match:
        return int(root)
    else:
        return 0


def init_tree(cg):
    tree = nx.Graph()
    for cq in nx.find_cliques(cg):
        tree.add_node(len(tree), nodes=cq)
    for n1 in tree:
        for n2 in tree:
            nodes1 = set(tree.nodes[n1]['nodes'])
            nodes2 = set(tree.nodes[n2]['nodes'])
            if nodes1 & nodes2:
                tree.add_edge(n1, n2, weight=len(nodes1 & nodes2))
    tree.remove_edges_from(nx.selfloop_edges(tree))
    return tree


def llm_break_edge(img_path, prompt_path):
    post_prompt = lambda res: f"I want you to perform a simple data post-processing step of the following response:\n{res}\n The input is a response from another language agent. It may or may not contain an answer in the form of a single integer. If it does, output the integer and NOTHING ELSE. Don't include explanations or superlatives. Just output the answer. If it doesn't contain an answer in the form of a pair, output the single word NONE."
    return llm_call([img_path], prompt_path, post_prompt)


def draw_cycle(cyc, tree, mol, path):
    n = len(cyc)
    d = int(np.sqrt(n))
    fig, axes = plt.subplots(d, n//d, figsize=(20, 20))
    axes = flatten(axes)
    for i, ((cq_1_id, cq_2_id), ax) in enumerate(zip(cyc, axes)):
        ax.axis('off')
        ax.set_title(f"{i}")
        cq_1 = tree.nodes[cq_1_id]['nodes']
        cq_2 = tree.nodes[cq_2_id]['nodes']
        draw_cliques(None, mol, ax, [(None, cq_1, (1, 0, 0)), (None, cq_2, (0, 1, 0))], label=False)
    fig.set_facecolor('white')
    fig.savefig(path, bbox_inches='tight', dpi=100)
    print(os.path.abspath(path))
    return path


def llm_break_cycles(tree, mol, root, prompt_path):
    d = get_next_version(fig_dir)
    dir_name = os.path.join(fig_dir, f'{d}')    
    os.makedirs(dir_name)
    while not nx.is_tree(tree):
        i = get_next_version(dir_name, dir=False)
        path = os.path.join(dir_name, f'{i}.png')
        cyc = nx.find_cycle(tree, root)
        if cyc:
            path = draw_cycle(cyc, tree, mol, path)
            e = llm_break_edge(path, prompt_path)
            match = re.match("^\d+$", e)
            if match:
                e = int(e)
                if e >= len(cyc):
                    continue
                e1, e2 = cyc[e]
                tree.remove_edge(e1, e2)
        else:
            break
    return tree


mol, cg = chordal_mol_graph(smiles)
cg, path = llm_edit_cliques(cg, mol, prompt_1_path)
tree = init_tree(cg)
print(nx.is_tree(tree))
root = llm_choose_root(path, prompt_2_path)
tree = llm_break_cycles(tree, mol, root, prompt_3_path)
# tree = nx.maximum_spanning_tree(tree)
# while not nx.is_tree(tree):
nx.is_tree(tree)

In [None]:
from src.draw.utils import hierarchy_pos

pos = hierarchy_pos(tree, root)
nx.draw(tree, pos, with_labels=True)

In [None]:
nx.draw(tree, pos, with_labels=True)

In [None]:
import HRG.create_production_rules as cpr
import importlib
from collections import defaultdict
importlib.reload(cpr)

# cpr.learn_production_rules(cg, tree, root)
def convert_to_node_set_tree(tree):
    assert len(tree) == len(set(tuple(sorted(tuple(tree.nodes[n]['nodes']))) for n in tree)) # make sure no repeated nodes
    # T = convert_to_directed(tree, root)    
    T = defaultdict(set)
    for t_1, t_2 in tree.edges:
        nodes_1 = frozenset(tree.nodes[t_1]['nodes'])
        nodes_2 = frozenset(tree.nodes[t_2]['nodes'])
        T[frozenset(nodes_1)].add(nodes_2)
        T[frozenset(nodes_2)].add(nodes_1)        
    return T

T = convert_to_node_set_tree(tree)
set_root = frozenset(tree.nodes[root]['nodes'])
rules = cpr.learn_production_rules(cg, T, set_root)

In [None]:
draw_smiles(smiles)

In [None]:
def GetAtomsFromBonds(mol, bonds):
    return list(set(flatten([[mol.GetBondWithIdx(bond).GetBeginAtomIdx(), mol.GetBondWithIdx(bond).GetEndAtomIdx()] for bond in bonds])))


class Grammar:
    def __init__(self, mol, rules):
        self.rules = rules
        self.mol = mol

    
    @staticmethod
    def __extract_subgraph(mol, selected_atoms):
        selected_atoms = set(selected_atoms)
        roots = []
        for idx in selected_atoms:
            atom = mol.GetAtomWithIdx(idx)
            bad_neis = [y for y in atom.GetNeighbors() if y.GetIdx() not in selected_atoms]
            if len(bad_neis) > 0:
                roots.append(idx)

        new_mol = Chem.RWMol(mol)
        for atom in new_mol.GetAtoms():
            atom.SetIntProp('org_idx', atom.GetIdx())
        for bond in new_mol.GetBonds():
            bond.SetIntProp('org_idx', bond.GetIdx())
        for atom_idx in roots:
            atom = new_mol.GetAtomWithIdx(atom_idx)
            atom.SetAtomMapNum(1)
            aroma_bonds = [bond for bond in atom.GetBonds() if bond.GetBondType() == Chem.rdchem.BondType.AROMATIC]
            aroma_bonds = [bond for bond in aroma_bonds if bond.GetBeginAtom().GetIdx() in selected_atoms and bond.GetEndAtom().GetIdx() in selected_atoms]
            if len(aroma_bonds) == 0:
                atom.SetIsAromatic(False)

        remove_atoms = [atom.GetIdx() for atom in new_mol.GetAtoms() if atom.GetIdx() not in selected_atoms]
        remove_atoms = sorted(remove_atoms, reverse=True)
        for atom in remove_atoms:
            new_mol.RemoveAtom(atom)
        return new_mol.GetMol()        


    def NumNTs(self):
        return len(self.rules)

    
    def GetNTs(self):
        return list(self.rules)


    def NumRulesForNT(self, nonterm):
        assert nonterm in self.rules
        return len(self.rules[nonterm])

    def GetRule(self, nonterm, idx):
        assert idx < self.NumRulesForNT(nonterm)
        return list(self.rules[nonterm])[idx]        

    def EdgesForRule(self, nonterm, idx, nt_only=False):
        rule = self.GetRule(nonterm, idx)
        nt_or_t = 'N' if nt_only else 'N|T'
        matches = re.findall(f'(\((?:\w+,)*\w+:(?:{nt_or_t})\))', rule)
        is_nts = [':N' in match for match in matches]
        return matches, is_nts


    def GetEdgeForRule(self, nonterm, idx, i, nt_only=False):
        matches, is_nts = self.EdgesForRule(nonterm, idx, nt_only=nt_only)
        return matches[i], is_nts[i]

    def RHSNodes(self, nonterm, idx):
        rule = self.GetRule(nonterm, idx)
        _, nodes, d = self.rules[nonterm][rule]
        return nodes, d


    def RHSMol(self, nonterm, idx):
        nodes, _ = self.RHSNodes(nonterm, idx)
        return Grammar.__extract_subgraph(self.mol, GetAtomsFromBonds(self.mol, nodes))
    

    def RHSEdgeMol(self, nonterm, idx, i, nt_only=False):
        nodes, d = self.RHSNodes(nonterm, idx)
        inv_d = dict(zip(d.values(), d.keys()))
        print(d, "map org idx to rhs idx")
        print(inv_d, "map rhs idx to org idx")
        print(nodes, "nodes of rhs")
        match, is_nt = self.GetEdgeForRule(nonterm, idx, i, nt_only=nt_only)
        if ':N' in match:
            print(match, "match")
        nt_or_t = 'N' if nt_only else 'N|T'
        grps = re.match(f'\(((?:\w+,)*\w+):({nt_or_t})\)', match)
        nodes_idx_str, _ = grps.groups()
        nodes_idx = list(map(lambda ind: int(ind) if ind.isdigit() else ind, nodes_idx_str.split(',')))
        print(nodes_idx, "rhs edge nodes idx")
        bonds = [inv_d[ind] for ind in nodes_idx] # these refer to bonds of hyperedge
        # print(bonds, "bonds")
        # print([nodes[ind] for ind in nodes_idx], "old bonds")        
        rhs_edge_mol = Grammar.__extract_subgraph(mol, GetAtomsFromBonds(mol, bonds))        
        rhs_edge_bond_lookup = Grammar.bond_lookup(rhs_edge_mol)
        print(rhs_edge_bond_lookup, "rhs edge bond lookup")
        rhs_edge_bonds = [rhs_edge_bond_lookup[b] for b in bonds] # TODO: map bonds of hyperedge to bonds in rhs_mol
        # return rhs_edge_bonds, bonds, rhs_edge_mol
        nodes_rhs_edge_idx = [inv_d[b] for b in nodes_idx] # org idx in mol
        return nodes_rhs_edge_idx, rhs_edge_bonds, rhs_edge_mol, is_nt


    def VisRule(self, nonterm, idx, nt_only=False):
        edges, _ = self.EdgesForRule(nonterm, idx, nt_only)
        fig, axes = plt.subplots(1, len(edges)+1, figsize=(10, 10))
        axes = flatten(axes)
        nodes, _ = self.RHSNodes(nonterm, idx)
        rhs_mol = Grammar.__extract_subgraph(self.mol, GetAtomsFromBonds(self.mol, nodes))
        draw_mol(rhs_mol, axes[0])
        for i in range(len(edges)):
            _, bonds, rhs_edge_mol, is_nt = self.RHSEdgeMol(nonterm, idx, i, nt_only)
            draw_mol(rhs_edge_mol, axes[i+1], bonds=bonds)
        return fig
    

    @staticmethod
    def bond_lookup(rhs_mol):
        return {bd.GetIntProp('org_idx'): bd.GetIdx() for bd in rhs_mol.GetBonds()}
    
    @staticmethod
    def atom_lookup(rhs_mol):
        return {at.GetIntProp('org_idx'): at.GetIdx() for at in rhs_mol.GetAtoms()}
    

    def VisRuleAlt(self, nonterm, idx, nt_only=False, ax=None):
        edges, _ = self.EdgesForRule(nonterm, idx, nt_only=nt_only)
        nodes, d = self.RHSNodes(nonterm, idx)
        inv_d = dict(zip(d.values(), d.keys()))
        rhs_mol = Grammar.__extract_subgraph(self.mol, GetAtomsFromBonds(self.mol, nodes))
        nts = re.match('\(((?:[a-z],)*[a-z])\)', nonterm)
        bond_lookup = Grammar.bond_lookup(rhs_mol)
        if nts is None:
            anchors = None
        else:
            nts = nts.groups()[0].split(',')
            anchors = [inv_d[nt] for nt in nts]            
            anchors = [bond_lookup[b] for b in anchors]         
        drawer = draw_mol(rhs_mol, bonds=anchors, return_drawer=True)             
        dim = 0.05
        for i in range(len(edges)):
            random_color = (random.random(),random.random(),random.random())
            bonds, _, _, is_nt = self.RHSEdgeMol(nonterm, idx, i, nt_only=nt_only)        
            bonds = [bond_lookup[b] for b in bonds]
            print(bonds, "bonds", is_nt, "is_nt")
            bonds_pos = [GetBondPosition(rhs_mol, bond) for bond in bonds]
            bonds_pos_mean = np.array(bonds_pos).mean(axis=0)            
            drawer.SetFillPolys(not is_nt)            
            if not is_nt:
                continue
            drawer.DrawRect(Point2D(*(bonds_pos_mean-dim)), Point2D(*(bonds_pos_mean+dim)))
            for bond in bonds:
                x, y = GetBondPosition(rhs_mol, bond)
                drawer.DrawArrow(Point2D(x, y), Point2D(*bonds_pos_mean), 
                                 asPolygon=True, color=random_color, frac=0.3)
        drawer.FinishDrawing()
        img_data = drawer.GetDrawingText()          
        if ax is None:
            from IPython.display import Image
            return Image(data=img_data)
        else:
            image = io.BytesIO(img_data)
            img = plt.imread(image)
            ax.imshow(img)             
    

    def VisAllRules(self):
        max_title_length = 20
        nonterms = self.GetNTs()
        counts = [self.NumRulesForNT(nonterm) for nonterm in nonterms]
        fig, axes = plt.subplots(len(counts), max(counts), figsize=(50, 50))
        for i in range(len(counts)):            
            for j in range(counts[i]):
                if j == 0:
                    axes[i][j].set_ylabel(nonterms[i], fontsize=36)
                nt_edges = self.EdgesForRule(nonterms[i], j, False)[0]
                nt_edges = ''.join(nt_edges)
                print(nt_edges)
                print(len(nt_edges))
                if len(nt_edges) > max_title_length:
                    title = "..." + nt_edges[-max_title_length:]          
                else:
                    title = nt_edges
                axes[i][j].set_title(title, fontsize=36)
                self.VisRuleAlt(nonterms[i], j, ax=axes[i][j])
        return fig



# test_rhs = list(rules['(S)'])[0]
# print(test_rhs)
# matches = re.findall('(\((?:\d+,)*\d+:(?:N)\))', test_rhs)

# i = 0
# print(matches[i])
# grps = re.match('\(((?:\d+,)*\d+):(N)\)', matches[i])
# nodes_idx_str, symbol = grps.groups()
# nodes_idx = list(map(int, nodes_idx_str.split(',')))
# bonds = [nodes[ind] for ind in nodes_idx]
# __extract_subgraph(mol, GetAtomsFromBonds(mol, bonds))[0]

In [None]:
draw_smiles(smiles)

In [None]:
g = Grammar(mol, rules)
# g.VisRuleAlt('(a)',2)
# rhs_mol = g.RHSMol('(a)',2)
# print(Grammar.atom_lookup(rhs_mol))
# rhs_mol.GetBondBetweenAtoms(1,2).GetIdx()
fig = g.VisAllRules()
fig.set_facecolor('white')
fig.savefig('/home/msun415/rules.png')

In [None]:
mol.GetBondBetweenAtoms(6,8).GetIdx()

In [None]:
g.EdgesForRule('(a,b)',0,nt_only=True)

In [None]:
g.VisRuleAlt('(a)', 0, True)

In [None]:
g.VisRuleAlt('(a,b,c,d,e)', 0, False)

In [None]:
g.GetNTs()

In [None]:
g.VisRuleAlt('(a,b)', 0, False)

In [None]:
g.RHSEdgeMol('(S)', 0, 3, nt_only=True)

In [None]:


drawer = rdMolDraw2D.MolDraw2DCairo(600, 600)
AllChem.Compute2DCoords(rhs_mol)

bond_lookup = {bd.GetIntProp('org_idx'): bd.GetIdx() for bd in rhs_mol.GetBonds()}
new_bonds = [bond_lookup[b] for b in bonds]
atoms = GetAtomsFromBonds(rhs_mol, new_bonds)
conf = rhs_mol.GetConformer()
atoms_pos = [conf.GetAtomPosition(atom) for atom in atoms]
atoms_pos = [[pos.x, pos.y] for pos in atoms_pos]
atoms_pos_mean = np.array(atoms_pos).mean(axis=0)
drawer.DrawMolecule(rhs_mol)
drawer.DrawRect(Point2D(*(atoms_pos_mean-0.1)), Point2D(*(atoms_pos_mean+0.1)))
for bond in new_bonds:
    x, y = GetBondPosition(rhs_mol, bond)
    drawer.DrawArrow(Point2D(x, y), Point2D(*atoms_pos_mean))

drawer.FinishDrawing()
img_data = drawer.GetDrawingText()
from IPython.display import Image
Image(data=img_data)

In [None]:
from networkx.algorithms import chordal_graph_cliques, complete_to_chordal_graph
chordal_graph, _ = complete_to_chordal_graph(g)
# t = nx.junction_tree(g)
# fig, ax = plt.subplots(figsize=(20,20))
# pos = nx.spring_layout(t)
# nx.draw(t, pos, with_labels=True, ax=ax,)
cliques = [tuple(sorted(i)) for i in chordal_graph_cliques(chordal_graph)]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(40,20))
nx.draw(chordal_graph, with_labels=True, ax=ax1)

In [None]:
draw_smiles('C1=CC2=C3C=CC(=C4C=CC(=C5C=CC(=C6C=CC(=C1C=C2)C=C6)C=C5)C=C4)C=C3', ax2)

In [None]:
fig