In [None]:
from typing import List, Dict, Union, Tuple
import numpy as np
import matplotlib.pyplot as plt
import logging

logger = logging.getLogger(__name__)

# Custom decorators for logging and validation
def log_method_call(func):
    def wrapper(*args, **kwargs):
        logger.info(f"Calling {func.__name__}")
        return func(*args, **kwargs)
    return wrapper

def validate_fbands(func):
    def wrapper(self, *args, **kwargs):
        fbands = kwargs.get('fbands', self.fbands)
        if fbands != 'all' and fbands[-1] >= len(self.wdecomp.frequencies):
            raise ValueError(f'{fbands[-1]} is outside the total number of considered frequency bands: {self.wdecomp.frequencies}')
        return func(self, *args, **kwargs)
    return wrapper

class Wreconst0:
    def __init__(self, wdecomp: 'Wdecomp0', fbands: Union[List[float], str], subjects: List[str]):
        if fbands == 'all':
            logger.info(f'Reconstructs the signal from all available frequency bands: {wdecomp.frequencies}')
            fbands = list(range(len(wdecomp.frequencies)))
        elif fbands == []:
            logger.info(f"Empty set of bands: Nothing to reconstruct")
        else:
            fbands = sorted(set(fbands))  # remove duplicate values and sort
        
        if fbands and fbands[-1] >= len(wdecomp.frequencies):
            raise ValueError(f'{fbands[-1]} is outside the total number of considered frequency bands: {wdecomp.frequencies}')
        
        self.wdecomp = wdecomp
        self.fbands = fbands
        self.subjects = subjects
        self.recdictsubjs_ts = {subj: sum(wdecomp.recdictsubjs_coeffs_ts[subj][np.array(fbands)]) for subj in subjects}

    @log_method_call
    def add_subj(self, subj: str, recdictsubjs_ts: Dict[str, np.matrix]):
        self.subjects = sorted(self.subjects + [subj])
        self.recdictsubjs_ts[subj] = recdictsubjs_ts

    @log_method_call
    @validate_fbands
    def plotwreconstruct(self, title='', subjs='all'):
        if title == '':
            title = f'FEATURE: {self.wdecomp.feature}'
        
        if subjs == 'all':
            subjects = self.subjects
        elif subjs == []:
            logger.warning(f"Empty set of subjects: nothing will be calculated or plotted!")
            subjects = []
        else: 
            subjects = [f'subj{subj}' for subj in subjs]

        for subj in subjects:
            if subj not in self.subjects:
                raise ValueError(f'{subj} is outside the range of considered subjects: {self.subjects}')

        size = next(iter(self.recdictsubjs_ts.values())).shape[-1]
        t = np.linspace(0, size / self.wdecomp.sr, size)
        subgraphs_nr = len(subjects)

        plt.figure(figsize=(12, int(2 * subgraphs_nr)))

        for i, subj in enumerate(subjects):
            plt.subplot(subgraphs_nr, 1, i + 1)
            plt.plot(t, self.wdecomp.original[subj], label=f'Original Signal of {subj}')
            plt.legend()
            plt.title(f"{title}")
            plt.tight_layout()
    
        plt.figure(figsize=(12, int(2 * subgraphs_nr)))
        for j, subj in enumerate(subjects):
            plt.subplot(subgraphs_nr, 1, j + 1)
            plt.plot(t, self.recdictsubjs_ts[subj], label=f'Reconstructed {self.fbands}', linestyle='--')
            plt.legend()
            plt.title(f"subj{subj}: Frequencies {','.join([f'{intvl[0]} - {intvl[1]}' for intvl in calc_interval_list(self.fbands, self.wdecomp.frequencies)])} Hz \n"
                      f"period intervals {','.join([f'{intvl[0]} - {intvl[1]}' for intvl in calc_interval_list(self.fbands, self.wdecomp.periods, reverse=True)])} s")

            plt.tight_layout()
            
        plt.show()

    @log_method_call
    def bbootstrap(self, err=None, block_length=10.0, rseed=None, subjs='all'):
        return self.wdecomp.bbootstrap(err=err, block_length=block_length, rseed=rseed, subjs=subjs, fbands=[self.fbands])

