Skip to content

Commit

Permalink
Parallel plot refactor (#247)
Browse files Browse the repository at this point in the history
* Fix duplicate plotting in CRISPRessoBatch aggregate

* Refactor mulltiprocessing plots in CRISPRessoBatch

* Refactor multiprocessing plots in CRISPRessoCORE

* Refactor multiprocessing plots for CRISPRessoAggregate
  • Loading branch information
Colelyman committed Sep 14, 2022
1 parent 4ed5e24 commit c5f79ae
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 390 deletions.
85 changes: 31 additions & 54 deletions CRISPResso2/CRISPRessoAggregateCORE.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import glob
from copy import deepcopy
from concurrent.futures import ProcessPoolExecutor, wait
from functools import partial
import sys
import argparse
import numpy as np
Expand All @@ -18,7 +19,7 @@
from CRISPResso2 import CRISPRessoShared
from CRISPResso2 import CRISPRessoPlot
from CRISPResso2 import CRISPRessoReport
from CRISPResso2.CRISPRessoMultiProcessing import get_max_processes
from CRISPResso2.CRISPRessoMultiProcessing import get_max_processes, run_plot


import logging
Expand Down Expand Up @@ -108,6 +109,13 @@ def main():
process_pool = ProcessPoolExecutor(n_processes)
process_results = []

plot = partial(
run_plot,
num_processes=n_processes,
process_pool=process_pool,
process_results=process_results,
)

#glob returns paths including the original prefix
all_files = []
for prefix in args.prefix:
Expand Down Expand Up @@ -491,13 +499,10 @@ def main():
'quantification_window_idxs': include_idxs,
'group_column': 'Folder',
}
if n_processes > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_nucleotide_quilt,
**nucleotide_quilt_input,
))
else:
CRISPRessoPlot.plot_nucleotide_quilt(**nucleotide_quilt_input)
plot(
CRISPRessoPlot.plot_nucleotide_quilt,
nucleotide_quilt_input,
)

plot_name = os.path.basename(this_window_nuc_pct_quilt_plot_name)
window_nuc_pct_quilt_plot_names.append(plot_name)
Expand Down Expand Up @@ -529,13 +534,10 @@ def main():
'quantification_window_idxs': include_idxs,
'group_column': 'Folder',
}
if n_processes > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_nucleotide_quilt,
**nucleotide_quilt_input,
))
else:
CRISPRessoPlot.plot_nucleotide_quilt(**nucleotide_quilt_input)
plot(
CRISPRessoPlot.plot_nucleotide_quilt,
nucleotide_quilt_input,
)

plot_name = os.path.basename(this_nuc_pct_quilt_plot_name)
nuc_pct_quilt_plot_names.append(plot_name)
Expand Down Expand Up @@ -571,13 +573,10 @@ def main():
'quantification_window_idxs': consensus_include_idxs,
'group_column': 'Folder',
}
if n_processes > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_nucleotide_quilt,
**nucleotide_quilt_input,
))
else:
CRISPRessoPlot.plot_nucleotide_quilt(**nucleotide_quilt_input)
plot(
CRISPRessoPlot.plot_nucleotide_quilt,
nucleotide_quilt_input,
)

plot_name = os.path.basename(this_nuc_pct_quilt_plot_name)
nuc_pct_quilt_plot_names.append(plot_name)
Expand Down Expand Up @@ -633,15 +632,10 @@ def main():
'plot_path': plot_path,
'title': modification_type,
}
if n_processes > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_allele_modification_heatmap,
**allele_modification_heatmap_input,
))
else:
CRISPRessoPlot.plot_allele_modification_heatmap(
**allele_modification_heatmap_input,
)
plot(
CRISPRessoPlot.plot_allele_modification_heatmap,
allele_modification_heatmap_input,
)

