In [2]:
#V1.2.10.25

import os
import numpy as np
import tifffile as tiff
from scipy.optimize import curve_fit
import logging
from typing import Tuple, Dict, Optional
import pandas as pd
from PIL import Image
import cv2
from skimage.draw import polygon
from roifile import roiread
from scipy.signal import find_peaks
from scipy.stats import linregress
import matplotlib.pyplot as plt
import pickle
import concurrent.futures
import time
from datetime import datetime

# Configure logging
logger = logging.getLogger(__name__)
logging.basicConfig()
logger.setLevel(logging.DEBUG)

class CalciumImagingAnalysis:
    def __init__(self, base_dir: str, video_dims: Tuple[int, int] = (1080, 1920)):
        self.base_dir = base_dir
        self.video_dims = video_dims
        self.memmap_path = None
        self.current_base_name = None
        self._video_cache = None
        self.processing_times = {}  # Add timing dictionary
        self.background_values = {}  # Add background values dictionary

    def log_time(self, stage: str, start_time: float):
        """Log the elapsed time for a processing stage"""
        elapsed = time.time() - start_time
        self.processing_times[stage] = elapsed
        logger.info(f"Time for {stage}: {elapsed:.2f} seconds")
        return elapsed
    
    def find_matching_files(self) -> list:
        """
        Find matching .tif and .zip files based on concentration patterns.
        
        Returns:
            list: List of tuples containing matching (tif_path, zip_path) pairs
        """
        concentrations = ['_0um', '_10um', '_25um']
        matches = []
        
        tif_files = [f for f in os.listdir(self.base_dir) if f.endswith('.tif')]
        zip_files = [f for f in os.listdir(self.base_dir) if f.endswith('.zip')]
        
        for tif_file in tif_files:
            tif_path = os.path.join(self.base_dir, tif_file)
            # Find concentration pattern in tif file
            conc_pattern = next((c for c in concentrations if c in tif_file), None)
            if conc_pattern:
                # Look for matching zip file with same concentration
                matching_zip = next((f for f in zip_files if conc_pattern in f), None)
                if matching_zip:
                    zip_path = os.path.join(self.base_dir, matching_zip)
                    matches.append((tif_path, zip_path))
        
        return matches

    def process_folder(self):
        """Process all matching .tif and .zip file pairs in the base directory."""
        matching_pairs = self.find_matching_files()
        total_start = time.time()
        
        for tif_path, roi_path in matching_pairs:
            pair_start = time.time()
            logger.info(f"\nStarting processing of pair at {datetime.now()}:")
            logger.info(f"TIF: {os.path.basename(tif_path)}")
            logger.info(f"ROI: {os.path.basename(roi_path)}")
            
            try:
                # Reset timing dictionary for new pair
                self.processing_times = {}
                
                # Process each stage with timing
                stage_start = time.time()
                memmap_path = self.exponential_correct_tif(tif_path)
                self.log_time("Bleach Correction", stage_start)
                
                stage_start = time.time()
                masks_dir = self.convert_rois_to_masks(roi_path)
                self.log_time("ROI Conversion", stage_start)
                
                stage_start = time.time()
                results = self.analyze_fluorescence(masks_dir)
                self.log_time("Fluorescence Analysis", stage_start)
                
                # Log total time for this pair
                pair_time = self.log_time(f"Total time for {os.path.basename(tif_path)}", pair_start)
                logger.info(f"Successfully processed pair in {pair_time:.2f} seconds")
                
            except Exception as e:
                logger.error(f"Error processing {os.path.basename(tif_path)}: {str(e)}")
        
        # Log total processing time
        total_time = time.time() - total_start
        logger.info(f"\nTotal processing time for all pairs: {total_time:.2f} seconds")

    
    def get_output_paths(self, input_tif: str) -> Dict[str, str]:
        """
        Generate standardized output paths based on input .tif filename.
        
        Args:
            input_tif (str): Path to input .tif file
            
        Returns:
            Dict[str, str]: Dictionary containing all output paths
        """
        # Extract the base name without extension
        self.current_base_name = os.path.splitext(os.path.basename(input_tif))[0]
        output_dir = os.path.dirname(input_tif)
        
        return {
            'memmap': os.path.join(output_dir, f"{self.current_base_name}_corrected.dat"),
            'masks_dir': os.path.join(output_dir, f"{self.current_base_name}_masks"),
            'dff_traces': os.path.join(output_dir, f"{self.current_base_name}_dff_traces.pkl"),
            'metrics': os.path.join(output_dir, f"{self.current_base_name}_fluorescence_metrics.xlsx")
        }

        
    def exp(self, x: np.ndarray, a: float, b: float) -> np.ndarray:
        return a * np.exp(-b * x)

    def bi_exp(self, x: np.ndarray, a: float, b: float, c: float, d: float) -> np.ndarray:
        return (a * np.exp(-b * x)) + (c * np.exp(-d * x))
        
    def find_background(self, images: np.ndarray, roi_shape: Tuple[int, int]) -> Tuple[np.ndarray, Tuple[int, int]]:
        """Modified to return both background values and ROI center"""
        t, y, x = images.shape[0], images.shape[1], images.shape[2]
        first_frame = images[0, :, :]
        min_avg_fluorescence = np.inf
        min_roi_center = None
        
        for i in range(0, y - roi_shape[0], roi_shape[0]):
            for j in range(0, x - roi_shape[1], roi_shape[1]):
                roi = first_frame[i:i + roi_shape[0], j:j + roi_shape[1]]
                avg_fluorescence = np.mean(roi)
                
                if avg_fluorescence < min_avg_fluorescence:
                    min_avg_fluorescence = avg_fluorescence
                    min_roi_center = (i + roi_shape[0] // 2, j + roi_shape[1] // 2)

        background = np.mean(images[:, 
                           min_roi_center[0] - roi_shape[0] // 2: min_roi_center[0] + roi_shape[0] // 2,
                           min_roi_center[1] - roi_shape[1] // 2: min_roi_center[1] + roi_shape[1] // 2],
                           axis=(1, 2))
        
        return background, min_roi_center

    def exponential_correct_tif(self,
            input_tif: str,
            method: str = "bi",
            contrast_limits: Tuple[int, int] = (0, 65535),
            roi_shape: Tuple[int, int] = (10, 10)) -> str:
        """
        Performs bleach correction on a .tif file.
        
        Returns:
            str: Path to the created memmap file
        """
        paths = self.get_output_paths(input_tif)
        self.memmap_path = paths['memmap']
        
        with tiff.TiffFile(input_tif) as tif:
            images = tif.asarray()

        dtype = np.uint16
        shape = images.shape

        assert 3 <= len(shape) <= 4, f"Expected 3D or 4D stack, got {len(shape)}D"

        func = self.exp if method == "mono" else self.bi_exp
        axes = tuple(range(1, len(shape)))
        I_mean = np.mean(images, axis=axes)

        x_data = np.arange(shape[0])
        try:
            popt, _ = curve_fit(func, x_data, I_mean)
            f_ = func(x_data, *popt)
        except (ValueError, RuntimeError, Warning):
            f_ = np.ones_like(I_mean)

        f = f_ / np.max(f_)
        f = f.reshape((-1,) + (1,) * (len(shape) - 1))

        corrected_memmap = np.memmap(self.memmap_path, dtype=dtype, mode='w+', shape=shape)
        corrected_memmap[:] = np.clip(images / f, 0, 65535)
        corrected_memmap[:] = np.clip(corrected_memmap, contrast_limits[0], contrast_limits[1])

        logger.info(f"Corrected image saved as a memmap at {self.memmap_path}")
        return self.memmap_path

    def convert_rois_to_masks(self, roi_zip_path: str) -> str:
        """
        Converts ROIs to binary masks.
        
        Returns:
            str: Path to the masks directory
        """
        paths = self.get_output_paths(self.memmap_path)
        output_dir = paths['masks_dir']
        os.makedirs(output_dir, exist_ok=True)
        
        rois = roiread(roi_zip_path)
        
        for i, roi in enumerate(rois):
            mask = np.zeros(self.video_dims, dtype=np.uint16)
            coords = roi.coordinates()
            rr, cc = polygon(coords[:, 1], coords[:, 0], self.video_dims)
            mask[rr, cc] = 255
            
            output_path = os.path.join(output_dir, f"{self.current_base_name}_mask_{i+1:03d}.png")
            Image.fromarray(mask).save(output_path)
            logger.info(f"Saved mask: {output_path}")
            
        return output_dir

    # Define function to save dF/F traces externally (using pickle)
    def save_dff_traces(self, dff_traces: Dict[str, np.ndarray], file_path: str):
        """Save dF/F traces to a pickle file"""
        with open(file_path, 'wb') as f:
            pickle.dump(dff_traces, f)
        logger.info(f"Saved dF/F traces to {file_path}")

    # Define function to load dF/F traces from a pickle file
    def load_dff_traces(self, file_path: str) -> Dict[str, np.ndarray]:
        """Load dF/F traces from a pickle file"""
        with open(file_path, 'rb') as f:
            dff_traces = pickle.load(f)
        logger.info(f"Loaded dF/F traces from {file_path}")
        return dff_traces
    
    def analyze_fluorescence(self, 
            masks_dir: str,
            baseline_range: Tuple[int, int] = (0, 200),
            analysis_range: Tuple[int, int] = (233, 580),
            roi_shape: Tuple[int, int] = (100, 100)) -> pd.DataFrame:
        
        paths = self.get_output_paths(self.memmap_path)
        stage_start = time.time()
        
        # Calculate distances first
        distances = self.compute_distances(masks_dir)
        
        T, Y, X = self.get_video_shape()
        video = np.memmap(self.memmap_path, dtype=np.uint16, mode='r', shape=(T, Y, X))
        self.log_time("Video Loading", stage_start)
        
        # Extract fluorescence traces in parallel
        stage_start = time.time()
        fluorescence_traces = self.extract_fluorescence_traces_parallel(video, masks_dir, roi_shape)
        self.log_time("Fluorescence Extraction", stage_start)
        
        # Calculate dF/F
        stage_start = time.time()
        dff_traces = self.calculate_dff(fluorescence_traces)
        self.log_time("dF/F Calculation", stage_start)
        
        # Save dF/F traces externally
        self.save_dff_traces(dff_traces, paths['dff_traces'])
        
        # Compute metrics
        stage_start = time.time()
        metrics = self.compute_fluorescence_metrics(dff_traces, *baseline_range, *analysis_range)
        self.log_time("Metrics Computation", stage_start)
        
        # Create DataFrame with metrics
        columns = ["Peak Amplitude", "Time of Peak", "Std Dev", "AUC", "Max Rise Slope",
                  "Time of Max Rise Slope", "Slope", "Rise Time", "Rise Slope"]
        df_metrics = pd.DataFrame.from_dict(metrics, orient='index', columns=columns)
        
        # Add background information to metrics
        background_df = pd.DataFrame.from_dict(
            {roi: {
                'Background Mean': info['background_mean'],
                'Background Std': info['background_std'],
                'Background ROI Center X': info['background_roi_center'][1],
                'Background ROI Center Y': info['background_roi_center'][0]
            } for roi, info in self.background_values.items()},
            orient='index'
        )
        
        # Create distances DataFrame
        df_distances = pd.DataFrame.from_dict(distances, orient='index', columns=['Distance to Lamina'])
        
        # Merge all the DataFrames
        df_final = pd.concat([df_metrics, background_df, df_distances], axis=1)
        
        # Add processing times
        df_final.attrs['processing_times'] = self.processing_times
        
        # Save results
        df_final.to_excel(paths['metrics'])
        logger.info(f"Saved fluorescence metrics, background info, and distances to {paths['metrics']}")
        
        # Save background traces separately
        background_traces_path = os.path.join(os.path.dirname(paths['metrics']), 
                                            f"{self.current_base_name}_background_traces.pkl")
        with open(background_traces_path, 'wb') as f:
            pickle.dump(self.background_values, f)
        logger.info(f"Saved background traces to {background_traces_path}")
        
        return df_final

    def get_video_shape(self) -> Tuple[int, int, int]:
        """Gets the shape of the memory-mapped video."""
        file_size = os.path.getsize(self.memmap_path)
        bytes_per_pixel = np.dtype(np.uint16).itemsize
        t = file_size // (self.video_dims[0] * self.video_dims[1] * bytes_per_pixel)
        
        if file_size % (self.video_dims[0] * self.video_dims[1] * bytes_per_pixel) != 0:
            raise ValueError("File size is not evenly divisible by expected frame size.")
            
        return (t, *self.video_dims)

    def extract_fluorescence_traces_parallel(self, video: np.ndarray, masks_dir: str, roi_shape: Tuple[int, int]) -> Dict[str, np.ndarray]:
        """Extracts fluorescence traces in parallel."""
        fluorescence_traces = {}

        # Using ThreadPoolExecutor for parallel processing of masks
        with concurrent.futures.ThreadPoolExecutor() as executor:
            # Create a list of future tasks for each ROI
            future_to_roi = {
                executor.submit(self.process_roi_trace, roi_filename, video, masks_dir, roi_shape): roi_filename
                for roi_filename in sorted(os.listdir(masks_dir)) if roi_filename.endswith(".png")
            }
            
            for future in concurrent.futures.as_completed(future_to_roi):
                roi_filename = future_to_roi[future]
                try:
                    # Retrieve the result for each completed task
                    corrected_trace = future.result()
                    fluorescence_traces[roi_filename] = corrected_trace
                except Exception as e:
                    logger.error(f"Error processing ROI {roi_filename}: {e}")
        
        return fluorescence_traces

    def process_roi_trace(self, roi_filename: str, video: np.ndarray, masks_dir: str, roi_shape: Tuple[int, int]) -> np.ndarray:
        roi_path = os.path.join(masks_dir, roi_filename)
        roi_mask = cv2.imread(roi_path, cv2.IMREAD_UNCHANGED) > 0
        
        # Vectorized computation
        fluorescence_trace = np.mean(video[:, roi_mask], axis=1)
        background, roi_center = self.find_background(video, roi_shape)
        
        # Store background information
        self.background_values[roi_filename] = {
            'background_trace': background,
            'background_roi_center': roi_center,
            'background_mean': np.mean(background),
            'background_std': np.std(background)
        }
        
        return np.clip(fluorescence_trace - background, 0, None)

    # Define function to calculate dF/F
    def calculate_dff(self, fluorescence_traces):
        dff_traces = {} 
        for roi, trace in fluorescence_traces.items():
            F0 = np.percentile(trace, 8)  # Baseline as 8th percentile
            dff_traces[roi] = (trace - F0) / F0
        return dff_traces

    # Define function to compute fluorescence metrics
    def compute_fluorescence_metrics(self, dff_traces, baseline_start=0, baseline_end=200, start=233, end=580):
        """Compute fluorescence metrics in parallel."""
        def compute_single_metric(roi_trace):
            # Extract baseline and compute F0
            baseline_segment = roi_trace[baseline_start:baseline_end]
            F0 = np.mean(baseline_segment)
            
            # Get trace segment for analysis
            trace_segment = roi_trace[start:end]
            time_range = np.arange(start, end)
            
            # Peak metrics
            peak_idx = np.argmax(trace_segment)
            peak_amplitude = trace_segment[peak_idx] - F0
            time_of_peak = time_range[peak_idx]
            
            # Other metrics
            std_dev = np.std(trace_segment)
            auc = np.trapz(trace_segment, dx=1)
            
            # Slope calculations
            diff_trace = np.diff(trace_segment)
            max_rise_idx = np.argmax(diff_trace)
            max_rise_slope = diff_trace[max_rise_idx]
            time_of_max_rise_slope = time_range[max_rise_idx]
            slope, _, _, _, _ = linregress(time_range, trace_segment)
            
            # Rise time calculations
            ten_percent = 0.1 * peak_amplitude
            ninety_percent = 0.9 * peak_amplitude
            try:
                rise_start = np.where(trace_segment >= ten_percent)[0][0] + start
                rise_end = np.where(trace_segment >= ninety_percent)[0][0] + start
                rise_time = rise_end - rise_start
                rise_slope = (ninety_percent - ten_percent) / rise_time
            except IndexError:
                rise_time = np.nan
                rise_slope = np.nan
                
            return [peak_amplitude, time_of_peak, std_dev, auc,
                   max_rise_slope, time_of_max_rise_slope, slope, rise_time, rise_slope]
        
        metrics = {}
        with concurrent.futures.ThreadPoolExecutor() as executor:
            future_to_roi = {executor.submit(compute_single_metric, trace): roi 
                           for roi, trace in dff_traces.items()}
            
            for future in concurrent.futures.as_completed(future_to_roi):
                roi = future_to_roi[future]
                try:
                    metrics[roi] = future.result()
                except Exception as e:
                    logger.error(f"Error processing ROI {roi}: {e}")
                    metrics[roi] = [np.nan] * 9  # Fill with NaN if processing fails
        
        return metrics

    # Define function to compute distance to laminar boundary
    def compute_distances(self, roi_folder: str) -> Dict[str, float]:
        """
        Compute distances from ROI centers to the laminar boundary (top of image).
        
        Args:
            roi_folder (str): Path to folder containing ROI mask files
            
        Returns:
            Dict[str, float]: Dictionary mapping ROI filenames to their distances from lamina
        """
        distances = {}
        for roi_filename in sorted(os.listdir(roi_folder)):
            if roi_filename.endswith(".png"):
                roi_path = os.path.join(roi_folder, roi_filename)
                roi_mask = cv2.imread(roi_path, cv2.IMREAD_UNCHANGED) > 0
                
                # Find all non-zero points in the mask
                y_indices, _ = np.where(roi_mask)
                if len(y_indices) > 0:
                    # Calculate center of ROI
                    center_y = np.mean(y_indices)
                    # Distance to lamina (top border, y=0)
                    distance_to_lamina = center_y  
                else:
                    distance_to_lamina = np.nan
                
                distances[roi_filename] = distance_to_lamina
                logger.debug(f"ROI {roi_filename}: Distance to lamina = {distance_to_lamina:.2f} pixels")
        
        return distances



# Example usage to load dF/F traces externally
def main():
    base_dir = "F:/Recovered/Research/BoninLab/PainModelingProject/calcium_imaging_data/@Disinhibition/Disinhib_m1_7.23.20/Disinhib_m1_s2"
    
    # Initialize analysis pipeline
    analysis = CalciumImagingAnalysis(base_dir)
    
    # Process all matching pairs in the folder
    analysis.process_folder()
    
    # After processing, you can load specific dF/F traces if needed
    traces_path = os.path.join(base_dir, "dff_traces.pkl")
    if os.path.exists(traces_path):
        dff_traces = analysis.load_dff_traces(traces_path)
        return dff_traces

if __name__ == "__main__":
    main()



#if __name__ == "__main__":
#    base_dir = "path/to/your/data/folder"
#    analysis = CalciumImagingAnalysis(base_dir)
#    analysis.process_folder()

    # Example usage for processing a single file:
   # analysis = CalciumImagingAnalysis(base_dir)
   # memmap_path = analysis.exponential_correct_tif("your_file.tif")
    #masks_dir = analysis.convert_rois_to_masks("your_rois.zip")
   # results = analysis.analyze_fluorescence(masks_dir)

INFO:__main__:
Starting processing of pair at 2025-02-10 15:43:50.847064:
INFO:__main__:TIF: Disinhib1_7.23.20_ipsi2_0um_cor.tif
INFO:__main__:ROI: RoiSet_m1_s3_0um.zip
  return (a * np.exp(-b * x)) + (c * np.exp(-d * x))
INFO:__main__:Corrected image saved as a memmap at F:/Recovered/Research/BoninLab/PainModelingProject/calcium_imaging_data/@Disinhibition/Disinhib_m1_7.23.20/Disinhib_m1_s2\Disinhib1_7.23.20_ipsi2_0um_cor_corrected.dat
INFO:__main__:Time for Bleach Correction: 195.24 seconds
INFO:__main__:Saved mask: F:/Recovered/Research/BoninLab/PainModelingProject/calcium_imaging_data/@Disinhibition/Disinhib_m1_7.23.20/Disinhib_m1_s2\Disinhib1_7.23.20_ipsi2_0um_cor_corrected_masks\Disinhib1_7.23.20_ipsi2_0um_cor_corrected_mask_001.png
INFO:__main__:Saved mask: F:/Recovered/Research/BoninLab/PainModelingProject/calcium_imaging_data/@Disinhibition/Disinhib_m1_7.23.20/Disinhib_m1_s2\Disinhib1_7.23.20_ipsi2_0um_cor_corrected_masks\Disinhib1_7.23.20_ipsi2_0um_cor_corrected_mask_002.png
