Interpretation in This Context:
The "Restricted Mean Baresoil Time (up to t 
∗
  years)" tells you that, on average, pixels in your cohort remained baresoil for that many years, when looking at the period from 0 to t 
∗
  years. It's the average "event-free" time within this window, where the "event" is transitioning from baresoil to regrowth.

In [2]:
import os
import numpy as np
import pandas as pd
import rasterio
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from lifelines import KaplanMeierFitter, LogNormalFitter
from lifelines.utils import restricted_mean_survival_time # <<<--- IMPORT THIS

# =====================================================================
# CONFIGURATION PARAMETERS
# =====================================================================
DATA_FOLDER = "G:/Vis/new3" # <<<--- UPDATE THIS PATH
if not os.path.exists(DATA_FOLDER):
    print(f"Warning: Data folder '{DATA_FOLDER}' does not exist.")
    try: os.makedirs(DATA_FOLDER, exist_ok=True)
    except OSError as e: print(f"Could not create data folder: {e}"); exit()

YEARS = range(2018, 2025)
INITIAL_YEARS = range(2018, 2022)
FILE_TEMPLATE = 'classified_extract_p_{}_4.tif'
CLASS_CODES = {"BS": 0, "ER": 1, "MR": 2}

TARGET_TRANSITION_KEY = 'BS_to_R'
TRANSITION_DETAILS = {
    'BS_to_R': {
        'time_col': 'bs_to_r_time',
        'event_col': 'bs_to_r_event',
        'title_segment': 'Baresoil to Any Regrowth',
        'cohort_text': '(2018-2021 Cohort)',
        'km_plot_file': 'BS_to_R_recovery_cohort.png'
    }
}

# =====================================================================
# DATA LOADING AND PROCESSING FUNCTIONS
# =====================================================================
def load_classified_image(file_path):
    try:
        with rasterio.open(file_path) as src:
            print(f"  Successfully opened: {file_path}")
            return src.read(1), src.meta
    except Exception as e:
        print(f"  Error loading {file_path}: {e}")
        return None, None

def load_all_images():
    print("Loading classified images...")
    images = {}
    metadata = None
    ref_shape = None
    ref_crs = None
    ref_transform = None
    loaded_years = []

    for year in YEARS:
        file_path = os.path.join(DATA_FOLDER, FILE_TEMPLATE.format(year))
        if os.path.exists(file_path):
            img, meta = load_classified_image(file_path)
            if img is not None and meta is not None:
                current_shape = (meta.get('height'), meta.get('width'))
                current_crs = meta.get('crs')
                current_transform = meta.get('transform')

                if metadata is None:
                    metadata = meta
                    ref_shape = current_shape
                    ref_crs = current_crs
                    ref_transform = current_transform
                    print(f"  Reference metadata set from year {year}")
                    images[year] = img
                    loaded_years.append(year)
                elif current_shape == ref_shape and current_crs == ref_crs and current_transform == ref_transform:
                    images[year] = img
                    loaded_years.append(year)
                else:
                    print(f"  Warning: Image metadata for {year} differs from reference. Skipping.")
            else:
                print(f"  Skipping year {year} due to loading error or missing metadata.")
        else:
            print(f"  File not found, skipping: {file_path}")

    if not images:
        print("\nError: No valid, consistent images found.")
        return None, None

    if metadata is None:
        print("\nError: Could not determine reference metadata.")
        return None, None

    metadata['height'] = ref_shape[0]
    metadata['width'] = ref_shape[1]
    metadata['count'] = 1

    print(f"\nFinished loading {len(images)} images for years: {loaded_years}.")
    return images, metadata