# Refactoring `Wreconst` and `Wdecomp`

1. Use Property Decorators: 
   * Convert getter methods into properties using the @property decorator for cleaner access.

2. Use Type Hinting:
   * Improve readability and maintainability by adding type hints.

3. Refactor Common Patterns Using Decorators:
   * Create custom decorators for repeated patterns, such as logging or validation.

### Key Changes:

1. Decorator Functions:

    * log_method_call decorator for logging method calls.
    * validate_fbands decorator for validating frequency bands.

2. Type Hinting:

    * Added type hints to methods for better readability and maintainability.

3. Improved Readability:

    * Organized the code into clear sections and used meaningful variable names.

### Usage:

> You can use these classes as usual, and the decorators will automatically log method calls and validate frequency bands where applicable. 

In [None]:
from typing import List, Dict, Union, Tuple
import numpy as np
import pywt
import matplotlib.pyplot as plt
import logging

logger = logging.getLogger(__name__)

# Custom decorators for logging and validation
def log_method_call(func):
    def wrapper(*args, **kwargs):
        logger.info(f"Calling {func.__name__}")
        return func(*args, **kwargs)
    return wrapper

def validate_fbands(func):
    def wrapper(self, *args, **kwargs):
        fbands = kwargs.get('fbands', self.fbands)
        if fbands != 'all' and fbands[-1] >= len(self.wdecomp.frequencies):
            raise ValueError(f'{fbands[-1]} is outside the total number of considered frequency bands: {self.wdecomp.frequencies}')
        return func(self, *args, **kwargs)
    return wrapper

class Wreconst0:
    def __init__(self, wdecomp: 'Wdecomp0', fbands: Union[List[float], str], subjects: List[str]):
        if fbands == 'all':
            logger.info(f'Reconstructs the signal from all available frequency bands: {wdecomp.frequencies}')
            fbands = list(range(len(wdecomp.frequencies)))
        elif fbands == []:
            logger.info(f"Empty set of bands: Nothing to reconstruct")
        else:
            fbands = sorted(set(fbands))  # remove duplicate values and sort

        if fbands and fbands[-1] >= len(wdecomp.frequencies):
            raise ValueError(f'{fbands[-1]} is outside the total number of considered frequency bands: {wdecomp.frequencies}')

        self.wdecomp = wdecomp
        self.fbands = fbands
        self.subjects = subjects
        self.recdictsubjs_ts = {subj: sum(wdecomp.recdictsubjs_coeffs_ts[subj][np.array(fbands)]) for subj in subjects}

    @log_method_call
    def add_subj(self, subj: str, recdictsubjs_ts: Dict[str, np.matrix]):
        self.subjects = sorted(self.subjects + [subj])
        self.recdictsubjs_ts[subj] = recdictsubjs_ts

    @log_method_call
    @validate_fbands
    def plotwreconstruct(self, title='', subjs='all'):
        if title == '':
            title = f'FEATURE: {self.wdecomp.feature}'

        if subjs == 'all':
            subjects = self.subjects
        elif subjs == []:
            logger.warning(f"Empty set of subjects: nothing will be calculated or plotted!")
            subjects = []
        else:
            subjects = [f'subj{subj}' for subj in subjs]

        for subj in subjects:
            if subj not in self.subjects:
                raise ValueError(f'{subj} is outside the range of considered subjects: {self.subjects}')

        size = next(iter(self.recdictsubjs_ts.values())).shape[-1]
        t = np.linspace(0, size / self.wdecomp.sr, size)
        subgraphs_nr = len(subjects)

        plt.figure(figsize=(12, int(2 * subgraphs_nr)))

        for i, subj in enumerate(subjects):
            plt.subplot(subgraphs_nr, 1, i + 1)
            plt.plot(t, self.wdecomp.original[subj], label=f'Original Signal of {subj}')
            plt.legend()
            plt.title(f"{title}")
            plt.tight_layout()

        plt.figure(figsize=(12, int(2 * subgraphs_nr)))
        for j, subj in enumerate(subjects):
            plt.subplot(subgraphs_nr, 1, j + 1)
            plt.plot(t, self.recdictsubjs_ts[subj], label=f'Reconstructed {self.fbands}', linestyle='--')
            plt.legend()
            plt.title(f"subj{subj}: Frequencies {','.join([f'{intvl[0]} - {intvl[1]}' for intvl in calc_interval_list(self.fbands, self.wdecomp.frequencies)])} Hz \n"
                      f"period intervals {','.join([f'{intvl[0]} - {intvl[1]}' for intvl in calc_interval_list(self.fbands, self.wdecomp.periods, reverse=True)])} s")

            plt.tight_layout()

        plt.show()

    @log_method_call
    def bbootstrap(self, err=None, block_length=10.0, rseed=None, subjs='all'):
        return self.wdecomp.bbootstrap(err=err, block_length=block_length, rseed=rseed, subjs=subjs, fbands=[self.fbands])


