In [None]:
import os
import cv2
import pandas as pd
from . import utils
from matplotlib import colors
import seaborn as sns
import matplotlib as mpl
import h5py
import natSF.mapping as mapping
from PIL import Image
import pickle
from .create_gratings import generate_sinusoidal_grating

rc = {'text.color': 'black',
      'axes.labelcolor': 'black',
      'xtick.color': 'black',
      'ytick.color': 'black',
      'axes.edgecolor': 'black',
      'font.family': 'Helvetica',
      'axes.linewidth': 1,
      'axes.labelpad': 5,
      'axes.titlepad': 5,
      'axes.spines.right': False,
      'axes.spines.top': False,
      'xtick.major.pad': 5,
      'xtick.major.width': 1,
      'ytick.major.width': 1,
      'lines.linewidth': 1,
      'font.size': 11,
      'axes.titlesize': 11,
      'axes.labelsize': 11,
      'xtick.labelsize': 9,
      'ytick.labelsize': 9,
      'legend.title_fontsize': 11,
      'legend.fontsize': 11,
      'figure.titlesize': 14,
      'figure.dpi': 72 * 3,
      'savefig.dpi': 300
      }
mpl.rcParams.update(rc)



class SteerablePyramidSF:


    def reshape_img(self, image):
        """
        Reshape the image to the required dimensions for the steerable pyramid.
        Assumes image is either a NumPy array or a PyTorch tensor.
        Returns:
            - torch.Tensor: Reshaped image tensor [1,1,n,n].
        """
        if isinstance(image, torch.Tensor):
            image = image.detach().cpu().numpy()  # Convert to NumPy if necessary

        # Ensure image is grayscale and has correct shape
        if len(image.shape) == 3:  # If RGB, convert to grayscale
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        # Resize if needed
        if image.shape != (self.img_size, self.img_size):
            print('Resizing image to', self.img_size, 'x', self.img_size)
            image = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_CUBIC)

        # Convert back to tensor
        image = torch.tensor(image, dtype=self.dtype).to(self.device).unsqueeze(0).unsqueeze(0)  # Shape: [1,1,H,W]
        return image

    def forward(self, image):
        """
        Processes an image through the steerable pyramid and returns its coefficients.
        """
        image_batch = self.reshape_img(image)
        pyr_coeffs = self.filters(image_batch)
        return pyr_coeffs  # Use the pre-initialized pyramid

    def normalize_filters(self, pyr_energy, is_magnitude=False):
        """
        Normalize filter output values across sf channels using L2 normalization.
        """
        normalized_dict = {}
        if is_magnitude is True:
            i = 2
        else:
            i = 4
        for (sf, ori), item in pyr_energy.items():
            normalized_dict[sf, ori] = item / (i**sf)
        return normalized_dict

    def get_energy_of_filters(self, pyr_coeffs, to_numpy=True, reshape=True, normalize=True, to_magnitude=False):
        """
        Compute the squared magnitude (energy) of the steerable pyramid coefficients.

        Args:
            pyr_coeffs: Pyramid coefficients from the forward pass.
            to_numpy (bool): Convert results to NumPy arrays.
            reshape (bool): Reshape back to the original image size.

        Returns:
            dict: Dictionary of squared magnitude filter responses.
        """
        if to_magnitude:
            i = 1
        else:
            i = 2
        pyr_energy = {(sf, ori): (pyr_coeffs[sf, ori].abs() ** i) for sf, ori in
                    product(range(self.n_sf), range(self.n_ori))}
        if to_numpy:
            pyr_energy = {k: v.cpu().numpy() for k, v in pyr_energy.items()}
        #TODO: This needs to be fixed. The downsample is not working.
        if reshape:
            pyr_energy = {k: v.reshape((self.img_size, self.img_size)) for k, v in pyr_energy.items()}
        if normalize:
            pyr_energy = self.normalize_filters(pyr_energy, is_magnitude=to_magnitude)
        pyr_energy_numpy = np.zeros((self.n_sf, self.n_ori, self.img_size, self.img_size))
        for i in range(self.n_sf):
            for j in range(self.n_ori):
                pyr_energy_numpy[i,j] = pyr_energy[(i,j)]
        return pyr_energy_numpy


    def visualize_filter_magnitudes(self, pyr_mags, title=None, cbar=False, share_color_range=False, pRF_loc=None, angle_in_radians=False, max_eccentricity=4.2, scale_factor=1, save_path=None):
        """
        Visualize the filter magnitude in the steerable pyramid.
        Args:
            pyr_mags (dict): Dictionary containing magnitude responses.
        """

        fig, axes = plt.subplots(self.n_sf, self.n_ori, 
                                 figsize=(self.n_ori*scale_factor, self.n_sf*scale_factor))
        if share_color_range:
            vmin = min(image.min() for image in pyr_mags.values())  # Smallest value across all images
            vmax = max(image.max() for image in pyr_mags.values())
            norm = colors.Normalize(vmin=vmin, vmax=vmax)

        for i in range(self.n_sf):
            for j in range(self.n_ori):
                ax = axes[i, j]
                img = pyr_mags[i,j]
                ax.axis('off')
                # Handle edge cases
                if share_color_range:
                    im = ax.imshow(img, cmap='gray', norm=norm)
                else:
                    im = ax.imshow(img, cmap='gray')
                # Add individual colorbar to each subplot
                if cbar is True:
                    cbar = plt.colorbar(im, ax=ax, format="%.2e", fraction=0.046, pad=0.04)
                    cbar.ax.tick_params(labelsize=8)  # Adjust tick label size if needed
                if pRF_loc is not None:
                    if isinstance(pRF_loc[0], list) or isinstance(pRF_loc[0], tuple):
                        my_colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange']
                        for c, pRF in enumerate(pRF_loc):
                            eccentricity, angle = pRF
                            row,col = mapping.find_pRF_loc_in_image(eccentricity=eccentricity,
                                                                         angle=angle,
                                                                         angle_in_radians=angle_in_radians,
                                                                         image_size=self.img_size,
                                                                         max_eccentricity=max_eccentricity)
                            ax.scatter(col,row, color=my_colors[c], s=30, marker='o')
                            ax.text(col,row, f'{pRF}', fontsize=10, color=my_colors[c])
                    else:
                        eccentricity, angle = pRF_loc
                        row,col = mapping.find_pRF_loc_in_image(eccentricity=eccentricity,
                                                                    angle=angle,
                                                                    angle_in_radians=angle_in_radians,
                                                                image_size=self.img_size,
                                                                    max_eccentricity=max_eccentricity)
                        ax.scatter(col,row, color='red', s=30, marker='o')  # Plot pRF location

        fig.suptitle(title)
        plt.tight_layout()
        plt.subplots_adjust(wspace=0.03, hspace=0.03)
        if save_path is not None:
            fig.savefig(save_path, bbox_inches='tight', pad_inches=0)
        plt.show()

    def find_energy_in_image_location(self, pyr_energy, x, y):
        energy_array = np.zeros((self.n_sf, self.n_ori))
        for i, v in product(range(self.n_sf), range(self.n_ori)):
            tmp = pyr_energy[i,v]
            val = tmp[x,y]
            energy_array[i,v] = val
        return energy_array

    def find_energy_in_pRFloc(self, pyr_energy, pRF_loc, max_eccentricity=4.2, angle_in_radians=False):
        energy_array = np.zeros((self.n_sf, self.n_ori))
        for i, v in product(range(self.n_sf), range(self.n_ori)):
            tmp = pyr_energy[i,v]
            val = mapping.find_pRF_value_in_image(tmp, pRF_loc[0], pRF_loc[1], angle_in_radians, max_eccentricity)
            energy_array[i,v] = val
        return energy_array

    def energy_heatmap(self, energy_array, ax=None, cbar=True, title=None, save_path=None):
        """
        Plots a grayscale heatmap of a 2D matrix, where white represents the highest value (energy)
        and black represents the lowest value.

        Parameters:
        - matrix: NumPy array or Pandas DataFrame (2D)
        """
        # Convert input to a NumPy array if it's a DataFrame
        if isinstance(energy_array, pd.DataFrame):
            energy_array = energy_array.to_numpy()
        if ax is None:
            # Create the heatmap plot
            figure, ax = plt.subplots()
        ax.grid(False)
        im = ax.imshow(energy_array, cmap='gray', alpha=1, interpolation='none',
                       aspect='equal')  # 'gray_r' reverses the colormap

        # Add thin black grid lines manually to separate blocks
        for i in range(energy_array.shape[0] + 1):  # Horizontal lines
            ax.hlines(i - 0.5, -0.5, energy_array.shape[1] - 0.5, color='yellow', alpha=0.5, linewidth=0.5)
        for j in range(energy_array.shape[1] + 1):  # Vertical lines
            ax.vlines(j - 0.5, -0.5, energy_array.shape[0] - 0.5, color='yellow', alpha=0.5, linewidth=0.5)

        # Add colorbar
        if cbar is True:
            cbar = plt.colorbar(im)
            cbar.set_label("Energy")
        # Set titles and labels
        ax.set_title(title)
        ax.set(xlabel='Orientation(\u00b0)', ylabel='Spatial Frequency (cyc/deg)')
        ax.set(xticks=range(self.n_ori),
               xticklabels=[str(int(np.rad2deg(k))) for k in self.find_ori_preference_for_each_filter()])  # Label columns
        ax.set(yticks=range(self.n_sf),
               yticklabels=[str(np.round(k,2)) for k in self.find_sf_preference_for_each_filter()])  # Label rows
        # Show the plot only if ax is None (standalone mode)
        if save_path is not None:
            plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
        if ax is None:
            plt.show()