def identify_initial_baresoil(images, metadata):
    if not images or not metadata:
        print("Error: Cannot identify initial baresoil without images or metadata.")
        return None

    print(f"\nIdentifying initial baresoil pixels from years {min(INITIAL_YEARS)}-{max(INITIAL_YEARS)}...")
    height, width = metadata['height'], metadata['width']
    initial_bs_mask = np.zeros((height, width), dtype=bool)
    processed_years_count = 0
    available_initial_years = []

    for year in INITIAL_YEARS:
        if year in images:
            current_bs_mask = (images[year] == CLASS_CODES["BS"])
            initial_bs_mask = initial_bs_mask | current_bs_mask
            processed_years_count += 1
            available_initial_years.append(year)
        else:
            print(f"  Warning: Image for initial year {year} not available.")

    if processed_years_count == 0:
        print("\nError: No images available for the specified initial years.")
        return None

    num_initial_bs = np.sum(initial_bs_mask)
    print(f"\nIdentified {num_initial_bs:,} pixels as baresoil during years {available_initial_years}.")

    if num_initial_bs == 0:
        print("Warning: No initial baresoil pixels found.")

    return initial_bs_mask

def track_transitions(images, metadata, initial_bs_mask):
    if initial_bs_mask is None or np.sum(initial_bs_mask) == 0 or not images:
        print("\nError: Cannot track transitions with missing data.")
        return None, None, None, None

    print("\nTracking transitions for initial baresoil pixels...")
    height, width = metadata['height'], metadata['width']

    bs_year = np.zeros_like(initial_bs_mask, dtype=np.uint16)
    er_year = np.zeros_like(initial_bs_mask, dtype=np.uint16)
    mr_year = np.zeros_like(initial_bs_mask, dtype=np.uint16)
    r_year = np.zeros_like(initial_bs_mask, dtype=np.uint16)

    sorted_years = sorted(images.keys())
    print(f"  Tracking across years: {sorted_years}")

    not_yet_bs = (bs_year == 0)
    not_yet_er = (er_year == 0)
    not_yet_mr = (mr_year == 0)
    not_yet_r = (r_year == 0)

    for year in tqdm(sorted_years, desc="Tracking years"):
        img = images[year]
        active_pixels = initial_bs_mask

        # Track BS
        is_bs_now = (img == CLASS_CODES["BS"]) & active_pixels
        newly_bs = is_bs_now & not_yet_bs
        bs_year[newly_bs] = year
        not_yet_bs[newly_bs] = False

        # Track ER
        is_er_now = (img == CLASS_CODES["ER"]) & active_pixels
        newly_er = is_er_now & not_yet_er
        er_year[newly_er] = year
        not_yet_er[newly_er] = False

        # Track MR
        is_mr_now = (img == CLASS_CODES["MR"]) & active_pixels
        newly_mr = is_mr_now & not_yet_mr
        mr_year[newly_mr] = year
        not_yet_mr[newly_mr] = False

        # Track any Regrowth (ER or MR)
        is_regrowth_now = (is_er_now | is_mr_now)
        newly_r = is_regrowth_now & not_yet_r
        r_year[newly_r] = year
        not_yet_r[newly_r] = False

    print("\nFinished tracking transitions.")
    return bs_year, er_year, mr_year, r_year