class Wdecomp0:
    def __init__(self, sr: int, orig: Dict[str, np.ndarray], ts_subjs: np.matrix, subjects: Tuple[int, int], time_interval: Tuple[int, int], feature: str, wavelet: str, level: int):
        _feat = FEATURE.get(feature, feature)
        coeffs_subjs_ts = pywt.wavedec(ts_subjs.T, wavelet, level=level)

        _frequencies = [(round(sr / 2 ** (i + 2), 2) if i + 1 < len(coeffs_subjs_ts) else 0, round(sr / (2 ** (i + 1)), 2)) for i in reversed(range(len(coeffs_subjs_ts)))]
        _periods = [(round((2 ** (i + 1)) / sr, 2), round(2 ** (i + 2) / sr, 2) if i + 1 < len(coeffs_subjs_ts) else 'inf') for i in reversed(range(len(coeffs_subjs_ts)))]

        reccoeffs_subjs_ts = np.stack(
            [pywt.waverec(
                [np.zeros_like(coef_j) if coef_i is not coef_j else coef_j for coef_j in coeffs_subjs_ts],
                wavelet
            )
            for coef_i in coeffs_subjs_ts])

        recdictsubjs_coeffs_ts: Dict[str, np.matrix] = {f'subj{subjects[0] + j}': reccoeffs_subjs_ts[:, j, :] for j in range(reccoeffs_subjs_ts.shape[1])}
        recdictcoeffs_subjs_ts: Dict[str, np.matrix] = {i: reccoeffs_subjs_ts[i, ...] for i in range(reccoeffs_subjs_ts.shape[0])}
        assert reccoeffs_subjs_ts.shape[0] == level + 1
        original = {k: v[_feat] for k, v in orig.items()}

        self.frequencies = _frequencies
        self.periods = _periods
        self.recdictsubjs_coeffs_ts = recdictsubjs_coeffs_ts
        self.recdictcoeffs_subjs_ts = recdictcoeffs_subjs_ts
        self.original = original
        self.subjects = subjects
        self.time_interval = time_interval
        self.sr = sr
        self.feature = feature
        self.wavelet = wavelet
        self.level = level
        self._wreconstruct = {}

    @log_method_call
    def get_wreconstruct(self, subjs: Union[List, str] = 'all', fbands: Union[List, str] = 'all') -> Wreconst0:
        if subjs == 'all':
            subjs = [subj for subj in range(*self.subjects)]
            logger.debug(f"Reconstructing for all subjects: {subjs}")
        elif subjs == []:
            logger.warning(f"Empty set of subjects")

        subjects = []
        for subj in subjs:
            if subj < self.subjects[0] or subj >= self.subjects[1]:
                raise ValueError(f'{subj} is outside the range of considered subjects: {self.subjects}')
            subjects.append(f'subj{subj}')

        if fbands == 'all':
            logger.info(f'Reconstructs the signal from all available frequency bands: {self.frequencies}')
            fbands = [i for i in range(len(self.frequencies))]
        elif fbands == []:
            logger.info(f"Empty set of bands: Nothing to reconstruct")
            return None
        else:
            fbands = sorted(set(fbands))  # remove duplicate values and sort

        if fbands[-1] >= len(self.frequencies):
            raise ValueError(f'{fbands[-1]} is outside the total number of considered frequency bands: {self.frequencies}')

        if str(fbands) in self._wreconstruct.keys():
            _wreconst: Wreconst0 = self._wreconstruct[str(fbands)]
            logger.debug(f'Reconstruction for {fbands} already calculated for some subjects {_wreconst.subjects}')

            for subj in subjects:
                if subj in _wreconst.subjects:
                    logger.debug(f"get_wreconstruct: Using already calculated value for {subj}")
                    pass
                else:
                    logger.debug(f"Reconstructing frequency bands {fbands} for {subj}")
                    _wreconst.add_subj(subj, sum(self.recdictsubjs_coeffs_ts[subj][np.array(fbands)]))
        else:
            logger.debug(f'Reconstructing {subjects} from frequency bands {fbands}')
            self._wreconstruct[str(fbands)] = Wreconst0(self, fbands, subjects)

        return self._wreconstruct.get(str(fbands), None)

    @log_method_call
    def plotwdecomp(self, title='', subjs='all', fbands='all'):
        if title == '':
            title = f'FEATURE: {self.feature}'

        if subjs == 'all':
            subjs = [subj for subj in range(*self.subjects)]
        elif subjs == []:
            logger.warning(f"Empty set of subjects: nothing will be calculated or plotted!")

        subjects = []
        for subj in subjs:
            if subj < self.subjects[0] or subj >= self.subjects[1]:
                raise ValueError(f'{subj} is outside the range of considered subjects: {self.subjects}')
            subjects.append(f'subj{subj}')

        if fbands == 'all':
            fbands = [i for i in range(len(self.frequencies))]
        elif fbands == []:
            logger.info(f"Empty set of bands: only the original time series will be calculated and plotted!")

        for band in fbands:
            if band >= len(self.frequencies):
                raise ValueError(f'{band} is outside the total number of considered frequency bands: {self.frequencies}')

        size = self.time_interval[1] - self.time_interval[0]
        t = np.linspace(0, size / self.sr, size)

        plt.figure(figsize=(12, int(2 * len(subjs))))
        subgraphs_nr = len(subjects)

        for i, subj in enumerate(subjects):
            plt.subplot(subgraphs_nr, 1, i + 1)
            plt.plot(t, self.original[subj], label=f'Original Signal of {subj}')
            plt.legend()
            plt.title(f"{title}")
            plt.tight_layout()

        for band in fbands:
            plt.figure(figsize=(12, int(2 * len(subjs))))

            for j, subj in enumerate(subjects):
                plt.subplot(subgraphs_nr, 1, j + 1)
                plt.plot(t, self.recdictsubjs_coeffs_ts[subj][band], label=f'Reconstructed band {band}', linestyle='--')
                plt.legend()
                plt.title(f"subj{subj}: Frequency interval {self.frequencies[band]} Hz, period interval {self.periods[band]} s")

                plt.tight_layout()
        plt.show()

    @log_method_call
    def plotwreconstruct(self, title='', subjs='all', fbands='all'):
        _wreconst = self.get_wreconstruct(subjs, fbands)
        _wreconst.plotwreconstruct(title=title, subjs=subjs)

    @log_method_call
    def bbootstrap(self, err=None, block_length=10.0, rseed=None, subjs='all', fbands='all'):
        if subjs == 'all':
            subjs = [subj for subj in range(*self.subjects)]
        elif subjs == []:
            logger.warning(f"Empty set of subjects: nothing will be calculated or plotted!")

        subjects = []
        for subj in subjs:
            if subj < self.subjects[0] or subj >= self.subjects[1]:
                raise ValueError(f'{subj} is outside the range of considered subjects: {self.subjects}')
            subjects.append(f'subj{subj}')

        if fbands == 'all':
            fbands = [i for i in range(len(self.frequencies))]
        elif fbands == []:
            logger.info(f"Empty set of bands: only the original time series will be calculated and plotted!")

        block_lengths = {}
        magnitudes = {}
        _fbands = []
        for band in fbands:
            if isinstance(band, list):
                if max(band) >= len(self.frequencies):
                    raise ValueError(f'The range of {band} is outside the total number of considered frequency bands: {self.frequencies}')
                band_str = str(sorted(set(band)))
                block_lengths[band_str] = 2 * self.periods[min(band)][0]
                magnitudes[band_str] = self.get_wreconstruct(subjs=subjs, fbands=band).recdictsubjs_ts
                _fbands.append(band_str)

            elif isinstance(band, int):
                if band >= len(self.frequencies):
                    raise ValueError(f'{band} is outside the total number of considered frequency bands: {self.frequencies}')
                block_lengths[band] = 2 * self.periods[band][0]
                magnitudes[band] = {k: w[band] for k, w in self.recdictsubjs_coeffs_ts.items()}
                _fbands.append(band)

            else:
                raise ValueError(f'{band} must be an integer or a list of integers!')

        block_lengths['original'] = block_length
        magnitudes['original'] = self.original
        num_frames = next(iter(self.original.values())).shape[0]
        logger.debug(f"bbotstrap: num_frames = {num_frames}")

        mjd_boot = np.zeros(shape=(num_frames,))
        mag_boot = {band: {subj: np.zeros(shape=(num_frames,)) for subj in subjects} for band in _fbands + ['original']}
        err_boot = {subj: np.zeros(shape=(num_frames,)) for subj in subjects} if err is not None else err

        for band in _fbands + ['original']:
            np.random.seed(rseed)
            mjd: np.ndarray = np.arange(num_frames) / self.sr
            k = 0
            last_time = 0.0

            for max_idx in range(2, num_frames):
                if mjd[-1] - mjd[-max_idx] > block_lengths[band]:
                    break
            while k < num_frames:
                idx_start = np.random.randint(num_frames - max_idx - 1)
                for idx_end in range(idx_start + 1, num_frames):
                    if mjd[idx_end] - mjd[idx_start] > block_lengths[band] or k + idx_end - idx_start >= num_frames - 1:
                        break
                mjd_boot[k:k + idx_end - idx_start] = mjd[idx_start:idx_end] - mjd[idx_start] + last_time

                for subj in subjects:
                    mag_boot[band][subj][k:k + idx_end - idx_start] = magnitudes[band][subj][idx_start:idx_end]

                if err is not None:
                    err_boot[subj][k:k + idx_end - idx_start] = err[subj][idx_start:idx_end]
                last_time = mjd[idx_end] - mjd[idx_start] + last_time
                k += idx_end - idx_start
        return mjd_boot, mag_boot, err_boot