def predict_beta_from_image(image, str_model, ecc, angle,  W):
    image = np.array(image)
    pyr_coeffs = str_model.forward(image)
    pyr_energy = str_model.get_energy_of_filters(pyr_coeffs,
                                                 to_magnitude=False,
                                                 to_numpy=True,
                                                 reshape=True,
                                                 normalize=True)
    E = str_model.find_energy_in_pRFloc(pyr_energy,
                                        pRF_loc=[ecc, angle],
                                        angle_in_radians=False)
    E = E.reshape(1,-1)

    return np.dot(E,W)



def sort_values_by_position(matrix):
    """
    Sorts a 2D NumPy array or Pandas DataFrame by value in descending order,
    returning a DataFrame with row, column, and value.

    Parameters:
    - matrix: NumPy array or Pandas DataFrame (2D)

    Returns:
    - DataFrame with sorted (row, column, value)
    """
    # Convert to Pandas DataFrame if input is a NumPy array
    if isinstance(matrix, np.ndarray):
        df = pd.DataFrame(matrix)
    else:
        df = matrix.copy()

    # Flatten and store (row, col, value) tuples
    sorted_values = sorted(
        [(row, col, df.iloc[row, col]) for row in range(df.shape[0]) for col in range(df.shape[1])],
        key=lambda x: x[2],  # Sort by value
        reverse=True  # Descending order
    )

    # Convert to a DataFrame
    sorted_df = pd.DataFrame(sorted_values, columns=["SF", "Ori", "Energy"])

    return sorted_df