def calculate_transition_times(images, metadata, initial_bs_mask, bs_year, er_year, mr_year, r_year):
    if bs_year is None or er_year is None or mr_year is None or r_year is None:
        print("\nError: Missing tracking arrays.")
        return None

    print("\nCalculating transition times and event flags...")
    height, width = metadata['height'], metadata['width']
    last_year = max(images.keys())

    valid_cohort_mask = initial_bs_mask & (bs_year > 0)
    row_indices, col_indices = np.where(valid_cohort_mask)
    num_pixels_to_process = len(row_indices)

    if num_pixels_to_process == 0:
        print("Warning: No valid cohort pixels found.")
        return pd.DataFrame()

    print(f"  Processing {num_pixels_to_process:,} pixels...")

    first_bs_yr_vals = bs_year[row_indices, col_indices]
    first_r_yr_vals = r_year[row_indices, col_indices]

    # Calculate BS to Regrowth transition times and events
    bs_to_r_event = (first_r_yr_vals > first_bs_yr_vals).astype(np.int8)
    bs_to_r_time = np.where(bs_to_r_event == 1,
                              first_r_yr_vals - first_bs_yr_vals,
                              last_year - first_bs_yr_vals + 1)

    # Create DataFrame
    data_dict = {
        'pixel_id': row_indices * width + col_indices,
        'row': row_indices,
        'col': col_indices,
        'first_bs_year': first_bs_yr_vals.astype(np.uint16),
        'first_r_year': np.where(first_r_yr_vals > 0, first_r_yr_vals, np.nan),
        'bs_to_r_time': bs_to_r_time.astype(np.int16),
        'bs_to_r_event': bs_to_r_event.astype(np.int8),
    }

    transitions_df = pd.DataFrame(data_dict)

    print("\nTransition Calculation Summary:")
    print(f"  BS -> R (any regrowth): {transitions_df['bs_to_r_event'].sum():,} events observed.")
    remained_bs_count = np.sum(transitions_df['bs_to_r_event'] == 0)
    print(f"  Remained BS (never reached R): {remained_bs_count:,} pixels.")

    # Save to CSV
    csv_path = os.path.join(DATA_FOLDER, "pixel_transition_times_cohort.csv")
    try:
        transitions_df.to_csv(csv_path, index=False, na_rep='NA')
        print(f"\nPixel transition data saved to {csv_path}")
    except Exception as e:
        print(f"\nError saving transition data to CSV: {e}")

    return transitions_df

