In [1]:
import torch
import numpy as np

S = np.asarray([[0,0,0, 1, 1], [1,1,1, 0,0]])

y = np.array([1, 1, 0, 0, 0])
ypred= np.array([1, 1,0, 0, 0])

def p_x_g_k(S, i, j, k):
    return S[k][i] * S[k][j] / np.sum(S[k])

def score(S, y, ypred):
    s = 0
    for i in range(len(y)):
        for j in range(len(ypred)):
            for k in range(2):
                s += p_x_g_k(S, i, j, k)*y[i]*ypred[j] 
    return s
score(S, y, ypred)

1.3333333333333333

In [99]:
%load_ext autoreload
%autoreload 2
import torch
from molga.scoring import MolMetrics
from molga.utils.data import smiles_to_mols
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole #Needed to show molecules
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions #Only needed if modifying defaults
from molga.generator.aae import AAE
from molga.generator.networks import *
from molga.utils.graph_utils import to_cuda
from molga.utils import sanifix
DrawingOptions.bondLineWidth=1.8

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [161]:
import os

In [162]:
from molga.configs import gen_model_getter, sup_model_getter

def copy_edit_mols(mol, cp=False):
    if not cp:
        return mol
    else:
        Chem.rdmolops.Cleanup(mol)
        mol = sanifix.fix_mol(mol)
        return mol
    
def get_standard_model(model_path, id_number):
    model_params = gen_model_getter("aae", id_number)
    print(model_params)
    network = AAE(metrics={}, mse_loss=True)
    encoder = Encoder(5, 128, nedges=3, feat_dim=12, gather="sum", **model_params["encoder"])
    decoder = Decoder(128, 5, nedges=3, max_vertex=9, other_feat_dim=0, **model_params["decoder"])
    discriminator = MLPdiscriminator(128, **model_params["discriminator"])
    network.set_encoder(encoder)
    network.set_decoder(decoder)
    network.set_discriminator(discriminator)
    checkpoint = torch.load(model_path, map_location="cpu")
    network.load_state_dict(checkpoint['model'])
    gpu = checkpoint["gpu"]
    return network,  gpu

def sample_molecules(model, n_mols=25000, batch_size=512, gpu=False, cp=False):
        mols = []
        if not batch_size:
            batch_size = n_mols
        n_sample = n_mols // batch_size
        model.eval()
        with torch.no_grad():
            for k in range(n_sample):
                z = self.model.sample(batch_size)
                if gpu:
                    z = to_cuda(z)
                gen_data = model.forward_generator(z)
                # Postprocess with Gumbel softmax
                fake_data = model.postprocess(gen_data, hard=True, sample=True)
                fake_mols, valids = data2mol(fake_data)
                smiles = convert_mol_to_smiles([copy_edit(fake_mols[i], cp)  for i, val in valids if val])
                mols.extend(smiles)
        return mols

In [4]:
#model, gpu =get_standard_model("GEN_RES/qm9_mlp_mse_1/AAE/model.epoch:17-loss:0.15.pth.tar", 1)
def weighted_choice(choices):
    total = sum(w for c, w in choices)
    r = random.uniform(0, total)
    upto = 0
    for c, w in choices:
        if upto + w >= r:
            return c
        upto += w
    assert False, "Shouldn't get here"

In [6]:
with open("../data/qm9/train.csv") as REF:
    ref_smiles = [x.strip() for x in REF]
    ref_mols = smiles_to_mols(ref_smiles)

In [185]:
seq = [ "GEN_RES/qm9_mlp_mse_2/", "GEN_RES/qm9_mlp_diffpool/", "GEN_RES/qm9_mlp_gcn/"]

In [188]:
for s in seq:
    with open(os.path.join(s, "mols.txt"), "r" ) as IN:
        mols_1 = [x.strip() for x in IN]
        new_sm = []
        mols = []
        for sm in mols_1:
            mol = Chem.MolFromSmiles(sm, sanitize=True)
            if not mol:
                mol = Chem.MolFromSmiles(sm, sanitize=False) 
                repaired_mol = copy_edit_mols(mol, True)
                if repaired_mol:
                    sm = Chem.MolToSmiles(repaired_mol)
                mol = Chem.MolFromSmiles(sm, sanitize=True)
            new_sm.append(sm)
            mols.append(mol)

        val = MolMetrics.validity(mols)
        prob = (val) / val.sum()
        ind = np.random.choice(len(new_sm), 10000, p=prob) 
        new_sm  = np.asarray(new_sm)
        mols = np.asarray(mols)
        new_sm = new_sm[ind]
        mols = mols[ind]
        with open(os.path.join(s, "fixed_mols2.txt"), "w") as OUT:
            OUT.writelines(new_sm)
        f_mols = [x for x in mols if x is not None]
        novelty = MolMetrics.novelty(f_mols, ref_smiles)
        uniq = MolMetrics.uniqueness(f_mols)
        val = MolMetrics.validity(mols)
        print(s, novelty.mean(), uniq, val.mean())

KeyboardInterrupt: 

In [179]:
val = MolMetrics.validity(mols)


In [180]:
len(val)*0.23

11489.880000000001