## Refactored `plotwdecomp` and `plotwreconstruct`

### Helper Function (plot_signals):

> This function takes the time array t, a dictionary of signals, a title, and labels for the signals. It plots each signal in a subplot.

### Refactored Methods:

> Both plotwdecomp and plotwreconstruct methods now call the plot_signals helper function to handle the plotting logic. This eliminates the repetitive plt blocks and makes the code more concise.


### Benefits:

* Maintainability: Easier to modify the plotting logic in one place.
* Readability: Cleaner and more readable code.
* Reusability: The helper function can be reused in other parts of the code if needed.


## Helper Function for Plotting:

In [None]:
def plot_signals(t: np.ndarray, signals: Dict[str, np.ndarray], title: str, labels: List[str]):
    """Helper function to plot signals."""
    subgraphs_nr = len(signals)
    plt.figure(figsize=(12, int(2 * subgraphs_nr)))

    for i, (label, signal) in enumerate(signals.items()):
        plt.subplot(subgraphs_nr, 1, i + 1)
        plt.plot(t, signal, label=label)
        plt.legend()
        plt.title(title)
        plt.tight_layout()

    plt.show()

## Refactor plotwdecomp in Wdecomp0 and plotwreconstruct in Wreconst0 to Use the Helper Function:

In [None]:
class Wdecomp0:
    # Other methods...
    
    @log_method_call
    def plotwdecomp(self, title='', subjs='all', fbands='all'):
        if title == '':
            title = f'FEATURE: {self.feature}'

        if subjs == 'all':
            subjs = [subj for subj in range(*self.subjects)]
        elif subjs == []:
            logger.warning(f"Empty set of subjects: nothing will be calculated or plotted!")

        subjects = [f'subj{subj}' for subj in subjs if self.subjects[0] <= subj < self.subjects[1]]

        if fbands == 'all':
            fbands = list(range(len(self.frequencies)))
        elif fbands == []:
            logger.info(f"Empty set of bands: only the original time series will be calculated and plotted!")

        size = self.time_interval[1] - self.time_interval[0]
        t = np.linspace(0, size / self.sr, size)

        # Plot original signals
        original_signals = {subj: self.original[subj] for subj in subjects}
        plot_signals(t, original_signals, title, subjects)

        # Plot reconstructed signals for each band
        for band in fbands:
            reconstructed_signals = {subj: self.recdictsubjs_coeffs_ts[subj][band] for subj in subjects}
            band_title = f"{title}: Frequency interval {self.frequencies[band]} Hz, period interval {self.periods[band]} s"
            plot_signals(t, reconstructed_signals, band_title, subjects)