# =====================================================================
# KAPLAN-MEIER ANALYSIS WITH LOG-NORMAL FIT
# =====================================================================
def plot_km_with_lognormal(df, time_col, event_col, title, filename):
    """Plot Kaplan-Meier recovery curve with confidence intervals, event times, and log-normal fit"""
    print(f"\nGenerating recovery curve with log-normal fit for: {title}")

    # Prepare data
    analysis_df = df[[time_col, event_col]].dropna()
    times = analysis_df[time_col].to_numpy(dtype=float)
    events = analysis_df[event_col].to_numpy(dtype=float)

    if len(times) == 0:
        print("  No data available. Skipping plot.")
        return None, None

    # Fix any non-positive times
    if (times <= 0).any():
        print("  Warning: Non-positive times found. Setting to minimum of 1.")
        times = np.maximum(times, 1)

    # Set up plotting
    plt.style.use('seaborn-v0_8-whitegrid') if 'seaborn-v0_8-whitegrid' in plt.style.available else plt.style.use('seaborn-whitegrid')
    fig, ax = plt.subplots(figsize=(10, 6), dpi=150)

    # Use lifelines to fit KM and log-normal
    kmf = KaplanMeierFitter()
    lnf = LogNormalFitter()

    try:
        # Fit Kaplan-Meier
        kmf.fit(times, event_observed=events, label='KM_estimate')

        # Plot KM confidence intervals
        try:
            ci = kmf.confidence_interval_survival_function_
            ci_lower_col = f'{kmf._label}_lower_0.95'
            ci_upper_col = f'{kmf._label}_upper_0.95'

            if ci_lower_col in ci.columns and ci_upper_col in ci.columns:
                recovery_lower = 1.0 - ci[ci_upper_col]
                recovery_upper = 1.0 - ci[ci_lower_col]
                recovery_lower = np.maximum(0, recovery_lower)
                recovery_upper = np.minimum(1, recovery_upper)
                ax.fill_between(ci.index, recovery_lower, recovery_upper, step='post',
                                  alpha=0.3, color='lightblue', label='95% CI')
        except Exception as e:
            print(f"  Error plotting confidence intervals: {e}")

        # Plot KM curve (as recovery = 1 - survival)
        km_survival = kmf.survival_function_[kmf._label].values
        km_times = kmf.survival_function_.index.values
        km_recovery = 1.0 - km_survival

        ax.step(km_times, km_recovery, where='post', label='Kaplan-Meier Estimate',
                  color='black', linewidth=2, linestyle='-')

        # Plot event dots
        try:
            event_table = kmf.event_table
            event_occurred_times = event_table.index[event_table['observed'] > 0]
            if len(event_occurred_times) > 0:
                survival_prob_at_event = kmf.predict(event_occurred_times)
                recovery_prob_at_event = 1.0 - survival_prob_at_event
                ax.plot(event_occurred_times, recovery_prob_at_event, marker='o', linestyle='none',
                          color='red', markersize=4, label='Event Time')
        except Exception as e:
            print(f"  Error plotting event times: {e}")

        # Get KM median
        median_recovery_time = kmf.median_survival_time_
        if np.isinf(median_recovery_time):
            print(f"  Median time: > {np.max(times)} yrs (KM est.)")
        else:
            print(f"  Median time: {median_recovery_time:.2f} yrs (KM est.)")

        # Plot median line if available
        if not np.isinf(median_recovery_time) and median_recovery_time > 0:
            try:
                # Find the recovery probability at the median time
                # We need to find the step in km_recovery that corresponds to median_recovery_time
                # The KM curve is 1 - S(t), so we are looking for P(T <= median_recovery_time) approx 0.5
                # The median_survival_time_ is when S(t) crosses 0.5, so recovery is 1 - 0.5 = 0.5
                # However, due to step function nature, we should use the actual value from the curve.
                # kmf.predict(median_recovery_time) gives S(median_recovery_time)
                median_survival_prob_on_curve = kmf.predict(median_recovery_time)
                median_recovery_prob_on_curve = 1.0 - median_survival_prob_on_curve

                ax.hlines(median_recovery_prob_on_curve, 0, median_recovery_time, color='red', linestyle='--',
                              linewidth=1.5, label=f'Median Rec. Time ({median_recovery_time:.2f} yrs)')
                ax.vlines(median_recovery_time, 0, median_recovery_prob_on_curve, color='red', linestyle='--', linewidth=1.5)
            except Exception as e:
                print(f"  Error plotting median lines: {e}")
        
        # Fit Log-normal model
        try:
            lnf.fit(times, event_observed=events, label='LogNormal_fit')
            params = lnf.params_
            mu, sigma = params
            print(f"  Log-normal parameters: mu={mu:.4f}, sigma={sigma:.4f}")
            print(f"  Log-normal median: {np.exp(mu):.2f} years")
            print(f"  Log-normal mean: {np.exp(mu + (sigma**2)/2):.2f} years")

            # Plot log-normal fit
            max_time_for_plot = max(times) * 1.5 
            plot_times_ln = np.linspace(0.1, max_time_for_plot, 500) 
            ln_survival = lnf.survival_function_at_times(plot_times_ln)
            ln_recovery = 1 - ln_survival

            ax.plot(plot_times_ln, ln_recovery, label=f'Log-normal Fit',
                      color='#2ca02c', linewidth=2, alpha=0.8)

            # Calculate recovery at specific times
            predict_times_ln = np.array([5, 10, 15, 20]) 
            for t_ln in predict_times_ln: 
                try:
                    sf_df = lnf.survival_function_at_times(t_ln)
                    if not sf_df.empty:
                         rec_prob_ln = 1 - sf_df.iloc[0, 0]
                         print(f"  Recovery at {t_ln} years (LogNormal): {rec_prob_ln*100:.1f}%")
                    else:
                        print(f"  LogNormal survival function returned empty for {t_ln} years.")
                except Exception as e_ln_pred_plot:
                     print(f"  Could not calculate LogNormal recovery for plot at {t_ln} years: {e_ln_pred_plot}")


        except Exception as e:
            print(f"  Error with log-normal fit: {e}")
            lnf = None 

    except Exception as e:
        print(f"  Error in KM estimation: {e}")
        if 'kmf' in locals(): # Check if kmf was initialized
             kmf = None # Ensure kmf is None if fitting failed before it's fully assigned
        lnf = None # Also ensure lnf is None
        return kmf, lnf


    # Finalize plot
    ax.set_xlabel("Time (years)", fontsize=12)
    ax.set_ylabel("Probability of Recovery", fontsize=12)
    ax.set_title(f"Recovery Curve: {title}", fontsize=14, fontweight='bold')

    handles, labels = ax.get_legend_handles_labels()
    # Filter out duplicate labels for median line if it was attempted multiple times or has multiple segments
    unique_labels = {}
    for handle, label in zip(handles, labels):
        if label not in unique_labels:
            unique_labels[label] = handle
    ax.legend(unique_labels.values(), unique_labels.keys(), fontsize=10)


    ax.grid(True, linestyle=':', alpha=0.7)
    ax.set_ylim(0, 1.05)

    max_time_obs = km_times.max() if 'km_times' in locals() and len(km_times) > 0 else 10
    xlim_upper = max(15, max_time_obs * 1.5)
    ax.set_xlim(left=0, right=xlim_upper)

    plt.tight_layout()

    # Save plots
    plot_base_path = os.path.join(DATA_FOLDER, os.path.splitext(filename)[0])
    png_path = plot_base_path + ".png"
    pdf_path = plot_base_path + ".pdf"

    try:
        plt.savefig(png_path, bbox_inches='tight')
        print(f"  Recovery curve saved to {png_path}")
        plt.savefig(pdf_path, bbox_inches='tight')
        print(f"  Recovery curve also saved as PDF to {pdf_path}")
    except Exception as e:
        print(f"  Error saving plots: {e}")

    plt.close(fig)
    
    # Ensure kmf is defined if an early error in try block occurred before its full fitting
    if 'kmf' not in locals() or not hasattr(kmf, 'survival_function_'):
        kmf = None
    if 'lnf' not in locals() or not hasattr(lnf, 'params_'): # Check if lnf was fitted
        lnf = None

    return kmf, lnf