def _dict_into_array(data_dict, image_size):

    n_sf = max([i[0] for i in data_dict.keys()]) + 1
    n_ori = max([i[1] for i in data_dict.keys()]) + 1

    # Ensure dictionary contains all required entries
    assert all((i, j) in data_dict for i in range(n_sf) for j in range(n_ori)), \
        "Dictionary is missing required (i, j) keys."

    if type(image_size) != int:
        raise ValueError('Image size is not a scalar.')
    # Create an empty NumPy array with the desired shape
    array_4d = np.empty((n_sf, n_ori, image_size, image_size), dtype=data_dict[0,0].dtype)

    # Populate the NumPy array using the dictionary values
    for i in range(n_sf):
        for j in range(n_ori):
            array_4d[i, j] = data_dict[i, j]
    return array_4d

def _array_into_dict(array_4d):
    """
    Converts a 4D NumPy array back into a dictionary with (i, j) keys.

    Parameters:
        array_4d (numpy.ndarray): A NumPy array with shape (n_sf, n_ori, image_size, image_size)

    Returns:
        dict: A dictionary where keys are (i, j) tuples and values are (image_size, image_size) NumPy arrays.
    """
    n_sf, n_ori, image_size, _ = array_4d.shape  # Unpacking dimensions

    # Create the dictionary
    data_dict = {(i, j): array_4d[i, j] for i in range(n_sf) for j in range(n_ori)}

    return data_dict



