From 3a03b5516b0bdc4403deabd6e5eaadf2a820bd14 Mon Sep 17 00:00:00 2001 From: Neil Vaytet Date: Wed, 19 Apr 2023 23:47:48 +0200 Subject: [PATCH] add type hints --- src/tof/chopper.py | 16 ++++++++-------- src/tof/detector.py | 10 ++++++---- src/tof/model.py | 16 ++++++++++++---- src/tof/pulse.py | 28 +++++++++++++++------------- src/tof/units.py | 18 ++++++++++-------- 5 files changed, 51 insertions(+), 37 deletions(-) diff --git a/src/tof/chopper.py b/src/tof/chopper.py index 3898f95..07eb7f7 100644 --- a/src/tof/chopper.py +++ b/src/tof/chopper.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import List, Union +from typing import List, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -36,25 +36,25 @@ def __init__( self._mask = None @property - def omega(self): + def omega(self) -> float: return 2.0 * np.pi * self.frequency @property - def open_times(self): + def open_times(self) -> np.ndarray: return (self.open + self.phase) / self.omega @property - def close_times(self): + def close_times(self) -> np.ndarray: return (self.close + self.phase) / self.omega @property - def tofs(self): + def tofs(self) -> np.ndarray: return units.s_to_us(self._arrival_times[self._mask]) - def hist(self, bins=300): + def hist(self, bins: Union[int, np.ndarray] = 300) -> Tuple[np.ndarray, np.ndarray]: return np.histogram(self.tofs, bins=bins) - def plot(self, bins=300): + def plot(self, bins: Union[int, np.ndarray] = 300) -> Plot: h, edges = self.hist(bins=bins) fig, ax = plt.subplots() x = np.concatenate([edges, edges[-1:]]) @@ -66,7 +66,7 @@ def plot(self, bins=300): ax.set_title(f"Chopper: {self.name}") return Plot(fig=fig, ax=ax) - def __repr__(self): + def __repr__(self) -> str: return ( f"Chopper(name={self.name}, distance={self.distance}, " f"frequency={self.frequency}, phase={self.phase}, " diff --git a/src/tof/detector.py b/src/tof/detector.py index 9b50003..7451a98 100644 --- a/src/tof/detector.py +++ b/src/tof/detector.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from typing import Tuple, Union + import matplotlib.pyplot as plt import numpy as np @@ -16,13 +18,13 @@ def __init__(self, distance: float = 0.0, name: str = "detector"): self._mask = None @property - def tofs(self): + def tofs(self) -> np.ndarray: return s_to_us(self._arrival_times[self._mask]) - def hist(self, bins=300): + def hist(self, bins: Union[int, np.ndarray] = 300) -> Tuple[np.ndarray, np.ndarray]: return np.histogram(self.tofs, bins=bins) - def plot(self, bins=300): + def plot(self, bins: Union[int, np.ndarray] = 300) -> Plot: h, edges = self.hist(bins=bins) fig, ax = plt.subplots() x = np.concatenate([edges, edges[-1:]]) @@ -34,5 +36,5 @@ def plot(self, bins=300): ax.set_title(f"Detector: {self.name}") return Plot(fig=fig, ax=ax) - def __repr__(self): + def __repr__(self) -> str: return f"Detector(name={self.name}, distance={self.distance})" diff --git a/src/tof/model.py b/src/tof/model.py index d2fd2f3..e98a8bb 100644 --- a/src/tof/model.py +++ b/src/tof/model.py @@ -2,18 +2,26 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from itertools import chain +from typing import List, Union import matplotlib.pyplot as plt import numpy as np from matplotlib.collections import LineCollection +from .chopper import Chopper from .detector import Detector +from .pulse import Pulse from .tools import Plot from .units import s_to_us class Model: - def __init__(self, choppers, detectors, pulse): + def __init__( + self, + choppers: Union[Chopper, List[Chopper]], + detectors: Union[Detector, List[Detector]], + pulse: Pulse, + ): self.choppers = choppers if not isinstance(self.choppers, (list, tuple)): self.choppers = [self.choppers] @@ -22,7 +30,7 @@ def __init__(self, choppers, detectors, pulse): self.detectors = [self.detectors] self.pulse = pulse - def run(self, npulses=1): + def run(self, npulses: int = 1): # TODO: ray-trace multiple pulses components = sorted( chain(self.choppers, self.detectors), @@ -45,7 +53,7 @@ def run(self, npulses=1): comp._mask = combined initial_mask = combined - def plot(self, max_rays=1000): + def plot(self, max_rays: int = 1000) -> Plot: fig, ax = plt.subplots() furthest_detector = max(self.detectors, key=lambda d: d.distance) tofs = furthest_detector.tofs @@ -108,7 +116,7 @@ def plot(self, max_rays=1000): ax.set_ylabel("Distance (m)") return Plot(fig=fig, ax=ax) - def __repr__(self): + def __repr__(self) -> str: return ( f"Model(choppers={self.choppers},\n " f"detectors={self.detectors},\n " diff --git a/src/tof/pulse.py b/src/tof/pulse.py index fa4c4f9..02ad58b 100644 --- a/src/tof/pulse.py +++ b/src/tof/pulse.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from typing import Union + import matplotlib.pyplot as plt import numpy as np @@ -15,11 +17,11 @@ def __init__( tmax: float = None, lmin: float = None, lmax: float = None, - neutrons=1_000_000, - kind=None, - p_wav=None, - p_time=None, - sampling_resolution=10000, + neutrons: int = 1_000_000, + kind: str = None, + p_wav: np.ndarray = None, + p_time: np.ndarray = None, + sampling_resolution: int = 10000, ): self.kind = kind self.neutrons = neutrons @@ -61,17 +63,11 @@ def __init__( self.speeds = units.wavelength_to_speed(self.wavelengths) self.energies = units.speed_to_mev(self.speeds) - def __repr__(self): - return ( - f"Pulse(tmin={self.tmin}, tmax={self.tmax}, lmin={self.lmin}, " - f"lmax={self.lmax}, neutrons={self.neutrons}, kind={self.kind})" - ) - @property - def duration(self): + def duration(self) -> float: return self.tmax - self.tmin - def plot(self, bins=300): + def plot(self, bins: Union[int, np.ndarray] = 300) -> Plot: fig, ax = plt.subplots(1, 2) for i, (data, label) in enumerate( zip([self.birth_times, self.wavelengths], ["Time (s)", "Wavelength (Å)"]) @@ -86,3 +82,9 @@ def plot(self, bins=300): size = fig.get_size_inches() fig.set_size_inches(size[0] * 2, size[1]) return Plot(fig=fig, ax=ax) + + def __repr__(self) -> str: + return ( + f"Pulse(tmin={self.tmin}, tmax={self.tmax}, lmin={self.lmin}, " + f"lmax={self.lmax}, neutrons={self.neutrons}, kind={self.kind})" + ) diff --git a/src/tof/units.py b/src/tof/units.py index 2161916..b4a225f 100644 --- a/src/tof/units.py +++ b/src/tof/units.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from typing import Union + import numpy as np mass = 1.674927471e-27 # Neutron mass in kg @@ -8,33 +10,33 @@ mev = 1.602176634e-22 # meV to Joule -def deg_to_rad(x): +def deg_to_rad(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: return np.radians(x) -def rad_to_deg(x): +def rad_to_deg(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: return np.degrees(x) -def us_to_s(x): +def us_to_s(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: return x * 1.0e-6 -def s_to_us(x): +def s_to_us(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: return x * 1.0e6 -def speed_to_wavelength(x): +def speed_to_wavelength(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: return 1.0 / (alpha * x) -def wavelength_to_speed(x): +def wavelength_to_speed(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: return 1.0 / (alpha * x) -def speed_to_mev(x): +def speed_to_mev(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: return mass * x * x / mev -def mev_to_speed(x): +def mev_to_speed(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: return np.sqrt(mev * x / mass)