def create_transition_boxplot(transitions_df):
    """Create boxplot of observed BS to R transition times (events only)"""
    if transitions_df is None or transitions_df.empty:
        print("\nWarning: Transition DataFrame empty.")
        return

    print("\nCreating box plot for observed transition times (events only)...")

    plt.style.use('seaborn-v0_8-whitegrid') if 'seaborn-v0_8-whitegrid' in plt.style.available else plt.style.use('seaborn-whitegrid')

    time_col = TRANSITION_DETAILS[TARGET_TRANSITION_KEY]['time_col']
    event_col = TRANSITION_DETAILS[TARGET_TRANSITION_KEY]['event_col']

    bs_r_t = transitions_df.loc[transitions_df[event_col] == 1, time_col].dropna()
    if bs_r_t.empty:
        print("  No observed events found. Skipping box plot.")
        return

    bs_r_t = pd.to_numeric(bs_r_t, errors='coerce').dropna()
    if bs_r_t.empty: # Check again after numeric conversion
        print("  No valid numeric observed event times found after conversion. Skipping box plot.")
        return

    fig, ax = plt.subplots(figsize=(7, 7), dpi=150)
    sns.boxplot(y=bs_r_t, ax=ax, color='skyblue')
    sns.stripplot(y=bs_r_t, ax=ax, color='black', alpha=0.3, size=3, jitter=True)

    mean_val = bs_r_t.mean()
    # median_val = bs_r_t.median() # Not used in text annotation directly
    count_val = len(bs_r_t)

    ax.axhline(mean_val, color='red', linestyle='--', linewidth=1.5)
    ax.text(0.95, mean_val, f' Mean: {mean_val:.1f} yrs',
            ha='right', va='center', color='red', fontweight='bold',
            bbox=dict(facecolor='white', alpha=0.8, edgecolor='red', boxstyle='round,pad=0.2'))

    ax.text(0.5, 0.01, f'n = {count_val} events observed',
            ha='center', va='bottom', transform=ax.transAxes, fontweight='bold',
            bbox=dict(facecolor='white', alpha=0.8, edgecolor='black', boxstyle='round,pad=0.2'))

    ax.set_ylabel('Time to Recovery (years)', fontweight='bold')
    ax.set_title('Distribution of Observed BS to Regrowth Times\n(Events Only)', fontweight='bold')
    ax.set_xlabel('')
    ax.set_xticks([])

    description = (f"Plot shows durations only for pixels where transition was observed by {max(YEARS)}.\n"
                   f"Cohort: Pixels classified as BS at least once during {min(INITIAL_YEARS)}-{max(INITIAL_YEARS)}.")
    fig.text(0.5, 0.01, description, ha='center', va='bottom', fontsize=9, style='italic')

    plt.tight_layout(rect=[0, 0.05, 1, 0.95])

    boxplot_base_path = os.path.join(DATA_FOLDER, "observed_BS_to_R_transition_boxplot")
    png_path = boxplot_base_path + ".png"
    pdf_path = boxplot_base_path + ".pdf"

    try:
        plt.savefig(png_path, bbox_inches='tight')
        print(f"  Box plot saved to {png_path}")
        plt.savefig(pdf_path, bbox_inches='tight')
        print(f"  Box plot also saved as PDF to {pdf_path}")
    except Exception as e:
        print(f"  Error saving box plots: {e}")

    plt.close(fig)