crispresso2_info['results']['general_plots']['allele_modification_heatmap_plot_names'].append(plot_name)
crispresso2_info['results']['general_plots']['allele_modification_heatmap_plot_paths'][plot_name] = plot_path
Expand All @@ -668,15 +662,10 @@ def main():
'plot_path': plot_path,
'title': modification_type,
}
if n_processes > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_allele_modification_line,
**allele_modification_line_input,
))
else:
CRISPRessoPlot.plot_allele_modification_line(
**allele_modification_line_input
)
plot(
CRISPRessoPlot.plot_allele_modification_line,
allele_modification_line_input,
)
crispresso2_info['results']['general_plots']['allele_modification_line_plot_names'].append(plot_name)
crispresso2_info['results']['general_plots']['allele_modification_line_plot_paths'][plot_name] = plot_path
crispresso2_info['results']['general_plots']['allele_modification_line_plot_titles'][plot_name] = 'CRISPRessoAggregate {0} Across Samples for {1}'.format(
Expand Down Expand Up @@ -778,13 +767,7 @@ def main():
'save_png': save_png,
'cutoff': args.min_reads_for_inclusion,
}
if n_processes > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_reads_total,
**reads_total_input,
))
else:
CRISPRessoPlot.plot_reads_total(**reads_total_input)
plot(CRISPRessoPlot.plot_reads_total, reads_total_input)

plot_name = os.path.basename(plot_root)
crispresso2_info['results']['general_plots']['summary_plot_root'] = plot_name
Expand All @@ -801,13 +784,7 @@ def main():
'save_png': save_png,
'cutoff': args.min_reads_for_inclusion,
}
if n_processes > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_unmod_mod_pcts,
**unmod_mod_pcts_input,
))
else:
CRISPRessoPlot.plot_unmod_mod_pcts(**unmod_mod_pcts_input)
plot(CRISPRessoPlot.plot_unmod_mod_pcts, unmod_mod_pcts_input)

plot_name = os.path.basename(plot_root)
crispresso2_info['results']['general_plots']['summary_plot_root'] = plot_name
Expand Down
109 changes: 39 additions & 70 deletions CRISPResso2/CRISPRessoBatchCORE.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from copy import deepcopy
from concurrent.futures import ProcessPoolExecutor, wait
from functools import partial
import sys
import traceback
from datetime import datetime
Expand Down Expand Up @@ -337,6 +338,13 @@ def main():
process_results = []
process_pool = ProcessPoolExecutor(n_processes_for_batch)

plot = partial(
CRISPRessoMultiProcessing.run_plot,
num_processes=n_processes_for_batch,
process_results=process_results,
process_pool=process_pool,
)