def sample_prf_location(ecc_range=(0.5, 4.2), angle_range=(0, 2 * np.pi), n_samples=1):
    """
    Simulates a pRF location by randomly sampling eccentricity (in visual degrees) and angle (in radians).

    Args:
        ecc_range (tuple): Min and max values for eccentricity in visual degrees (default: 0 to 10).
        angle_range (tuple): Min and max values for angle in radians (default: 0 to 2Ï€).

    Returns:
        tuple: (eccentricity, angle) where
               - eccentricity is a float within ecc_range (in visual degrees)
               - angle is a float within angle_range (in radians)
    """
    eccentricity = np.random.uniform(*ecc_range, n_samples)  # Eccentricity in visual degrees
    angle = np.random.uniform(*angle_range, n_samples)  # Angle in radians
    return eccentricity, angle




def make_Gaussian_2D(eccentricity, angle, sigma, img_size=512, screen_height=39.29, n_pixel_height=1080, visual_distance=176.5):
    """Generate pRF mask image as a normalized 2D Gaussian distribution in visual degrees. the 2D Gaussian will be centered at a specified eccentricity and angle with a size of sigma.
    Parameters:
    -eccentricity (float): The distance from the stimulus center to a voxel's center in visual degrees.
    -angle (float): The pRF angle in radians.
    -sigma (float): The pRF size parameter in visual angles
    -img_size (int, optional): The size of the square image in pixels. Default is 714 based on NSD settings.
    -screen_height (float, optional): The physical height of the display in centimeters. Default is 39.29 cm based on NSD settings.
    -n_pixel_height (int, optional): The number of pixels along the height of the display. Default is 1080 based on NSD settings.
    -visual_distance (float, optional): The viewing distance in centimeters. Default is 176.5 cm based on NSD settings.
    Returns:
    -np.ndarray: A 2D NumPy array representing the Gaussian distribution normalized to [0, 1]."""

    # polar to cartesian
    x,y = utils.polar_to_cartesian(eccentricity, angle)
    # pix_to_deg
    visual_angle = utils.pix_to_deg(img_size, screen_height, n_pixel_height, visual_distance)
    # Generate x and y coordinate grids
    img_x = np.linspace(-visual_angle/2, visual_angle/2, img_size)
    img_y = np.linspace(-visual_angle/2, visual_angle/2, img_size)
    X, Y = np.meshgrid(img_x, img_y)

    # Compute the 2D Gaussian
    gaussian = np.exp(-(((X - x)**2) / (2 * sigma**2) + ((Y - y)**2) / (2 * sigma**2)))

    # Normalize to [0, 1] range for visualization
    gaussian /= gaussian.max()
    return gaussian


# Compute radial frequency distribution
def radial_profile(magnitude_spectrum):
    # Get image size and center
    h, w = magnitude_spectrum.shape
    center = (h // 2, w // 2)

    # Compute distance of each point from center
    y, x = np.indices((h, w))
    r = np.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2).astype(np.int32)

    # Find the maximum radius that exists in the image
    max_radius = np.max(r)

    # Allocate arrays only up to the valid maximum radius
    radial_mean = np.zeros(max_radius + 1)  # Fix: Ensure it's large enough
    counts = np.zeros(max_radius + 1)

    # Accumulate values in radial bins
    for i in range(h):
        for j in range(w):
            radius = r[i, j]
            if radius < len(radial_mean):  # Fix: Ensure index is within bounds
                radial_mean[radius] += magnitude_spectrum[i, j]
                counts[radius] += 1

    # Normalize to avoid division by zero
    radial_mean = np.divide(radial_mean, counts, where=counts != 0)

    # Return radial frequencies and their corresponding magnitude values
    return np.arange(len(radial_mean)), radial_mean