def create_summary_report(transitions_df, kmf, lnf):
    """Create a summary report of all analysis results"""
    if transitions_df is None or transitions_df.empty:
        print("\nWarning: Transition DataFrame empty.")
        return

    print("\n" + "="*80)
    print(" RECOVERY ANALYSIS SUMMARY ")
    print("="*80)

    total_pixels = len(transitions_df)
    time_col = TRANSITION_DETAILS[TARGET_TRANSITION_KEY]['time_col']
    event_col = TRANSITION_DETAILS[TARGET_TRANSITION_KEY]['event_col']
    events_count = transitions_df[event_col].sum() # Renamed to avoid conflict
    censored = total_pixels - events_count

    print(f"\nCohort Summary:")
    print(f"  Total pixels in cohort: {total_pixels:,}")
    print(f"  Pixels that transitioned to regrowth: {events_count:,} ({events_count/total_pixels*100:.1f}%)")
    print(f"  Pixels still baresoil: {censored:,} ({censored/total_pixels*100:.1f}%)")

    event_times_df = transitions_df.loc[transitions_df[event_col] == 1, time_col] # Renamed
    if not event_times_df.empty:
        print(f"\nObserved Recovery Times (events only):")
        print(f"  Mean: {event_times_df.mean():.2f} years")
        print(f"  Median: {event_times_df.median():.2f} years")
        print(f"  Min: {event_times_df.min():.0f} years")
        print(f"  Max: {event_times_df.max():.0f} years")

    if kmf is not None and hasattr(kmf, 'median_survival_time_'): # Check if kmf is properly fitted
        print(f"\nKaplan-Meier Estimates:")
        try:
            median_time_km = kmf.median_survival_time_ # Renamed
            if np.isinf(median_time_km):
                print(f"  Median recovery time: Not reached within observation period")
            else:
                print(f"  Median recovery time: {median_time_km:.2f} years")

            # Calculate Restricted Mean Survival Time (RMST)
            # We define "survival" here as "remaining baresoil". So RMST is restricted mean time as baresoil.
            # For "recovery time", we'd want 1 - S(t). Or, RMST is "mean time to event, restricted".
            # The function calculates area under S(t). So this is "Restricted Mean Baresoil Time".
            # To get "Restricted Mean Recovery Time", it's more complex directly from S(t)'s RMST.
            # Let's report RMST of being in the initial state (Baresoil).
            
            # Define t_star for RMST, e.g., the maximum observed time in the study
            # This should be from the 'times' array used to fit kmf.
            # We can get it from kmf.timeline.max() if available and populated.
            if kmf.timeline is not None and len(kmf.timeline) > 0:
                t_star_rmst = kmf.timeline.max()
                # RMST for "time spent as baresoil"
                rmst_baresoil = restricted_mean_survival_time(kmf, t=t_star_rmst)
                if rmst_baresoil is not None:
                     # The interpretation for recovery is tricky. RMST is area under S(t).
                     # S(t) here is prob of *not* recovering (still baresoil).
                     # So, rmst_baresoil is mean time spent as baresoil, up to t_star_rmst.
                    print(f"  Restricted Mean Baresoil Time (up to {t_star_rmst:.2f} years): {rmst_baresoil:.2f} years")

                    # If you want "restricted mean time lost to not recovering by t_star"
                    # rmtl = t_star_rmst - rmst_baresoil
                    # print(f"  Restricted Mean Time Lost (due to not recovering by {t_star_rmst:.2f} years): {rmtl:.2f} years")
                else:
                    print(f"  Restricted Mean Baresoil Time could not be calculated.")
            else:
                print(f"  KM timeline not available for RMST calculation.")


            times_to_predict_km = [5, 10, 15, 20] # Renamed
            for t_km_pred in times_to_predict_km: # Renamed
                if kmf.timeline is not None and len(kmf.timeline) > 0 and t_km_pred <= kmf.timeline.max():
                    surv_prob_km = kmf.predict(t_km_pred)
                    rec_prob_km = 1 - surv_prob_km
                    print(f"  Recovery probability by {t_km_pred} years: {rec_prob_km*100:.1f}%")
                elif kmf.timeline is not None and len(kmf.timeline) > 0: # t_km_pred is beyond timeline max
                    surv_prob_at_max_km = kmf.predict(kmf.timeline.max())
                    rec_prob_at_max_km = 1 - surv_prob_at_max_km
                    print(f"  Recovery probability by {t_km_pred} years: {rec_prob_at_max_km*100:.1f}% (value from max timeline: {kmf.timeline.max():.1f} yrs)")
                else:
                     print(f"  KM timeline not available to predict recovery at {t_km_pred} years.")


        except Exception as e:
            print(f"  Error extracting KM estimates: {e}")
    elif kmf is None:
        print(f"\nKaplan-Meier model (kmf) is None. Skipping KM estimates.")


    if lnf is not None and hasattr(lnf, 'params_'): # Check if lnf is properly fitted
        print(f"\nLog-normal Model Estimates:")
        try:
            params_ln = lnf.params_ # Renamed
            mu_ln, sigma_ln = params_ln # Renamed

            print(f"  Parameters: mu={mu_ln:.4f}, sigma={sigma_ln:.4f}")
            median_time_ln_model = np.exp(mu_ln) # Renamed
            mean_time_ln_model = np.exp(mu_ln + (sigma_ln**2)/2) # Renamed

            print(f"  Median recovery time (model): {median_time_ln_model:.2f} years")
            print(f"  Mean recovery time (model): {mean_time_ln_model:.2f} years")

            times_to_predict_ln_report = [5, 10, 15, 20] # Renamed
            for t_ln_rep in times_to_predict_ln_report: # Renamed
                try:
                    sf_val_ln = lnf.survival_function_at_times(t_ln_rep).iloc[0, 0] # Renamed
                    rec_prob_ln_rep = 1 - sf_val_ln # Renamed
                    print(f"  Recovery probability by {t_ln_rep} years: {rec_prob_ln_rep*100:.1f}%")
                except Exception as e_ln_rep_pred: # Renamed
                    print(f"  Could not calculate Lognormal recovery for report at {t_ln_rep} years: {e_ln_rep_pred}")

            param_data_ln = { # Renamed
                'Parameter': ['mu', 'sigma', 'median_time_model', 'mean_time_model'],
                'Value': [mu_ln, sigma_ln, median_time_ln_model, mean_time_ln_model]
            }
            param_df_ln = pd.DataFrame(param_data_ln) # Renamed
            param_path_ln = os.path.join(DATA_FOLDER, "lognormal_parameters.csv") # Renamed
            param_df_ln.to_csv(param_path_ln, index=False)
            print(f"\nLog-normal parameters saved to: {param_path_ln}")

        except Exception as e:
            print(f"  Error extracting log-normal estimates: {e}")
    elif lnf is None:
         print(f"\nLog-normal model (lnf) is None. Skipping Log-normal estimates.")


    print("\n" + "="*80)


