Skip to content

Commit

Permalink
Use SciencePlots
Browse files Browse the repository at this point in the history
Signed-off-by: Devansh Agarwal <devansh.kv@gmail.com>
  • Loading branch information
devanshkv committed Nov 1, 2020
1 parent ae09be5 commit bda807b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 96 deletions.
6 changes: 1 addition & 5 deletions bin/your_h5plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from rich.logging import RichHandler
from rich.progress import Progress

from your.utils.plotter import get_params, plot_h5
from your.utils.plotter import plot_h5

os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
matplotlib.use("Agg")
Expand Down Expand Up @@ -113,10 +113,6 @@ def mapper(save, detrend_ft, publication, mad_filter, out_dir, h5_file):
else:
raise ValueError(f"Need either --files or --results_csv argument.")

params = get_params()

plt.rcParams.update(params)

with Pool(processes=values.nproc) as p:
max_ = len(h5_files)
func = partial(
Expand Down
116 changes: 25 additions & 91 deletions your/utils/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,71 +14,6 @@
from your.utils.math import smad_plotter


def figsize(scale, width_by_height_ratio):
"""
Create figure size either a full page or a half page figure
Args:
scale (float): 0.5 for half page figure, 1 for full page
width_by_height_ratio (float): ratio of width to height for the figure
Returns:
list: list of width and height
"""
fig_width_pt = (
513.17 # 469.755 # Get this from LaTeX using \the\textwidth
)
inches_per_pt = 1.0 / 72.27 # Convert pt to inch
golden_mean = (np.sqrt(5.0) - 1.0) / 2.0 # Aesthetic ratio (you could change this)
fig_width = fig_width_pt * inches_per_pt * scale # width in inches
fig_height = fig_width * golden_mean # height in inches
fig_size = [fig_width, width_by_height_ratio * fig_height]
return fig_size


def get_params(scale=0.5, width_by_height_ratio=1):
"""
Create a dictionary for pretty plotting
Args:
scale (float): 0.5 for half page figure, 1 for full page
width_by_height_ratio (float): ratio of width to height for the figure
Returns:
dict: dictionary of parameters
"""
params = {
"backend": "pdf",
"axes.labelsize": 10,
"lines.markersize": 4,
"font.size": 10,
"xtick.major.size": 6,
"xtick.minor.size": 3,
"ytick.major.size": 6,
"ytick.minor.size": 3,
"xtick.major.width": 0.5,
"ytick.major.width": 0.5,
"xtick.minor.width": 0.5,
"ytick.minor.width": 0.5,
"lines.markeredgewidth": 1,
"axes.linewidth": 1.2,
"legend.fontsize": 7,
"xtick.labelsize": 10,
"ytick.labelsize": 10,
"savefig.dpi": 200,
"path.simplify": True,
"font.family": "serif",
"font.serif": "Times",
"text.latex.preamble": [
r"\usepackage{amsmath}",
r"\usepackage{amsbsy}",
r"\DeclareMathAlphabet{\mathcal}{OMS}{cmsy}{m}{n}",
],
"figure.figsize": figsize(scale, width_by_height_ratio),
}
return params


def plot_h5(
h5_file,
save=True,
Expand All @@ -102,7 +37,9 @@ def plot_h5(
None
"""
with h5py.File(h5_file, "r") as f:
with h5py.File(h5_file, "r") as f, plt.style.context(
["science", "no-latex", "ieee"]
):
dm_time = np.array(f["data_dm_time"])
if detrend_ft:
freq_time = detrend(np.array(f["data_freq_time"])[:, ::-1].T)
Expand Down Expand Up @@ -187,11 +124,11 @@ def plot_h5(
filename = outdir + os.path.basename(h5_file)[:-3] + ".png"
else:
filename = h5_file[:-3] + ".png"
plt.savefig(filename, bbox_inches="tight")
plt.savefig(filename, bbox_inches="tight", dpi=300)
else:
plt.close()

return None
return None


def save_bandpass(
Expand All @@ -212,10 +149,6 @@ def save_bandpass(
freqs = your_object.chan_freqs
foff = your_object.your_header.foff

params = get_params()

plt.rcParams.update(params)

if not outdir:
outdir = "./"

Expand All @@ -227,22 +160,23 @@ def save_bandpass(
else:
bp_plot = outname

fig = plt.figure()
ax11 = fig.add_subplot(111)
if foff < 0:
ax11.invert_xaxis()

ax11.plot(freqs, bandpass, "k-", label="Bandpass")
if mask is not None:
if mask.sum():
logging.info("Flagged %d channels", mask.sum())
ax11.plot(freqs[mask], bandpass[mask], "ro", label="Flagged Channels")
ax11.set_xlabel("Frequency (MHz)")
ax11.set_ylabel("Arb. Units")
ax11.legend()

ax21 = ax11.twiny()
ax21.plot(chan_nos, bandpass, alpha=0)
ax21.set_xlabel("Channel Numbers")

return plt.savefig(bp_plot, bbox_inches="tight")
with plt.style.context(["science", "no-latex"]):
fig = plt.figure()
ax11 = fig.add_subplot(111)
if foff < 0:
ax11.invert_xaxis()

ax11.plot(freqs, bandpass, "k-", label="Bandpass")
if mask is not None:
if mask.sum():
logging.info("Flagged %d channels", mask.sum())
ax11.plot(freqs[mask], bandpass[mask], "r.", label="Flagged Channels")
ax11.set_xlabel("Frequency (MHz)")
ax11.set_ylabel("Arb. Units")
ax11.legend()

ax21 = ax11.twiny()
ax21.plot(chan_nos, bandpass, alpha=0)
ax21.set_xlabel("Channel Numbers")

return plt.savefig(bp_plot, bbox_inches="tight", dpi=300)

0 comments on commit bda807b

Please sign in to comment.