Skip to content

Commit

Permalink
added script
Browse files Browse the repository at this point in the history
  • Loading branch information
Arús-Pous, Josep committed Jul 1, 2019
1 parent eb26691 commit f0951de
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 4 deletions.
53 changes: 53 additions & 0 deletions create_randomized_smiles.py
@@ -0,0 +1,53 @@
#!/usr/bin/env python

import argparse
import os
import functools

import utils.log as ul
import utils.chem as uc
import utils.spark as us


def parse_args():
"""Parses input arguments."""
parser = argparse.ArgumentParser(
description="Creates many datasets.")
parser.add_argument("--input-smi-path", "-i", help="Path to a SMILES file to convert.", type=str, required=True)
parser.add_argument("--output-smi-folder-path", "-o",
help="Path to a folder that will have the converted SMILES files.", type=str, required=True)
parser.add_argument("--random-type", "-r", help="Type of the converted SMILES TYPES=(restricted,unrestricted) \
[DEFAULT: restricted].", type=str, default="restricted")
parser.add_argument(
"--num-files", "-n", help="Number of SMILES files to create (numbered from 000 ...) [DEFAULT: 1]",
type=int, default=1)
parser.add_argument("--num-partitions", "-p", help="Number of SPARK partitions to use [DEFAULT: 1000]",
type=int, default=1000)

return parser.parse_args()


def main():
"""Main function."""
args = parse_args()

mols_rdd = SC.textFile(args.input_smi_path) \
.repartition(args.num_partitions) \
.map(uc.to_mol)\
.persist()

os.makedirs(args.output_smi_folder_path, exist_ok=True)

smiles_func = functools.partial(uc.randomize_smiles, random_type=args.random_type)
for i in range(args.num_files):
with open("{}/{:03d}.smi".format(args.output_smi_folder_path, i), "w+") as out_file:
for smi in mols_rdd.map(smiles_func).collect():
out_file.write("{}\n".format(smi))

mols_rdd.unpersist()


LOG = ul.get_logger("create_randomized_smiles")
if __name__ == "__main__":
SPARK, SC = us.SparkSessionSingleton.get("create_randomized_smiles")
main()
8 changes: 4 additions & 4 deletions utils/chem.py
Expand Up @@ -85,19 +85,19 @@ def to_smiles(mol):
return rkc.MolToSmiles(mol, isomericSmiles=False)


def randomize_smiles(mol, random_type="order"):
def randomize_smiles(mol, random_type="restricted"):
"""
Returns a random SMILES given a SMILES of a molecule.
:param mol: A Mol object
:param random_type: The type (branching, order) of randomization performed.
:param random_type: The type (unrestricted, restricted) of randomization performed.
:return : A random SMILES string of the same molecule or None if the molecule is invalid.
"""
if not mol:
return None

if random_type == "branching":
if random_type == "unrestricted":
return rkc.MolToSmiles(mol, canonical=False, doRandom=True, isomericSmiles=False)
if random_type == "order":
if random_type == "restricted":
new_atom_order = list(range(mol.GetNumHeavyAtoms()))
random.shuffle(new_atom_order)
random_mol = rkc.RenumberAtoms(mol, newOrder=new_atom_order)
Expand Down
41 changes: 41 additions & 0 deletions utils/spark.py
@@ -0,0 +1,41 @@
"""
Spark util functions
"""

import pyspark.sql as ps


class SparkSessionSingleton:
"""Manages unique spark sessions for each app name."""

SESSIONS = {}

def __init__(self):
raise NotImplementedError("SparkSessionSingleton is not instantiable.")

@classmethod
def get(cls, app_name, params_func=None):
"""
Retrieves (or creates) a session with a given app name.
"""

if app_name not in cls.SESSIONS:
session = ps.SparkSession.builder \
.appName(app_name)
if params_func:
params_func(session)
session = session.getOrCreate()
context = session.sparkContext
context.setLogLevel("ERROR")

cls.SESSIONS[app_name] = (session, context)
return cls.SESSIONS[app_name]

@classmethod
def cleanup(cls):
"""
Closes all sessions.
"""
for session, _ in cls.SESSIONS.values():
session.close()
cls.SESSIONS = {}

0 comments on commit f0951de

Please sign in to comment.