# =====================================================================
# MAIN EXECUTION
# =====================================================================
def main():
    print("="*70)
    print(" Baresoil to Regrowth Recovery Analysis (KM with Log-normal)")
    print("="*70)

    images, metadata = load_all_images()
    if images is None or metadata is None:
        print("Exiting due to data loading errors.")
        return

    initial_bs_mask = identify_initial_baresoil(images, metadata)
    if initial_bs_mask is None or np.sum(initial_bs_mask) == 0:
        print("Exiting: No initial baresoil pixels identified.")
        return

    bs_year, er_year, mr_year, r_year = track_transitions(images, metadata, initial_bs_mask)
    if bs_year is None:
        print("Exiting due to errors during transition tracking.")
        return

    transitions_df = calculate_transition_times(images, metadata, initial_bs_mask, bs_year, er_year, mr_year, r_year)
    if transitions_df is None or transitions_df.empty:
        print("Exiting: No transition data available.")
        return

    key = TARGET_TRANSITION_KEY
    params_km = TRANSITION_DETAILS[key] # Renamed
    full_title_km = f"{params_km['title_segment']} {params_km['cohort_text']}" # Renamed

    kmf_model, lnf_model = plot_km_with_lognormal( # Renamed
        transitions_df,
        params_km['time_col'],
        params_km['event_col'],
        full_title_km,
        params_km['km_plot_file']
    )

    create_transition_boxplot(transitions_df)
    create_summary_report(transitions_df, kmf_model, lnf_model) # Use renamed models

    print("\nAnalysis complete!")