def energy_heatmap(energy_array, ax=None, cbar=True, title=None,
                   xlabel='Orientation Filter',
                   ylabel='Spatial Frequency Filter',
                   xticks=None, xticklabels=None,
                   yticks=None, yticklabels=None):
    """
    Plots a grayscale heatmap of a 2D matrix, where white represents the highest value (energy)
    and black represents the lowest value.

    Parameters:
    - matrix: NumPy array or Pandas DataFrame (2D)
    """
    # Convert input to a NumPy array if it's a DataFrame
    if isinstance(energy_array, pd.DataFrame):
        energy_array = energy_array.to_numpy()
    if ax is None:
        # Create the heatmap plot
        figure, ax = plt.subplots()
    ax.grid(False)
    im = ax.imshow(energy_array, cmap='gray', alpha=1, interpolation='none',
                   aspect='equal')  # 'gray_r' reverses the colormap

    # Add thin black grid lines manually to separate blocks
    for i in range(energy_array.shape[0] + 1):  # Horizontal lines
        ax.hlines(i - 0.5, -0.5, energy_array.shape[1] - 0.5, color='yellow', alpha=0.5, linewidth=0.5)
    for j in range(energy_array.shape[1] + 1):  # Vertical lines
        ax.vlines(j - 0.5, -0.5, energy_array.shape[0] - 0.5, color='yellow', alpha=0.5, linewidth=0.5)

    # Add colorbar
    if cbar is True:
        cbar = plt.colorbar(im)
        cbar.set_label("Energy")
    # Set titles and labels
    ax.set_title(title)
    ax.set(xlabel=xlabel, ylabel=ylabel)
    if xticks is not None:
        ax.set(xticks=xticks)  # Label columns
    if yticks is not None:
        ax.set(yticks=yticks)
    if xticklabels is not None:
        ax.set(xticklabels=xticklabels)
    if yticklabels is not None:
        ax.set(yticklabels=yticklabels)  # Label rows
    # Show the plot only if ax is None (standalone mode)
    if ax is None:
        plt.show()


def plot_comparison(df, x, y, hue='freq_lvl',
                   lgd=None, logscale=True, xlim=None,
                   xticks=None, yticks=None, xticklabels=None, yticklabels=None,
                   xlabel=None, ylabel=None, title=None, ax=None, save_fig=None):
    """
    Plots a log-log scatterplot comparing local and predicted spatial frequency,
    with a diagonal reference line and equal aspect ratio.

    Parameters:
        df (DataFrame): Data to plot.
        x (str): Column name for x-axis values.
        y (str): Column name for y-axis values.
        hue (str): Column name for hue grouping.
    """
    if ax is None:
        fig, ax = plt.subplots()
    sns.set_theme(style='ticks')
    palette = sns.color_palette("hls")
    ax = sns.scatterplot(data=df, x=x, y=y, hue=hue, ax=ax)

    # Plot identity line (y = x)
    lims = [0.1, max(df[x].max(), df[y].max()) * 1.1]
    ax.plot(lims, lims, '--', color='black')

    if logscale:
        # Set log scale
        ax.set_xscale('log')
        ax.set_yscale('log')
    # Set equal aspect ratio
    ax.set_aspect('equal', adjustable='box')
    #legend
    if lgd is not None:
        ax.legend(title=lgd)


    # Labeling
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if xticks is not None:
        ax.set_xticks(xticks)
    if yticks is not None:
        ax.set_yticks(yticks)
    ax.set_title(title)
    if xticklabels is not None:
        ax.set_xticklabels(xticklabels)
    if yticklabels is not None:
        ax.set_yticklabels(yticklabels)
    if xlim is not None:
        ax.set_xlim(xlim)
        ax.set_ylim(xlim)

    plt.tight_layout()
    if save_fig is not None:
        fig.savefig(save_fig, bbox_inches='tight', pad_inches=0)

def prep_image(imgBrick, img_number):
    img = imgBrick[img_number, :, :, :]  # this loads a (3, 425, 425) array
    img = np.transpose(img, (0, 1, 2))  # convert to (425, 425, 3)
    img = Image.fromarray(img).convert('L')  # 'L' = grayscale
    img = np.array(img)
    return img

