Skip to content

Commit

Permalink
Merge pull request #309 from erwanp/improve-import-speed
Browse files Browse the repository at this point in the history
Improve radis import speed
  • Loading branch information
erwanp committed Jul 7, 2021
2 parents 1a3094e + a654d6b commit ea7b606
Show file tree
Hide file tree
Showing 15 changed files with 572 additions and 517 deletions.
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ name: radis-env
channels:
- conda-forge
- astropy
- cantera
dependencies:
- python=3.8
- numpy
- scipy>=1.4.0
- matplotlib
- cantera # for chemical equilibrium computations
- cython
- pandas>=1.0.5
- pytables # for pandas to HDF5 export
Expand Down
5 changes: 3 additions & 2 deletions radis/io/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@

import numpy as np
import pandas as pd
from astropy import units as u
from astroquery.hitran import Hitran

import radis
from radis.db.classes import get_molecule, get_molecule_identifier
Expand Down Expand Up @@ -79,6 +77,9 @@ def fetch_astroquery(
:py:attr:`astroquery.query.BaseQuery.cache_location`
"""
from astropy import units as u
from astroquery.hitran import Hitran

# Check input
if not is_float(molecule):
mol_id = get_molecule_identifier(molecule)
Expand Down
3 changes: 2 additions & 1 deletion radis/lbl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@

import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from astropy import units as u
Expand Down Expand Up @@ -261,6 +260,8 @@ def plot_hist(self, dataframe="df0", what="int"):
which feature to plot. Default ``'S'`` (scaled linestrength). Could also
be ``'int'`` (reference linestrength intensity), ``'A'`` (Einstein coefficient)
"""
import matplotlib.pyplot as plt

assert dataframe in ["df0", "df1"]
plt.figure()
df = getattr(self, dataframe)
Expand Down
2 changes: 1 addition & 1 deletion radis/lbl/broadening.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
"""
from warnings import warn

import matplotlib.pyplot as plt
import numpy as np
from numba import float64, jit
from numpy import arange, exp
Expand Down Expand Up @@ -1504,6 +1503,7 @@ def plot_broadening(self, i=0, pressure_atm=None, mole_fraction=None, Tgas=None)
"""
# TODO #clean: make it a standalone function.

import matplotlib.pyplot as plt
from publib import fix_style, set_style

if pressure_atm is None:
Expand Down
2 changes: 1 addition & 1 deletion radis/misc/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
@author: erwan
"""

import matplotlib.pyplot as plt
import numpy as np


Expand Down Expand Up @@ -39,6 +38,7 @@ def split_and_plot_by_parts(w, I, *args, **kwargs):
ax.plot
"""

import matplotlib.pyplot as plt
from publib.tools import keep_color

# Get defaults
Expand Down
10 changes: 8 additions & 2 deletions radis/phys/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

import warnings

import astropy.units as u


def Unit(st, *args, **kwargs):
"""Radis evaluation of an unit, using :py:class:`~astropy.units.Unit`
Expand All @@ -29,6 +27,8 @@ def Unit(st, *args, **kwargs):
a += 0.1 * u("W/cm2/sr/nm")
"""

import astropy.units as u

try:
st = st.replace("µ", "u")
except AttributeError:
Expand Down Expand Up @@ -67,6 +67,8 @@ def conv2(quantity, fromunit, tounit):
want to let the users choose another output unit
"""

import astropy.units as u

try:
a = quantity * Unit(fromunit)
a = a.to(Unit(tounit))
Expand All @@ -88,6 +90,8 @@ def is_homogeneous(unit1, unit2):
units
"""

import astropy.units as u

try:
1 * Unit(unit1) + 1 * Unit(unit2)
return True
Expand Down Expand Up @@ -330,6 +334,8 @@ def convert_universal(
wavenumber is needed in case we convert from ~1/nm to ~1/cm-1 (requires
a change of variable in the integral)
"""
import astropy.units as u

Iunit0 = from_unit
Iunit = to_unit
try:
Expand Down
13 changes: 12 additions & 1 deletion radis/phys/units_astropy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import astropy.units as u
# -*- coding: utf-8 -*-
"""
-------------------------------------------------------------------------------
"""


def convert_and_strip_units(quantity, output_unit=None, digit=10):
Expand Down Expand Up @@ -37,7 +43,12 @@ def convert_and_strip_units(quantity, output_unit=None, digit=10):
TypeError
Raised when ``quantity`` is a astropy.units quantity and ``output_unit`` is ``None``.
"""

import astropy.units as u

if isinstance(quantity, u.Quantity):
# TODO : make it possible to test if not adimensionned, before loading.
# This would allow not to have to load Astropy.units on start-up?
if output_unit in (u.deg_C, u.imperial.deg_F, u.K):
quantity = quantity.to_value(output_unit, equivalencies=u.temperature())
elif isinstance(output_unit, (u.UnitBase, u.Quantity)):
Expand Down
9 changes: 5 additions & 4 deletions radis/spectrum/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@

from warnings import warn

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import gridspec
from matplotlib.widgets import MultiCursor
from publib import fix_style, set_style

from radis.misc.arrays import array_allclose
from radis.misc.basics import compare_dict, compare_lists
Expand Down Expand Up @@ -689,6 +685,11 @@ def plot_diff(
:meth:`~radis.spectrum.spectrum.compare_with`
"""

import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib.widgets import MultiCursor
from publib import fix_style, set_style

if (not show) and (
not save
): # I added this line to avoid calculus in the case there is nothing to do (Minou)
Expand Down
10 changes: 7 additions & 3 deletions radis/spectrum/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,8 @@
from warnings import warn

import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import Cursor
from numpy import abs, diff
from publib import fix_style, set_style

# from radis.lbl.base import print_conditions
from radis.misc.arrays import count_nans, evenly_distributed, nantrapz
Expand Down Expand Up @@ -1441,6 +1438,9 @@ def plot(
:ref:`the Spectrum page <label_spectrum>`
"""

import matplotlib.pyplot as plt
from publib import fix_style, set_style

# Deprecated
if "plot_medium" in kwargs:
show_medium = kwargs.pop("plot_medium")
Expand Down Expand Up @@ -1568,6 +1568,8 @@ def clean_error_msg(string):
fig.cursor
# if already exist, do not add again
except AttributeError:
from matplotlib.widgets import Cursor

fig.cursor = Cursor(fig.gca(), useblit=True, color="r", lw=1, alpha=0.2)

# ... Add Ruler
Expand Down Expand Up @@ -1783,6 +1785,8 @@ def plot_populations(
kwargs: **dict
are forwarded to the plot
"""
import matplotlib.pyplot as plt
from publib import fix_style, set_style

# Check input, get defaults
pops = self.populations
Expand Down
8 changes: 7 additions & 1 deletion radis/test/io/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@

import pytest

from radis.io.query import CACHE_FILE_NAME, Hitran, fetch_astroquery
from radis.misc.warning import DeprecatedFileWarning


# ignored by pytest with argument -m "not needs_connection"
@pytest.mark.needs_connection
def test_fetch_astroquery(verbose=True, *args, **kwargs):
""" Test astroquery """
from radis.io.query import fetch_astroquery

df = fetch_astroquery("CO2", 1, 2200, 2400, verbose=verbose, cache=False)

assert df.iloc[0].id == 2
Expand All @@ -30,6 +31,8 @@ def test_fetch_astroquery(verbose=True, *args, **kwargs):
@pytest.mark.needs_connection
def test_fetch_astroquery_empty(verbose=True, *args, **kwargs):
""" Test astroquery: get a spectral range where there are no lines"""
from radis.io.query import fetch_astroquery

df = fetch_astroquery(
2, 1, 25000, 50000, verbose=verbose, cache=False
) # 200-400 nm
Expand All @@ -48,6 +51,9 @@ def test_fetch_astroquery_cache(verbose=True, *args, **kwargs):
- Cache file is loaded.
- Different metadata raises an error
"""
from astroquery.hitran import Hitran

from radis.io.query import CACHE_FILE_NAME, fetch_astroquery

df = fetch_astroquery(
"CO2",
Expand Down
19 changes: 19 additions & 0 deletions radis/test/tools/test_gascomp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 7 12:55:28 2021
@author: erwan
Test the :py:func:`~radis.tools.gascomp.get_eq_mole_fraction` function.
"""


def test_get_eq_mole_fraction(*args, **kwargs):
from radis.misc.basics import all_in
from radis.tools.gascomp import get_eq_mole_fraction

gas = get_eq_mole_fraction("CO2:1", 3000, 1e5)

assert all_in(["C", "CO", "CO2", "O", "O2"], gas.keys())

return True
7 changes: 6 additions & 1 deletion radis/tools/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ class to manage them all.
from warnings import warn

import json_tricks
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
Expand Down Expand Up @@ -905,6 +904,8 @@ def plot_spec(file, what="radiance", title=True, **kwargs):
:py:meth:`~radis.spectrum.spectrum.Spectrum.plot`
"""

import matplotlib.pyplot as plt

if isinstance(file, str):
s = load_spec(file)
elif isinstance(file, Spectrum):
Expand Down Expand Up @@ -1636,6 +1637,8 @@ def plot(self, nfig=None, legend=True, **kwargs):
Spectrum :py:meth:`~radis.spectrum.spectrum.Spectrum.plot` method
"""

import matplotlib.pyplot as plt

fig = plt.figure(num=nfig)
ax = fig.gca()
for s in self:
Expand Down Expand Up @@ -1668,6 +1671,8 @@ def plot_cond(self, cond_x, cond_y, z_value=None, nfig=None):
"""
# %%

import matplotlib.pyplot as plt

x = self.df[cond_x]
y = self.df[cond_y]

Expand Down
19 changes: 8 additions & 11 deletions radis/tools/gascomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,6 @@
"""

from radis.misc.utils import NotInstalled

try:
import cantera as ct
except:
ct = NotInstalled(
"cantera",
"Cantera is needed to calculate equilibrium mole fractions"
+ ". Install with `pip install cantera`",
)


def get_eq_mole_fraction(initial_mixture, T_K, p_Pa):
"""Calculates chemical equilibrium mole fraction at temperature T, using
Expand Down Expand Up @@ -65,6 +54,14 @@ def get_eq_mole_fraction(initial_mixture, T_K, p_Pa):
[CANTERA]_
"""

try:
import cantera as ct
except ImportError as err:
raise ImportError(
"Cantera is needed to calculate equilibrium mole fractions"
+ ". Install with `pip install cantera` or (better) `conda install -c cantera cantera`",
) from err

# %% Init Cantera
g = ct.Solution("gri30.xml")
g.TPX = T_K, p_Pa, initial_mixture
Expand Down
Loading

0 comments on commit ea7b606

Please sign in to comment.