Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Arús-Pous, Josep
committed
Jul 1, 2019
1 parent
eb26691
commit f0951de
Showing
3 changed files
with
98 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#!/usr/bin/env python | ||
|
||
import argparse | ||
import os | ||
import functools | ||
|
||
import utils.log as ul | ||
import utils.chem as uc | ||
import utils.spark as us | ||
|
||
|
||
def parse_args(): | ||
"""Parses input arguments.""" | ||
parser = argparse.ArgumentParser( | ||
description="Creates many datasets.") | ||
parser.add_argument("--input-smi-path", "-i", help="Path to a SMILES file to convert.", type=str, required=True) | ||
parser.add_argument("--output-smi-folder-path", "-o", | ||
help="Path to a folder that will have the converted SMILES files.", type=str, required=True) | ||
parser.add_argument("--random-type", "-r", help="Type of the converted SMILES TYPES=(restricted,unrestricted) \ | ||
[DEFAULT: restricted].", type=str, default="restricted") | ||
parser.add_argument( | ||
"--num-files", "-n", help="Number of SMILES files to create (numbered from 000 ...) [DEFAULT: 1]", | ||
type=int, default=1) | ||
parser.add_argument("--num-partitions", "-p", help="Number of SPARK partitions to use [DEFAULT: 1000]", | ||
type=int, default=1000) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def main(): | ||
"""Main function.""" | ||
args = parse_args() | ||
|
||
mols_rdd = SC.textFile(args.input_smi_path) \ | ||
.repartition(args.num_partitions) \ | ||
.map(uc.to_mol)\ | ||
.persist() | ||
|
||
os.makedirs(args.output_smi_folder_path, exist_ok=True) | ||
|
||
smiles_func = functools.partial(uc.randomize_smiles, random_type=args.random_type) | ||
for i in range(args.num_files): | ||
with open("{}/{:03d}.smi".format(args.output_smi_folder_path, i), "w+") as out_file: | ||
for smi in mols_rdd.map(smiles_func).collect(): | ||
out_file.write("{}\n".format(smi)) | ||
|
||
mols_rdd.unpersist() | ||
|
||
|
||
LOG = ul.get_logger("create_randomized_smiles") | ||
if __name__ == "__main__": | ||
SPARK, SC = us.SparkSessionSingleton.get("create_randomized_smiles") | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
""" | ||
Spark util functions | ||
""" | ||
|
||
import pyspark.sql as ps | ||
|
||
|
||
class SparkSessionSingleton: | ||
"""Manages unique spark sessions for each app name.""" | ||
|
||
SESSIONS = {} | ||
|
||
def __init__(self): | ||
raise NotImplementedError("SparkSessionSingleton is not instantiable.") | ||
|
||
@classmethod | ||
def get(cls, app_name, params_func=None): | ||
""" | ||
Retrieves (or creates) a session with a given app name. | ||
""" | ||
|
||
if app_name not in cls.SESSIONS: | ||
session = ps.SparkSession.builder \ | ||
.appName(app_name) | ||
if params_func: | ||
params_func(session) | ||
session = session.getOrCreate() | ||
context = session.sparkContext | ||
context.setLogLevel("ERROR") | ||
|
||
cls.SESSIONS[app_name] = (session, context) | ||
return cls.SESSIONS[app_name] | ||
|
||
@classmethod | ||
def cleanup(cls): | ||
""" | ||
Closes all sessions. | ||
""" | ||
for session, _ in cls.SESSIONS.values(): | ||
session.close() | ||
cls.SESSIONS = {} |