class Wreconst0:
    def __init__(self, wdecomp: 'Wdecomp0', fbands: Union[List[float], str], subjects: List[str]):
        if fbands == 'all':
            logger.info(f'Reconstructs the signal from all available frequency bands: {wdecomp.frequencies}')
            fbands = list(range(len(wdecomp.frequencies)))
        elif fbands == []:
            logger.info(f"Empty set of bands: Nothing to reconstruct")
        else:
            fbands = sorted(set(fbands))  # remove duplicate values and sort

        if fbands and fbands[-1] >= len(wdecomp.frequencies):
            raise ValueError(f'{fbands[-1]} is outside the total number of considered frequency bands: {wdecomp.frequencies}')

        self.wdecomp = wdecomp
        self.fbands = fbands
        self.subjects = subjects
        self.recdictsubjs_ts = {subj: sum(wdecomp.recdictsubjs_coeffs_ts[subj][np.array(fbands)]) for subj in subjects}

    @log_method_call
    def add_subj(self, subj: str, recdictsubjs_ts: Dict[str, np.matrix]):
        self.subjects = sorted(self.subjects + [subj])
        self.recdictsubjs_ts[subj] = recdictsubjs_ts

    @log_method_call
    @validate_fbands
    def plotwreconstruct(self, title='', subjs='all'):
        if title == '':
            title = f'FEATURE: {self.wdecomp.feature}'

        if subjs == 'all':
            subjects = self.subjects
        elif subjs == []:
            logger.warning(f"Empty set of subjects: nothing will be calculated or plotted!")
            subjects = []
        else:
            subjects = [f'subj{subj}' for subj in subjs]

        for subj in subjects:
            if subj not in self.subjects:
                raise ValueError(f'{subj} is outside the range of considered subjects: {self.subjects}')

        size = next(iter(self.recdictsubjs_ts.values())).shape[-1]
        t = np.linspace(0, size / self.wdecomp.sr, size)

        # Plot original signals
        original_signals = {subj: self.wdecomp.original[subj] for subj in subjects}
        plot_signals(t, original_signals, title, subjects)

        # Plot reconstructed signals
        reconstructed_signals = {subj: self.recdictsubjs_ts[subj] for subj in subjects}
        band_title = f"{title}: Reconstructed {self.fbands}"
        plot_signals(t, reconstructed_signals, band_title, subjects)