window_nuc_pct_quilt_plot_names = []
nuc_pct_quilt_plot_names = []
window_nuc_conv_plot_names = []
Expand Down Expand Up @@ -559,15 +567,10 @@ def main():
'sgRNA_intervals': sub_sgRNA_intervals,
'quantification_window_idxs': include_idxs,
}
if n_processes_for_batch > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_nucleotide_quilt,
**nucleotide_quilt_input,
))
else:
CRISPRessoPlot.plot_nucleotide_quilt(
**nucleotide_quilt_input,
)
plot(
CRISPRessoPlot.plot_nucleotide_quilt,
nucleotide_quilt_input,
)
plot_name = os.path.basename(this_window_nuc_pct_quilt_plot_name)
window_nuc_pct_quilt_plot_names.append(plot_name)
crispresso2_info['results']['general_plots']['summary_plot_titles'][plot_name] = 'sgRNA: ' + sgRNA + ' Amplicon: ' + amplicon_name
Expand All @@ -587,15 +590,10 @@ def main():
'sgRNA_intervals': sub_sgRNA_intervals,
'quantification_window_idxs': include_idxs,
}
if n_processes_for_batch > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_conversion_map,
**conversion_map_input,
))
else:
CRISPRessoPlot.plot_conversion_map(
**conversion_map_input,
)
plot(
CRISPRessoPlot.plot_conversion_map,
conversion_map_input,
)
plot_name = os.path.basename(this_window_nuc_conv_plot_name)
window_nuc_conv_plot_names.append(plot_name)
crispresso2_info['results']['general_plots']['summary_plot_titles'][plot_name] = 'sgRNA: ' + sgRNA + ' Amplicon: ' + amplicon_name
Expand All @@ -617,15 +615,10 @@ def main():
'sgRNA_intervals': consensus_sgRNA_intervals,
'quantification_window_idxs': include_idxs,
}
if n_processes_for_batch > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_nucleotide_quilt,
**nucleotide_plot_input,
))
else:
CRISPRessoPlot.plot_nucleotide_quilt(
**nucleotide_plot_input,
)
plot(
CRISPRessoPlot.plot_nucleotide_quilt,
nucleotide_quilt_input,
)
plot_name = os.path.basename(this_nuc_pct_quilt_plot_name)
nuc_pct_quilt_plot_names.append(plot_name)
crispresso2_info['results']['general_plots']['summary_plot_titles'][plot_name] = 'Amplicon: ' + amplicon_name
Expand All @@ -644,15 +637,10 @@ def main():
'sgRNA_intervals': consensus_sgRNA_intervals,
'quantification_window_idxs': include_idxs,
}
if n_processes_for_batch > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_conversion_map,
**conversion_map_input,
))
else:
CRISPRessoPlot.plot_conversion_map(
**conversion_map_input,
)
plot(
CRISPRessoPlot.plot_conversion_map,
conversion_map_input,
)
plot_name = os.path.basename(this_nuc_conv_plot_name)
nuc_conv_plot_names.append(plot_name)
crispresso2_info['results']['general_plots']['summary_plot_titles'][plot_name] = 'Amplicon: ' + amplicon_name
Expand All @@ -671,15 +659,10 @@ def main():
'fig_filename_root': this_nuc_pct_quilt_plot_name,
'save_also_png': save_png,
}
if n_processes_for_batch > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_nucleotide_quilt,
**nucleotide_quilt_input,
))
else:
CRISPRessoPlot.plot_nucleotide_quilt(
**nucleotide_quilt_input,
)
plot(
CRISPRessoPlot.plot_nucleotide_quilt,
nucleotide_quilt_input,
)
plot_name = os.path.basename(this_nuc_pct_quilt_plot_name)
nuc_pct_quilt_plot_names.append(plot_name)
crispresso2_info['results']['general_plots']['summary_plot_labels'][plot_name] = 'Composition of each base for the amplicon ' + amplicon_name
Expand All @@ -693,15 +676,10 @@ def main():
'conversion_nuc_to': args.conversion_nuc_to,
'save_also_png': save_png,
}
if n_processes_for_batch > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_conversion_map,
**conversion_map_input,
))
else:
CRISPRessoPlot.plot_conversion_map(
**conversion_map_input,
)
plot(
CRISPRessoPlot.plot_conversion_map,
conversion_map_input,
)
plot_name = os.path.basename(this_nuc_conv_plot_name)
nuc_conv_plot_names.append(plot_name)
crispresso2_info['results']['general_plots']['summary_plot_labels'][plot_name] = args.conversion_nuc_from + '->' + args.conversion_nuc_to +' conversion rates for the amplicon ' + amplicon_name
Expand Down Expand Up @@ -756,15 +734,10 @@ def main():
'plot_path': plot_path,
'title': modification_type,
}
if n_processes_for_batch > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_allele_modification_heatmap,
**allele_modification_heatmap_input,
))
else:
CRISPRessoPlot.plot_allele_modification_heatmap(
**allele_modification_heatmap_input,
)
plot(
CRISPRessoPlot.plot_allele_modification_heatmap,
allele_modification_heatmap_input,
)

crispresso2_info['results']['general_plots']['allele_modification_heatmap_plot_names'].append(plot_name)
crispresso2_info['results']['general_plots']['allele_modification_heatmap_plot_paths'][plot_name] = plot_path
Expand All @@ -791,13 +764,9 @@ def main():
'plot_path': plot_path,
'title': modification_type,
}
if n_processes_for_batch > 1:
process_results.append(process_pool.submit(
CRISPRessoPlot.plot_allele_modification_line,
**allele_modification_line_input,
))
CRISPRessoPlot.plot_allele_modification_line(
**allele_modification_line_input,
plot(
CRISPRessoPlot.plot_allele_modification_line,
allele_modification_line_input,
)

crispresso2_info['results']['general_plots']['allele_modification_line_plot_names'].append(plot_name)
Expand Down

0 comments on commit c5f79ae

Please sign in to comment.