if __name__ == "__main__":
    main()

 Baresoil to Regrowth Recovery Analysis (KM with Log-normal)
Loading classified images...
  Successfully opened: G:/Vis/new3\classified_extract_p_2018_4.tif
  Reference metadata set from year 2018
  Successfully opened: G:/Vis/new3\classified_extract_p_2019_4.tif
  Successfully opened: G:/Vis/new3\classified_extract_p_2020_4.tif
  Successfully opened: G:/Vis/new3\classified_extract_p_2021_4.tif
  Successfully opened: G:/Vis/new3\classified_extract_p_2022_4.tif
  Successfully opened: G:/Vis/new3\classified_extract_p_2023_4.tif
  Successfully opened: G:/Vis/new3\classified_extract_p_2024_4.tif

Finished loading 7 images for years: [2018, 2019, 2020, 2021, 2022, 2023, 2024].

Identifying initial baresoil pixels from years 2018-2021...

Identified 9,397,673 pixels as baresoil during years [2018, 2019, 2020, 2021].

Tracking transitions for initial baresoil pixels...
  Tracking across years: [2018, 2019, 2020, 2021, 2022, 2023, 2024]


Tracking years: 100%|████████████████████████████████████████████████████████████████████| 7/7 [00:08<00:00,  1.16s/it]



Finished tracking transitions.

Calculating transition times and event flags...
  Processing 9,397,673 pixels...

Transition Calculation Summary:
  BS -> R (any regrowth): 6,251,856 events observed.
  Remained BS (never reached R): 3,145,817 pixels.

Pixel transition data saved to G:/Vis/new3\pixel_transition_times_cohort.csv

Generating recovery curve with log-normal fit for: Baresoil to Any Regrowth (2018-2021 Cohort)
  Median time: 3.00 yrs (KM est.)
  Log-normal parameters: mu=1.2047, sigma=0.8420
  Log-normal median: 3.34 years
  Log-normal mean: 4.75 years
  Could not calculate LogNormal recovery for plot at 5 years: Too many indexers
  Could not calculate LogNormal recovery for plot at 10 years: Too many indexers
  Could not calculate LogNormal recovery for plot at 15 years: Too many indexers
  Could not calculate LogNormal recovery for plot at 20 years: Too many indexers
  Recovery curve saved to G:/Vis/new3\BS_to_R_recovery_cohort.png
  Recovery curve also saved as PDF to G:/