def process_single_image_energy(img_path, str_model, save_path,
                              to_magnitude=True, to_numpy=True, reshape=True, normalize=True):
    """
    Process a single image and save its pyramid energy.
    
    Args:
        img_number (int): Image number to process
        str_model: SteerablePyramidSF model instance
        imgBrick: HDF5 dataset containing the images
        save_dir (str): Directory to save the processed energy file
        to_magnitude (bool): Whether to compute magnitude instead of energy
        to_numpy (bool): Convert results to NumPy arrays
        reshape (bool): Reshape back to the original image size
        normalize (bool): Whether to normalize the filter responses
        
    Returns:
        bool: True if image was processed, False if skipped
    """    
    # Check if file already exists
    if os.path.exists(save_path):
        img_number = os.path.basename(img_path).split('stimuli_')[1].split('.npy')[0]
        print(f"File already exists for {img_number}, skipping...")
        return False
    
    # Load and process the image
    img = np.load(img_path)
    
    # Process through steerable pyramid
    coeffs = str_model.forward(img)
    pyr_energy = str_model.get_energy_of_filters(coeffs,
                                                to_magnitude=to_magnitude,
                                                to_numpy=to_numpy,
                                                reshape=reshape,
                                                normalize=normalize)
    
    # Save the file
    with open(save_path, 'wb') as f:
        pickle.dump(pyr_energy, f)
    return True

def process_and_save_image_energy(image_numbers, str_model, image_path, save_dir, 
                                to_magnitude=True, to_numpy=True, reshape=True, normalize=True):
    """
    Process a list of images and save their pyramid energy.
    
    Args:
        image_numbers (list): List of image numbers to process
        str_model: SteerablePyramidSF model instance
        image_path (str): Path to the HDF5 file containing images
        save_dir (str): Directory to save the processed energy files
        to_magnitude (bool): Whether to compute magnitude instead of energy
        to_numpy (bool): Convert results to NumPy arrays
        reshape (bool): Reshape back to the original image size
        normalize (bool): Whether to normalize the filter responses
        
    Returns:
        tuple: (processed_count, skipped_count) - Number of images processed and skipped
    """
    # Create directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    processed_count = 0
    skipped_count = 0
    
    # Open the HDF5 file
    with h5py.File(image_path, 'r') as f:
        imgBrick = f['imgBrick']
        
        # Loop through all image numbers
        for nat_img in image_numbers:
            save_path = os.path.join(save_dir, f'energy_img-{nat_img}.pkl')
            if process_single_image_energy(nat_img, str_model, imgBrick, save_path,
                                        to_magnitude, to_numpy, reshape, normalize):
                processed_count += 1
            else:
                skipped_count += 1
    
    print(f"Processing complete. Processed {processed_count} images, skipped {skipped_count} images.")
    return processed_count, skipped_count

def load_single_image_energy(load_path):
    """
    Load the pyramid energy for a single image.
    
    Args:
        image_number (int): The image number to load
        load_path (str): Path to the saved energy file
        
    Returns:
        The loaded pyramid energy data
    """
    try:
        with open(load_path, 'rb') as f:
            pyr_energy = pickle.load(f)
        return pyr_energy
    except Exception as e:
        print(f"Error loading energy")
        return None

def load_multiple_image_energy(image_numbers, load_dir):
    """
    Load pyramid energy for multiple images.
    
    Args:
        image_numbers (list): List of image numbers to load
        load_dir (str): Directory containing the saved energy files
        
    Returns:
        dict: Dictionary mapping image numbers to their pyramid energy data
    """
    loaded_data = {}
    skipped_count = 0
    
    for nat_img in image_numbers:
        load_path = os.path.join(load_dir, f'pyr_energy_{nat_img}.pkl')
        energy_data = load_single_image_energy(load_path)
        
        if energy_data is not None:
            loaded_data[nat_img] = energy_data
        else:
            skipped_count += 1
    
    print(f"Loading complete. Loaded {len(loaded_data)} images, skipped {skipped_count} images.")
    return loaded_data