In [1]:
# from google.colab import runtime
# runtime.unassign()

In [2]:
# 1.2 Imports
## 1.2.1 Import Libraries
import pymc as pm
import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math
import seaborn as sns
import warnings
import pkg_resources
import jax
import numpyro
import torch
import pytensor.tensor as pt
import os
import arviz as az
import pickle
import time
import random
import datetime as dt
import functools
import jax.numpy as jnp
import jax.random as random
import json
import gc
import psutil
import sys
import threading
import dataclasses
import xarray as xr
from pymc import *
from scipy import stats
from datetime import datetime
from google.cloud import bigquery
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from sklearn.preprocessing import StandardScaler

## 1.2.2 Unit Test for Imports
try:
    _ = [pm, tf, pd, np, plt, sns, stats, bigquery]
    print("All libraries imported successfully.")
except ImportError as e:
    print("Library import failed:", e)

All libraries imported successfully.


In [3]:
# 1.2.3 Retrieve the list of installed and imported packages and their point versions
# Install a specific point version of pymc
# !pip install pymc==3.11.4
# installed_packages = pkg_resources.working_set
# package_versions = {package.key: package.version for package in installed_packages}
# print(f"Installed packages and versions:\n{package_versions}")
%pip freeze > requirements.txt
# %pip install -r requirements.txt

In [4]:
# 1.2.4 Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

In [5]:
# 1.2.7 Set numpyro to use available GPUs
numpyro.set_host_device_count(4)

In [6]:
# 1.2.6 Number of GPUs discovered by jax
import jax
print("\nJAX version:", jax.__version__)
print("\nAvailable devices:")
print(jax.devices())


JAX version: 0.4.35

Available devices:
[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]


In [7]:
# 1.2.8 View GPU specs
!pip install nvidia-smi torch
!nvidia-smi

Fri Nov 15 00:50:47 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   65C    P0             30W /   72W |     195MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA L4                      Off |   00

In [8]:
# Data processor config override
dp_config = {
  # Data processing parameters
  'MIN_FREQUENCY': 1,
  'MIN_REVENUE': 1,
  'MIN_TRANSACTION_VALUE': 1,
  'OUTLIER_THRESHOLD': 3,  # Number of IQRs for outlier detection
  # Query parameters
  'PROJECT_ID': 'logic-dna-240402',
  'DATASET': 'CLV',
  'TABLE': 'T_CLV_360',
  'LIMIT': 1500000,
  'COHORT_MONTH': '2019-02-01',
  'MIN_PURCHASE_DATE': '2022-02-03',
  'MAX_PURCHASE_DATE': None,  # Defaults to current date if None
  'INCLUDE_ONLINE': True,
  'INCLUDE_STORE': True,
  'MIN_LOYALTY_POINTS': 0
}
dp_config

{'MIN_FREQUENCY': 1,
 'MIN_REVENUE': 1,
 'MIN_TRANSACTION_VALUE': 1,
 'OUTLIER_THRESHOLD': 3,
 'PROJECT_ID': 'logic-dna-240402',
 'DATASET': 'CLV',
 'TABLE': 'T_CLV_360',
 'LIMIT': 1500000,
 'COHORT_MONTH': '2019-02-01',
 'MIN_PURCHASE_DATE': '2022-02-03',
 'MAX_PURCHASE_DATE': None,
 'INCLUDE_ONLINE': True,
 'INCLUDE_STORE': True,
 'MIN_LOYALTY_POINTS': 0}

In [9]:
# Data processor
class CLVDataProcessor:
    """
    CLV data processing class with data quality checks
    """
    def __init__(self, config=None):
        self.config = dp_config or {
    # Data processing parameters
    'MIN_FREQUENCY': 1,
    'MIN_REVENUE': 1,
    'MIN_TRANSACTION_VALUE': 1,
    'OUTLIER_THRESHOLD': 3,  # Number of IQRs for outlier detection
    # Query parameters
    'PROJECT_ID': 'logic-dna-240402',
    'DATASET': 'CLV',
    'TABLE': 'T_CLV_360',
    'LIMIT': 10000000,
    'COHORT_MONTH': '2023-02-01',
    'MIN_PURCHASE_DATE': '2022-02-03',
    'MAX_PURCHASE_DATE': 'CURRENT_DATE()',  # Defaults to current date if None
    'INCLUDE_ONLINE': True,
    'INCLUDE_STORE': True,
    'MIN_LOYALTY_POINTS': 0
  }
        self.data = None
        self.quality_flags = None

    def _build_query(self):
        """Build BigQuery query based on configuration parameters"""
        # Handle date parameters
        max_date = self.config['MAX_PURCHASE_DATE'] or 'CURRENT_DATE()'
        min_date = f"DATE('{self.config['MIN_PURCHASE_DATE']}')"
        cohort_month = f"DATE('{self.config['COHORT_MONTH']}')"

        # Build channel condition
        channel_conditions = []
        if self.config['INCLUDE_ONLINE']:
            channel_conditions.append("has_online_purchases = 1")
        if self.config['INCLUDE_STORE']:
            channel_conditions.append("has_store_purchases = 1")
        channel_filter = f"({' OR '.join(channel_conditions)})" if channel_conditions else "TRUE"

        query = f"""
        WITH
        fin AS (
        SELECT
          CAST(customer_id AS STRING) AS customer_id,
          CAST(cohort_month AS STRING) AS cohort_month,
          CAST(recency_days AS INT64) AS recency,
          CAST(frequency AS INT64) AS frequency,
          ROUND(total_revenue,2) AS monetary,
          ROUND(total_revenue,2) AS total_revenue,
          ROUND(revenue_trend,4) AS revenue_trend,
          ROUND(avg_transaction_value,2) AS avg_transaction_value,
          CAST(first_purchase_date AS DATE) AS first_purchase_date,
          CAST(last_purchase_date AS DATE) AS last_purchase_date,
          CAST(customer_age_days AS INT64) AS customer_age_days,
          CAST(distinct_categories AS INT64) AS distinct_categories,
          CAST(distinct_brands AS INT64) AS distinct_brands,
          ROUND(avg_interpurchase_days,2) AS avg_interpurchase_days,
          CAST(has_online_purchases AS INT64) AS has_online_purchases,
          CAST(has_store_purchases AS INT64) AS has_store_purchases,
          ROUND(total_discount_amount,2) AS total_discount_amount,
          ROUND(avg_discount_amount,2) AS avg_discount_amount,
          ROUND(COALESCE(discount_rate,0),3) AS discount_rate,
          CAST(sms_active AS INT64) AS sms_active,
          CAST(email_active AS INT64) AS email_active,
          CAST(is_loyalty_member AS INT64) AS is_loyalty_member,
          CAST(loyalty_points AS INT64) AS loyalty_points
        FROM
          `{self.config['PROJECT_ID']}.{self.config['DATASET']}.{self.config['TABLE']}`
        WHERE
          customer_id IS NOT NULL
          AND cohort_month IS NOT NULL
          AND frequency >= {self.config['MIN_FREQUENCY']}
          AND total_revenue >= {self.config['MIN_REVENUE']}
          AND avg_transaction_value >= {self.config['MIN_TRANSACTION_VALUE']}
          AND cohort_month >= {cohort_month}
          # AND first_purchase_date >= {min_date}
          AND last_purchase_date <= {max_date}
          AND loyalty_points >= {self.config['MIN_LOYALTY_POINTS']}
          AND {channel_filter}
        )
        SELECT
          *
        FROM
          fin
        LIMIT
          {self.config['LIMIT']}
        """

        return query

    def load_data(self, query=None, project_id=None, csv_path=None):
        """Load data either from BigQuery or CSV file"""
        try:
            if csv_path:
                self.data = pd.read_csv(csv_path)
                print(f"Successfully loaded {len(self.data):,} records from CSV")
            else:
                from google.cloud import bigquery

                client = bigquery.Client(project=project_id or self.config['PROJECT_ID'])

                # Build the query using config parameters
                default_query = self._build_query()

                self.data = client.query(query or default_query).to_dataframe()
                print(f"Successfully loaded {len(self.data):,} records from BigQuery")

            # Convert date columns
            date_columns = ['first_purchase_date', 'last_purchase_date', 'cohort_month']
            for col in date_columns:
                if col in self.data.columns:
                    self.data[col] = pd.to_datetime(self.data[col])

            return self

        except Exception as e:
            print(f"Error loading data: {str(e)}")
            raise

    def process_data(self):
        """
        Main data processing pipeline that handles:
        - Basic cleaning
        - RFM calculation
        - Quality validation
        """
        if self.data is None:
            raise ValueError("No data loaded. Call load_data() first.")

        try:
            print(f"Starting data processing. Initial shape: {self.data.shape}")

            # 1. Basic cleaning
            self._clean_basic_data()

            # 2. Calculate RFM metrics
            self._calculate_rfm_metrics()

            # 3. Handle outliers and invalid values
            self._clean_monetary_values()
            self._clean_categorical_features()

            # 4. Validate data quality
            self._validate_data_quality()

            # 5. Generate final report
            self._generate_quality_report()

            return self

        except Exception as e:
            print(f"Error in data processing: {str(e)}")
            raise

    def _clean_basic_data(self):
        """Basic data cleaning operations"""
        print("Cleaning data...")
        initial_count = len(self.data)

        # Remove invalid records
        self.data = self.data[
            (self.data['frequency'] >= self.config['MIN_FREQUENCY']) &
            (self.data['total_revenue'] > 0) &
            (self.data['avg_transaction_value'] > 0)
        ]

        # Drop duplicates
        self.data = self.data.drop_duplicates(subset=['customer_id'])

        # Handle missing values
        self.data['discount_rate'] = self.data['discount_rate'].fillna(0)

        records_removed = initial_count - len(self.data)
        print(f"Records removed: {records_removed}")

    def _calculate_rfm_metrics(self):
        """Calculate Recency, Frequency, Monetary metrics"""
        current_date = pd.Timestamp.today()

        # Calculate or update RFM metrics
        if 'recency' not in self.data.columns:
            self.data['recency'] = (
                current_date - self.data['last_purchase_date']
            ).dt.days

        if 'customer_age_days' not in self.data.columns:
            self.data['customer_age_days'] = (
                current_date - self.data['first_purchase_date']
            ).dt.days

        if 'avg_transaction_value' not in self.data.columns:
            self.data['avg_transaction_value'] = (
                self.data['total_revenue'] / self.data['frequency']
            )

    def _clean_monetary_values(self):
        """Handle monetary value cleaning and outliers"""
        for col in ['frequency', 'monetary', 'avg_transaction_value', 'total_revenue']:
            if col in self.data.columns:
                # Remove negative values
                self.data = self.data[self.data[col] >= 0]

                # Handle outliers using IQR method
                Q1 = self.data[col].quantile(0.25)
                Q3 = self.data[col].quantile(0.75)
                IQR = Q3 - Q1
                lower_bound = Q1 - 3 * IQR
                upper_bound = Q3 + 3 * IQR

                self.data = self.data[
                    (self.data[col] >= lower_bound) &
                    (self.data[col] <= upper_bound)
                ]

    def _clean_categorical_features(self):
        """Clean and encode categorical features"""
        categorical_features = {
            'has_online_purchases': 0,
            'has_store_purchases': 0,
            'is_loyalty_member': 0,
            'loyalty_points': 0,
            'sms_active': 0,
            'email_active': 0
        }

        for col, default in categorical_features.items():
            if col in self.data.columns:
                self.data[col] = self.data[col].fillna(default).astype(int)

    def _validate_data_quality(self):
        """Validate data quality and create quality flags"""
        # Create quality flags
        self.quality_flags = pd.DataFrame(index=self.data.index)

        # Define validation rules
        validations = {
            'valid_frequency': self.data['frequency'] >= self.config['MIN_FREQUENCY'],
            'valid_recency': self.data['recency'] >= 0,
            'valid_monetary': self.data['avg_transaction_value'] > 0,
            'valid_dates': self.data['last_purchase_date'] >= self.data['first_purchase_date']
        }

        # Apply validations
        for flag_name, condition in validations.items():
            self.quality_flags[flag_name] = condition

        # Overall validation
        self.quality_flags['overall_valid'] = self.quality_flags.all(axis=1)

    def _generate_quality_report(self):
        """Generate final data quality report"""
        report = {
            'record_count': len(self.data),
            'metrics': {
                'frequency_mean': self.data['frequency'].mean(),
                'recency_mean': self.data['recency'].mean(),
                'monetary_mean': self.data['avg_transaction_value'].mean()
            },
            'quality_flags': {
                col: self.quality_flags[col].mean() * 100
                for col in self.quality_flags.columns
            }
        }

        print("\nQuality Report:")
        print(f"Records processed: {report['record_count']:,}")
        print("\nKey Metrics:")
        for metric, value in report['metrics'].items():
            print(f"{metric}: {value:.2f}")
        print("\nQuality Flags (% passing):")
        for flag, pct in report['quality_flags'].items():
            print(f"{flag}: {pct:.1f}%")

    def get_processed_data(self):
        """Return the processed DataFrame"""
        return self.data.copy()

In [10]:
# Create processor instance
processor = CLVDataProcessor(dp_config)

# Load the query with parameters
processor._build_query()

# Load your DataFrame
processor.load_data()

# Process the data
processor.process_data()

# Get the processed data
processed_df = processor.get_processed_data()

# View first few rows
print("\nFirst few rows of the loaded data:")
print(processed_df.head())

# Display dataframe info
print("\nDataframe information:")
print(processed_df.info())

Successfully loaded 1,409,887 records from BigQuery
Starting data processing. Initial shape: (1409887, 23)
Cleaning data...
Records removed: 0

Quality Report:
Records processed: 1,200,823

Key Metrics:
frequency_mean: 1.92
recency_mean: 855.97
monetary_mean: 43.45

Quality Flags (% passing):
valid_frequency: 100.0%
valid_recency: 100.0%
valid_monetary: 100.0%
valid_dates: 100.0%
overall_valid: 100.0%

First few rows of the loaded data:
                               customer_id cohort_month  recency  frequency  \
3567  95e48406-b3f7-3d04-94a7-9c1a8759597e   2024-02-01      256          1   
3568  c9e404f3-3b9b-3731-ac35-b7fca2506e96   2024-02-01      256          1   
3569  6909ef74-cf90-374b-833e-0ec339c61050   2024-02-01      256          1   
3570  26794b35-c9e6-3346-a848-ee97599fd568   2024-02-01      256          1   
3571  039f159b-d60f-35ac-8df0-4e6c2100defb   2024-02-01      256          1   

      monetary  total_revenue  revenue_trend  avg_transaction_value  \
3567     16.0

In [11]:
# Re-index the processed_df and create a customer_index column
processed_df.reset_index(drop=True, inplace=True)
processed_df['customer_index'] = processed_df.index

# Create a new DataFrame with the specified columns
customer_lookup_df = processed_df[['customer_index', 'customer_id']].copy()

# Display the resulting DataFrame
print(customer_lookup_df)

# Drop the 'customer_id' field from processed_df
processed_df = processed_df.drop(columns=['customer_index', 'customer_id'])

# Display the updated processed_df structure
print(processed_df.info())

customer_lookup_df.head()

processed_df.head()

         customer_index                           customer_id
0                     0  95e48406-b3f7-3d04-94a7-9c1a8759597e
1                     1  c9e404f3-3b9b-3731-ac35-b7fca2506e96
2                     2  6909ef74-cf90-374b-833e-0ec339c61050
3                     3  26794b35-c9e6-3346-a848-ee97599fd568
4                     4  039f159b-d60f-35ac-8df0-4e6c2100defb
...                 ...                                   ...
1200818         1200818  4168e105-5a45-3caa-a67a-2d34db664364
1200819         1200819  07979f36-5756-3e1c-8a43-d87b5517815c
1200820         1200820  3fe1ec23-bd3e-3dde-b3d2-759f75b469ef
1200821         1200821  2943023b-a840-301b-9829-29fee52bee13
1200822         1200822  2af337d4-fb15-3a22-ae5e-6c97b347ab21

[1200823 rows x 2 columns]
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1200823 entries, 0 to 1200822
Data columns (total 22 columns):
 #   Column                  Non-Null Count    Dtype         
---  ------                  --------------    ----- 

Unnamed: 0,cohort_month,recency,frequency,monetary,total_revenue,revenue_trend,avg_transaction_value,first_purchase_date,last_purchase_date,customer_age_days,...,avg_interpurchase_days,has_online_purchases,has_store_purchases,total_discount_amount,avg_discount_amount,discount_rate,sms_active,email_active,is_loyalty_member,loyalty_points
0,2024-02-01,256,1,16.0,16.0,0.0,5.33,2024-02-28,2024-02-28,256,...,0.0,1,0,0.0,0.0,0.0,0,1,0,0
1,2024-02-01,256,1,48.0,48.0,0.0,48.0,2024-02-28,2024-02-28,256,...,0.0,0,1,0.0,0.0,0.0,0,1,1,480
2,2024-02-01,256,1,174.0,174.0,0.0,43.5,2024-02-28,2024-02-28,256,...,0.0,0,1,0.0,0.0,0.0,0,0,0,0
3,2024-02-01,256,1,23.22,23.22,0.0,5.81,2024-02-28,2024-02-28,256,...,0.0,1,0,2.58,0.65,0.1,0,0,1,232
4,2024-02-01,256,1,42.0,42.0,0.0,14.0,2024-02-28,2024-02-28,256,...,0.0,1,0,0.0,0.0,0.0,0,0,1,420


P(frequency, recency | parameters, covariates) = \
      Poisson(freq | exp(r + Xβ_r) * T) × \
      Exponential(rec | exp(α + Xβ_α))

In [12]:
from typing import Dict, Any, Optional, Tuple
import numpy as np
import pandas as pd
import torch
import gc
import psutil
import jax
import os
import sys
from datetime import datetime
import traceback

class HierarchicalCLVSystem:
    """Hierarchical Customer Lifetime Value System"""
  ## Class Attributes
    # Configurable default segment configuration
    default_segment_config = {
        'use_rfm': True,              # RFM segmentation
        'rfm_bins': {
            'frequency': 4,           # Number of bins for frequency
            'recency': 4,             # Number of bins for recency
            'monetary': 4             # Number of bins for monetary
        },

        # Engagement Segmentation
        'use_engagement': True,        # Engagement metrics
        'engagement_bins': 4,          # Number of engagement level bins
        'engagement_metrics': [        # Which metrics to use
            'sms_active',
            'email_active',
            'is_loyalty_member'
        ],

        # Channel Segmentation
        'use_channel': True,           # Channel preference

        # Loyalty Segmentation
        'use_loyalty': True,           # Loyalty segments
        'loyalty_bins': 4,             # Number of loyalty level bins

        # Cohort Segmentation
        'use_cohorts': True,          # Cohort grouping
        'cohort_type': 'quarterly',     # 'monthly' or 'quarterly'

        # Combined Segmentation
        'use_combined': True,         # Whether to create combined segments

        # Segment Size Controls
        'min_segment_size': 200,      # Minimum customers per segment
        'merge_small_segments': True  # Whether to merge small segments into 'other'
    }

    # Core data requirements
    required_columns = {
        'transaction': [
            'frequency', 'recency', 'customer_age_days',
            'monetary', 'avg_transaction_value'
        ],
        'customer': [
            'cohort_month', 'distinct_categories', 'distinct_brands',
            'avg_interpurchase_days', 'has_online_purchases', 'has_store_purchases'
        ],
        'engagement': [
            'sms_active', 'email_active', 'is_loyalty_member', 'loyalty_points'
        ]
    }

    segment_configs = {
        'transaction': {
            'frequency': {
                'n_bins': 4,
                'labels': ['low', 'medium', 'high']
            },
            'recency': {
                'n_bins': 4,
                'labels': ['recent', 'mid', 'old']
            },
            'monetary': {
                'n_bins': 4,
                'labels': ['low', 'medium', 'high']
            }
        },
        'engagement': {
            'metrics': ['sms_active', 'email_active', 'is_loyalty_member'],
            'loyalty_points': {
                'n_bins': 4,
                'labels': ['low', 'medium', 'high']
            }
        }
    }

    # Initialization
    def __init__(self, config: Dict[str, Any], segment_config: Optional[Dict] = None):
        """Initialize system with configuration

        Parameters:
        -----------
        config : Dict[str, Any]
            Configuration dictionary with all parameters
        segment_config : Optional[Dict]
            Optional separate segmentation config (falls back to config['segment_config'])
        """
        self.config = config
        self.segment_config = segment_config or config.get('segment_config', {})
        self.data = None
        self.model = None
        self.trace = None

        # Initialize segmentation bins dictionary
        self.segment_bins = {
            'rfm': {},
            'loyalty': {},
            'engagement': {}
        }

        # Store medians for later use
        self.medians = {}

        # Initialize tracking
        self.gpu_enabled = False
        self.convergence_history = []
        self.training_metrics = {
            'n_divergent': 0,
            'max_rhat': None,
            'min_ess': None
        }

        # Setup GPU if enabled
        if self.config.get('use_gpu', True):
            self._setup_gpu()

        # Setup monitoring if enabled
        if self.config.get('monitor_resources', True):
            self.setup_monitoring()

    def run_analysis(self, processed_df=None, custom_config=None):
        """Run complete CLV analysis pipeline with progress monitoring"""
        try:
            print("\n=== Starting CLV Analysis ===")

            if processed_df is None:
                raise ValueError("No data provided")

            # Store original data
            self.original_data = processed_df.copy()

            # 1. Update configurations if provided
            if custom_config:
                print("\n1. Updating Configuration...")
                self.config = self.config_manager.setup_config(custom_config)

            # 2. Pre-process data and create Segments
            try:
                print("\n2.1 Processing Data...")
                # Print initial data info
                print(f"Initial shape: {processed_df.shape}")
                print("Columns:", processed_df.columns.tolist())

                # Verify and preprocess
                self._verify_model_columns(processed_df)
                processed_data = self._preprocess_data(processed_df)
                if processed_data is None:
                    raise ValueError("Data preprocessing failed")

                print("\n2.2 Creating Segments...")
                segmented_data = self.create_segments(processed_data)
                if segmented_data is None:
                    raise ValueError("Segmentation failed")

                # Print segmentation summary
                print("\nSegmentation Summary:")
                for col in ['rfm_segment', 'engagement_level', 'loyalty_segment', 'cohort_segment']:
                    if col in segmented_data.columns:
                        print(f"\n{col} distribution:")
                        print(segmented_data[col].value_counts().head())

                # Sample N records from the segmented data
                sample_size = self.config.get('sample_size', 25000)
                sample_size = min(sample_size, len(segmented_data))
                sampled_data = segmented_data.sample(n=sample_size, random_state=42)
                sampled_data_shape = sampled_data.shape
                num_columns = min(10, sampled_data.shape[1])
                columns_to_sort = sampled_data.columns[:num_columns]
                sorted_data = sampled_data.sort_values(by=list(columns_to_sort))
                self.data = sorted_data

                # Save the sorted DataFrame to a CSV file (Optional)
                file_path = '/content/sample_sorted_data.csv'
                sorted_data.to_csv(file_path, index=True)

                print(f"\nSampled {sample_size} records from segmented data")

            except Exception as e:
                print(f"Data processing and segmentation error: {str(e)}")
                raise

            # 3. Batch Data Processing
            try:
                # Process in batches
                print("\n3. Processing Batches...")
                self.data = self.process_batches(self.data)
                if self.data is None:
                    raise ValueError("Batch processing failed")

                # Print data quality metrics
                print("\nProcessed Data Summary:")
                print(f"Final shape: {self.data.shape}")
                print("\nMissing values:")
                print(self.data.isnull().sum()[self.data.isnull().sum() > 0])

            except Exception as e:
                print(f"Batch processing error: {str(e)}")
                raise

            # 4. Build and Train Model
            try:
                print("\n5. Building Model...")
                self.build_model()
                if self.model is None:
                    raise ValueError("Model building failed")

                print("\n6. Training Model...")
                # Setup GPU if enabled
                if self.config.get('use_gpu', False):
                    print("\nOptimizing GPU resources...")
                    self._setup_gpu()
                    self.optimize_gpu_memory()

                # Train model
                self.train_model()
                if self.trace is None:
                    raise ValueError("Model training failed")

                # Check convergence
                convergence_ok = self._check_convergence()
                if not convergence_ok:
                    print("\nWarning: Model may not have converged properly")
                    print("Consider adjusting parameters or reducing model complexity")

            except Exception as e:
                print(f"Model building/training error: {str(e)}")
                raise

            # 5. Generate Results
            try:
                print("\n7. Generating Results...")
                results = self._generate_results()
                if results is None:
                    raise ValueError("Results generation failed")

                # Print key metrics
                print("\nKey Model Metrics:")
                if 'metrics' in results:
                    for metric, value in results['metrics'].items():
                        if isinstance(value, dict):  # Handle nested dictionaries
                            print(f"\n{metric}:")
                            for k, v in value.items():
                                if isinstance(v, (int, float)):
                                    print(f"  {k}: {v:.4f}")  # Format as float
                                else:
                                    print(f"  {k}: {v}")  # Print as string
                        else:
                            if isinstance(value, (int, float)):
                                print(f"\n{metric}: {value:.4f}")  # Format as float
                            else:
                                print(f"\n{metric}: {value}")  # Print as string

            except Exception as e:
                print(f"Results generation error: {str(e)}")
                raise

            # 6. Save Model
            try:
                print("\n8. Saving Model...")
                # Check model status before saving
                is_ready, message = self.check_model_status(self)
                if is_ready:
                    saved_path = self.save_trained_model(prefix='clv_model')
                    print(f"Model saved to: {saved_path}")
                else:
                    print(f"Warning: Model not saved - {message}")

            except Exception as e:
                print(f"Model saving error: {str(e)}")
                print("Continuing without saving...")

            print("\n=== Analysis Complete ===")
            return self, results

        except Exception as e:
            print(f"\nAnalysis pipeline error: {str(e)}")
            import traceback
            print("\nDetailed error:")
            print(traceback.format_exc())
            return None, None

    # Check model status before trying to save:
    @staticmethod
    def check_model_status(model):
        """Check if model is properly trained and ready to save"""
        if model is None:
            return False, "Model is None"

        if not hasattr(model, 'trace') or model.trace is None:
            return False, "No trace available - model may not have trained successfully"

        if not hasattr(model, 'data') or model.data is None:
            return False, "No data available - model may not have processed data"

        return True, "Model ready to save"

    # 2. Data processing methods

    def _preprocess_data(self, processed_df):
        """Preprocess data with proper type conversion"""
        try:

            df = processed_df.copy()

            # Convert data types
            numeric_cols = (
                self.required_columns['transaction'] +
                ['distinct_categories', 'distinct_brands', 'loyalty_points']
            )

            bool_cols = (
                self.required_columns['engagement'][:-1] +  # Exclude loyalty_points
                ['has_online_purchases', 'has_store_purchases']
            )

            # Convert numeric columns
            for col in numeric_cols:
                if col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors='coerce')

                    # Integer conversion for count columns
                    if col in ['frequency', 'distinct_categories', 'distinct_brands', 'loyalty_points']:
                        df[col] = df[col].fillna(0).astype('int32')
                    else:
                        df[col] = df[col].fillna(0).astype('float32')

            # Convert boolean columns
            for col in bool_cols:
                if col in df.columns:
                    df[col] = df[col].astype('int32')

            # Convert dates
            df['cohort_month'] = pd.to_datetime(df['cohort_month'])

            return df

        except Exception as e:
            print(f"Error preprocessing data: {str(e)}")
            raise

    def _verify_model_columns(self, processed_df):
        """Verify all required columns exist in DataFrame"""
        missing_cols = []
        for category, cols in self.required_columns.items():
            missing = [col for col in cols if col not in processed_df.columns]
            if missing:
                missing_cols.extend(missing)

        if missing_cols:
            raise ValueError(f"Missing required columns: {missing_cols}")

    def create_segments(self, df):
        """Create comprehensive customer segments including engagement"""
        try:
            df = df.copy()
            print(f"Creating segments for {len(df):,} customers")
            segment_components = []

            # 1. Create RFM Segments
            if self.segment_config['use_rfm']:
                df = self._create_rfm_segments(df)
                segment_components.append('rfm_segment')
                print(f"\nCreated {df['rfm_segment'].nunique():,} RFM segments")

            # 2. Create Engagement Segments
            if self.segment_config.get('use_engagement', True):
                df = self._create_engagement_segments(df)
                segment_components.append('engagement_level')
                print(f"\nCreated {df['engagement_level'].nunique():,} engagement segments")

            # 3. Create Loyalty Segments
            if self.segment_config.get('use_loyalty', True):
                df = self._create_loyalty_segments(df)
                segment_components.append('loyalty_segment')
                print(f"\nCreated {df['loyalty_segment'].nunique():,} loyalty segments")

            # 4. Create Channel Segments
            if self.segment_config.get('use_channel', True):
                df = self._determine_channel(df)
                segment_components.append('channel_segment')
                print(f"\nCreated {df['channel_segment'].nunique():,} channel segments")

            # 5. Create Combined Segments
            use_combined = self.segment_config.get('use_combined', True)
            if use_combined and len(segment_components) > 1:
                df['customer_segment'] = df[segment_components].apply(
                    lambda x: '_'.join(x.astype(str)), axis=1
                )

                # Handle small segments if enabled
                if self.segment_config.get('merge_small_segments', True):
                    segment_counts = df['customer_segment'].value_counts()
                    small_segments = segment_counts[
                        segment_counts < self.segment_config['min_segment_size']
                    ].index
                    if len(small_segments) > 0:
                        df.loc[df['customer_segment'].isin(small_segments), 'customer_segment'] = 'other'

            # 6. Create Cohort Groups
            if self.segment_config.get('use_cohorts', True):
                df = self._create_cohort_groups(df)
                print(f"\nCreated {df['cohort_segment'].nunique():,} cohort groups")

            # 7. Create final group indices
            grouping_column = 'cohort_segment' if self.segment_config.get('use_cohorts', True) else 'customer_segment'
            unique_groups = df[grouping_column].unique()
            group_mapping = {group: idx for idx, group in enumerate(unique_groups)}
            df['group_idx'] = df[grouping_column].map(group_mapping)

            # Store grouping info
            self.groups = unique_groups
            self.coords = {
                "group_idx": np.arange(len(unique_groups)),
                "group": unique_groups
            }

            print(f"\nCreated {len(unique_groups):,} total groups with indices")

            return df  # Return the segmented DataFrame

        except Exception as e:
            print(f"Segmentation error: {str(e)}")
            raise

    def _create_rfm_segments(self, df):
        """Create RFM segments with configurable bins"""
        try:
            self.segment_bins['rfm'] = {}

            # Create segments for each RFM component
            for metric in ['frequency', 'recency', 'monetary']:
                n_bins = self.segment_config['rfm_bins'][metric]

                # Get the data range for the current metric
                metric_min = df[metric].min()
                metric_max = df[metric].max()

                # Calculate the bin edges
                bin_edges = np.linspace(metric_min, metric_max, n_bins + 1)

                # Create labels based on the number of bins
                labels = [f"{metric.capitalize()} {i+1}" for i in range(n_bins)]

                # Create segments using pd.cut()
                segment_col = f"{metric[0].lower()}_segment"
                df[segment_col] = pd.cut(
                    df[metric],
                    bins=bin_edges,
                    labels=labels,
                    include_lowest=True
                )

                print(f"\n{segment_col} distribution:")
                print(df[segment_col].value_counts().sort_index())

            # Combine RFM segments
            df['rfm_segment'] = (
                df['r_segment'].astype(str) + '_' +
                df['f_segment'].astype(str) + '_' +
                df['m_segment'].astype(str)
            )

            return df

        except Exception as e:
            print(f"RFM segmentation error: {str(e)}")
            raise

    def _create_engagement_segments(self, df):
        """Create engagement segments based on customer behavior"""
        try:
            # 1. Binary engagement metrics
            engagement_score = 0
            metrics = self.segment_config['engagement_metrics']

            for metric in metrics:
                if metric in df.columns:
                    engagement_score += df[metric]

            # Create engagement level with proper bin handling
            try:
                print("\nEngagement score summary:")
                print(engagement_score.describe())

                # Create explicit bins based on the data
                n_bins = self.segment_config['engagement_bins']
                labels = ['low', 'high'] if n_bins == 2 else ['low', 'medium', 'high']

                df['engagement_level'] = pd.qcut(
                    engagement_score,
                    q=n_bins,
                    labels=labels,
                    duplicates='drop'
                )

                print("\nEngagement level distribution:")
                print(df['engagement_level'].value_counts().sort_index())

            except Exception as e:
                print(f"Error creating engagement levels: {str(e)}")
                # Fallback to binary segmentation
                median = engagement_score.median()
                df['engagement_level'] = np.where(
                    engagement_score > median,
                    'high', 'low'
                )
                print("Fell back to binary engagement segmentation")

            # Handle small segments if enabled
            if self.segment_config['merge_small_segments']:
                segment_counts = df['engagement_level'].value_counts()
                small_segments = segment_counts[
                    segment_counts < self.segment_config['min_segment_size']
                ].index
                if len(small_segments) > 0:
                    df.loc[df['engagement_level'].isin(small_segments), 'engagement_level'] = 'other'

            return df

        except Exception as e:
            print(f"Engagement segmentation error: {str(e)}")
            raise

    def _create_loyalty_segments(self, df):
        """Create loyalty segments with configurable bins"""
        try:
            if 'loyalty_points' in df.columns:
                try:
                    n_bins = self.segment_config['loyalty_bins']
                    labels = ['low', 'high'] if n_bins == 2 else ['low', 'medium', 'high']

                    # Calculate loyalty points bins
                    loyalty_stats = df['loyalty_points'].describe()
                    print("\nLoyalty points summary:")
                    print(loyalty_stats)

                    # Create segments using qcut for even distribution
                    df['loyalty_segment'] = pd.qcut(
                        df['loyalty_points'],
                        q=n_bins,
                        labels=labels,
                        duplicates='drop'
                    )

                    print("\nLoyalty segment distribution:")
                    print(df['loyalty_segment'].value_counts().sort_index())

                    # Handle small segments if enabled
                    if self.segment_config['merge_small_segments']:
                        segment_counts = df['loyalty_segment'].value_counts()
                        small_segments = segment_counts[
                            segment_counts < self.segment_config['min_segment_size']
                        ].index
                        if len(small_segments) > 0:
                            df.loc[df['loyalty_segment'].isin(small_segments), 'loyalty_segment'] = 'other'

                except Exception as e:
                    print(f"Error creating loyalty segments: {str(e)}")
                    # Fallback to binary segmentation
                    median = df['loyalty_points'].median()
                    df['loyalty_segment'] = np.where(
                        df['loyalty_points'] > median,
                        'high', 'low'
                    )
                    print("Fell back to binary loyalty segmentation")

            return df

        except Exception as e:
            print(f"Loyalty segmentation error: {str(e)}")
            raise

    def _create_cohort_groups(self, df):
        """Create cohort groups for hierarchical modeling"""
        try:
            # Extract cohort features
            df['cohort_year'] = df['cohort_month'].dt.year

            if self.segment_config['cohort_type'] == 'quarterly':
                df['cohort_period'] = df['cohort_month'].dt.quarter
                period_name = 'quarter'
            else:  # monthly
                df['cohort_period'] = df['cohort_month'].dt.month
                period_name = 'month'

            # Create cohort-segment combination
            if 'customer_segment' in df.columns:
                df['cohort_segment'] = (
                    df['cohort_year'].astype(str) + '_' +
                    df['cohort_period'].astype(str) + '_' +
                    df['customer_segment']
                )
            else:
                df['cohort_segment'] = (
                    df['cohort_year'].astype(str) + '_' +
                    df['cohort_period'].astype(str)
                )

            # Handle small cohorts if enabled
            if self.segment_config['merge_small_segments']:
                segment_counts = df['cohort_segment'].value_counts()
                small_segments = segment_counts[
                    segment_counts < self.segment_config['min_segment_size']
                ].index

                if len(small_segments) > 0:
                    # For cohorts, merge into nearest time period instead of 'other'
                    for small_seg in small_segments:
                        year, period, *rest = small_seg.split('_')
                        # Find closest cohort in time
                        nearby_cohorts = [
                            seg for seg in segment_counts.index
                            if seg not in small_segments and
                            seg.startswith(f"{year}_")
                        ]
                        if nearby_cohorts:
                            closest_cohort = min(nearby_cohorts, key=lambda x: abs(
                                int(x.split('_')[1]) - int(period)
                            ))
                            df.loc[df['cohort_segment'] == small_seg, 'cohort_segment'] = closest_cohort

            print(f"\nCreated {df['cohort_segment'].nunique():,} cohort groups")
            print(f"\nCohort distribution by {period_name}:")
            print(df.groupby(['cohort_year', 'cohort_period']).size().unstack())

            # Store cohort information
            self.cohort_info = {
                'type': self.segment_config['cohort_type'],
                'years': sorted(df['cohort_year'].unique()),
                'periods': sorted(df['cohort_period'].unique())
            }

            return df

        except Exception as e:
            print(f"Cohort grouping error: {str(e)}")
            raise

    def _determine_channel(self, df):
        """Determine customer channel preference"""
        df['has_online_purchases'] = df['has_online_purchases'].astype(bool)
        df['has_store_purchases'] = df['has_store_purchases'].astype(bool)

        df['channel_segment'] = df.apply(
            lambda row: self._determine_channel(row['has_online_purchases'], row['has_store_purchases']),
            axis=1
        )

        return df

    def _determine_channel(self, online, store):
        """Determine customer channel preference"""
        if online and store:
            return 'Multichannel'
        elif online:
            return 'Online'
        elif store:
            return 'Store'
        return 'Unknown'

    def process_batches(self, df):
        """Process data in batches with proper type handling"""
        try:
            # First ensure correct dtypes
            numeric_cols = [
                'recency', 'frequency', 'monetary',
                'total_revenue', 'revenue_trend', 'avg_transaction_value',
                'customer_age_days', 'distinct_categories', 'distinct_brands',
                'avg_interpurchase_days', 'total_discount_amount',
                'avg_discount_amount', 'discount_rate'
            ]

            bool_cols = [
                'has_online_purchases', 'has_store_purchases',
                'sms_active', 'email_active', 'is_loyalty_member'
            ]

            int_cols = ['frequency', 'loyalty_points', 'group_idx']

            # Convert dtypes
            for col in numeric_cols:
                if col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
                    df[col] = df[col].astype('float32')

            for col in bool_cols:
                if col in df.columns:
                    df[col] = df[col].astype('int32')

            for col in int_cols:
                if col in df.columns:
                    df[col] = df[col].astype('int32')

            # Process in batches
            batch_size = self.config['batch_size']
            n_batches = len(df) // batch_size + (1 if len(df) % batch_size != 0 else 0)

            print(f"\nProcessing {len(df):,} records in {n_batches:,} batches")

            processed_batches = []
            for i in range(n_batches):
                start_idx = i * batch_size
                end_idx = min((i + 1) * batch_size, len(df))

                print(f"Processing batch {i+1}/{n_batches}")
                batch = df.iloc[start_idx:end_idx].copy()

                # Process batch
                processed_batch = self.process_batch(batch)
                if processed_batch is not None:
                    processed_batches.append(processed_batch)

                # Clear batch from memory
                del batch
                gc.collect()

            # Concatenate processed batches
            result_df = pd.concat(processed_batches, ignore_index=True)

            # Clear processed batches from memory
            del processed_batches
            gc.collect()

            # Final type check
            print("\nChecking final data types...")
            for col in result_df.columns:
                print(f"{col}: {result_df[col].dtype}")

            return result_df

        except Exception as e:
            print(f"Batch processing error: {str(e)}")
            raise

    def process_batch(self, batch):
        """Process a single batch with outlier handling and cleaning"""
        try:
            # Basic cleaning
            batch = batch[batch['monetary'] > 0]

            # Handle outliers
            for col in ['monetary', 'frequency', 'avg_transaction_value']:
                if col in batch.columns:
                    Q1 = batch[col].quantile(0.25)
                    Q3 = batch[col].quantile(0.75)
                    IQR = Q3 - Q1

                    lower_bound = Q1 - 3 * IQR
                    upper_bound = Q3 + 3 * IQR

                    # Cap values instead of removing
                    batch[col] = batch[col].clip(lower=lower_bound, upper=upper_bound)

            # Additional cleaning steps
            if 'avg_interpurchase_days' in batch.columns:
                batch['avg_interpurchase_days'] = batch['avg_interpurchase_days'].clip(lower=0)

            if 'recency' in batch.columns:
                batch['recency'] = batch['recency'].clip(lower=0)

            if 'customer_age_days' in batch.columns:
                batch['customer_age_days'] = batch['customer_age_days'].clip(lower=1)

            # Ensure proper relationships
            batch = batch[
                (batch['customer_age_days'] >= batch['recency']) &
                (batch['monetary'] >= 0) &
                (batch['frequency'] >= 1)
            ]

            return batch

        except Exception as e:
            print(f"Batch processing error: {str(e)}")
            return None

    # 4. Model Building Methods

    def build_model(self):
        """Build hierarchical BG/NBD model with covariates"""
        try:
            if self.data is None:
                raise ValueError("No data available. Run data processing first.")

            print("\nBuilding model...")
            print("Preparing model data...")

            # 1. Prepare model data
            self.model_data = self._prepare_model_data()

            # 2. Build PyMC model
            print("Building PyMC model...")
            with pm.Model() as self.model:
                # Add priors
                self._add_hierarchical_priors()

                # Add covariate effects
                self._add_covariate_effects()

                # Add likelihood
                x_tensor = pt.as_tensor_variable(self.model_data['frequency'])
                t_x_tensor = pt.as_tensor_variable(self.model_data['recency'])
                T_tensor = pt.as_tensor_variable(self.model_data['T'])
                x_zero_tensor = pt.as_tensor_variable(self.model_data['frequency'] == 0)

                # Store log likelihood for use during training
                self.model.log_likelihood = self.compute_log_likelihood(
                    x_tensor, t_x_tensor, T_tensor, x_zero_tensor)

            print("Model built successfully")
            return self

        except Exception as e:
            print(f"Model building error: {str(e)}")
            raise

    def _prepare_model_data(self):
        """Prepare data for PyMC model with covariates"""
        try:
            if self.data is None:
                raise ValueError("No data available. Run data processing first.")

            if 'group_idx' not in self.data.columns:
                raise ValueError("group_idx column missing. Run create_segments first.")

            # Convert main variables to float32
            model_data = {
                'frequency': self.data['frequency'].values.astype('float32'),
                'recency': self.data['recency'].values.astype('float32'),
                'T': self.data['customer_age_days'].values.astype('float32'),
                'monetary': self.data['monetary'].values.astype('float32'),
                'avg_transaction': self.data['avg_transaction_value'].values.astype('float32'),
                'group_idx': self.data['group_idx'].values.astype('int32')
            }

            # Prepare engagement covariates
            engagement_cols = [
                'sms_active', 'email_active', 'is_loyalty_member',
                'loyalty_points'
            ]

            # Standardize numeric covariates
            numeric_covs = ['loyalty_points', 'avg_interpurchase_days',
                          'distinct_categories', 'distinct_brands']

            self.scalers = {}  # Store scalers for later use/inverse transform
            for col in numeric_covs:
                if col in self.data.columns:
                    scaler = StandardScaler()
                    # Reshape to 2D array for sklearn (-1 means infer length)
                    reshaped_data = self.data[col].values.reshape(-1, 1)
                    # Fit scaler and transform data
                    # This standardizes the data to mean=0, std=1
                    scaled_data = scaler.fit_transform(reshaped_data)
                    # Convert to float32 and flatten back to 1D
                    model_data[f'{col}_scaled'] = scaled_data.astype('float32').flatten()
                    # Store scaler for this column
                    self.scalers[col] = scaler

                    if self.config.get('VERBOSE', False):
                        print(f"\nScaling {col}:")
                        print(f"Mean: {scaler.mean_[0]:.2f}")
                        print(f"Std: {scaler.scale_[0]:.2f}")

            # Binary covariates
            binary_covs = ['has_online_purchases', 'has_store_purchases',
                          'sms_active', 'email_active', 'is_loyalty_member']

            for col in binary_covs:
                if col in self.data.columns:
                    model_data[col] = self.data[col].values.astype('float32')

            print("\nPrepared model data:")
            for key, value in model_data.items():
                print(f"{key}: shape={value.shape}, dtype={value.dtype}")

            return model_data

        except Exception as e:
            print(f"Error preparing model data: {str(e)}")
            raise

    def _add_hierarchical_priors(self):
        """Add hierarchical priors to PyMC model"""
        try:
            n_groups = len(self.groups)

            # Group-level hyperpriors
            self.hyper_priors = {
                'r': {
                    'alpha': pm.Gamma('r_alpha', alpha=2, beta=1),
                    'beta': pm.Gamma('r_beta', alpha=2, beta=1)
                },
                'alpha': {
                    'mu': pm.Gamma('alpha_mu', alpha=2, beta=1),
                    'sigma': pm.HalfNormal('alpha_sigma', sigma=1)
                },
                's': {
                    'alpha': pm.Gamma('s_alpha', alpha=2, beta=1),
                    'beta': pm.Gamma('s_beta', alpha=2, beta=1)
                },
                'beta': {
                    'mu': pm.Gamma('beta_mu', alpha=2, beta=1),
                    'sigma': pm.HalfNormal('beta_sigma', sigma=1)
                }
            }

            # Group-level parameters
            self.group_params = {
                'r': pm.Gamma('r',
                    alpha=self.hyper_priors['r']['alpha'],
                    beta=self.hyper_priors['r']['beta'],
                    shape=n_groups
                ),
                'alpha': pm.Gamma('alpha',
                    mu=self.hyper_priors['alpha']['mu'],
                    sigma=self.hyper_priors['alpha']['sigma'],
                    shape=n_groups
                ),
                's': pm.Gamma('s',
                    alpha=self.hyper_priors['s']['alpha'],
                    beta=self.hyper_priors['s']['beta'],
                    shape=n_groups
                ),
                'beta': pm.Gamma('beta',
                    mu=self.hyper_priors['beta']['mu'],
                    sigma=self.hyper_priors['beta']['sigma'],
                    shape=n_groups
                )
            }

            print("\nAdded hierarchical priors")

        except Exception as e:
            print(f"Error adding hierarchical priors: {str(e)}")
            raise

    def _add_covariate_effects(self):
        """Add covariate effects to the model"""
        try:
            # Covariate coefficients for each parameter
            self.coef_priors = {}

            # Define covariate groups
            covariate_groups = {
                'transaction': ['avg_interpurchase_days_scaled'],
                'customer': ['distinct_categories_scaled', 'distinct_brands_scaled'],
                'channel': ['has_online_purchases', 'has_store_purchases'],
                'engagement': ['sms_active', 'email_active', 'is_loyalty_member',
                              'loyalty_points_scaled']
            }

            # Add coefficient priors for each parameter and covariate group
            for param in ['r', 'alpha']:
                self.coef_priors[param] = {}

                for group, covariates in covariate_groups.items():
                    for cov in covariates:
                        if cov in self.model_data:
                            coef_name = f"{param}_{cov}_coef"
                            self.coef_priors[param][cov] = pm.Normal(
                                coef_name,
                                mu=0,
                                sigma=1
                            )

            print("\nAdded covariate effects")

        except Exception as e:
            print(f"Error adding covariate effects: {str(e)}")
            raise
    def compute_log_likelihood(self, x, t_x, T, x_zero):
        """Compute log-likelihood for the hierarchical BG/NBD model"""
        try:
            import pytensor.tensor as pt

            # Convert inputs to tensors
            x = pt.as_tensor_variable(x)
            t_x = pt.as_tensor_variable(t_x)
            T = pt.as_tensor_variable(T)
            x_zero = pt.as_tensor_variable(x_zero)

            # Get group indices
            group_idx = self.model_data['group_idx']

            # Create distributions
            purchase_dist = pm.NegativeBinomial.dist(
                mu=pt.exp(self.group_params['r'][group_idx]),
                alpha=pt.exp(self.group_params['alpha'][group_idx])
            )
            dropout_dist = pm.Gamma.dist(
                alpha=pt.exp(self.group_params['s'][group_idx]),
                beta=pt.exp(self.group_params['beta'][group_idx])
            )

            # Calculate likelihoods using correct pm.logp syntax
            purchase_likelihood = pm.logp(rv=purchase_dist, value=x)
            dropout_likelihood = pm.logp(rv=dropout_dist, value=T - t_x)

            # Calculate alive probability
            p_alive = pt.exp(-pt.exp(self.group_params['alpha'][group_idx]) * T)

            # Combine components
            log_likelihood = (
                purchase_likelihood +
                dropout_likelihood * (1 - x_zero) +
                pt.log(p_alive + (1 - p_alive) * x_zero)
            )

            return log_likelihood

        except Exception as e:
            print(f"Error computing log-likelihood: {str(e)}")
            raise

    def _calculate_covariate_effects(self, param, group_idx):
        """Calculate combined covariate effects for a parameter"""
        try:
            # Start with group-level parameter
            effects = pt.log(self.group_params[param][group_idx])

            # Add covariate effects
            for cov, coef in self.coef_priors[param].items():
                effects = effects + coef * self.model_data[cov]

            return effects

        except Exception as e:
            print(f"Error calculating covariate effects: {str(e)}")
            raise

    # 5. Training and Convergence Methods

    def train_model(self):
        """Train model with GPU memory optimization and log likelihood computation"""
        try:
            if self.model is None:
                raise ValueError("No model built. Call build_model first.")

            # Optimize GPU memory before training
            if self.gpu_enabled:
                self.optimize_gpu_memory()

            print("\nTraining model...")
            with self.model:
                # Setup NUTS sampler with optimized parameters
                step = pm.NUTS(
                    target_accept=self.config['target_accept'],
                    max_treedepth=self.config['max_treedepth']
                )

                # Use gradient accumulation if enabled
                if self.config.get('gradient_accumulation', 0) > 1:
                    draws_per_step = self.config['mcmc_samples'] // self.config['gradient_accumulation']
                else:
                    draws_per_step = self.config['mcmc_samples']

                # Run sampling with memory monitoring
                try:
                    for i in range(self.config.get('gradient_accumulation', 1)):
                        if i > 0:
                            print(f"\nGradient accumulation step {i+1}")

                        # Monitor GPU memory before sampling
                        if self.gpu_enabled:
                            self._monitor_gpu()

                        self.trace = pm.sample(
                            draws=draws_per_step,
                            tune=self.config['mcmc_tune'],
                            chains=self.config['chains'],
                            cores=self.config['cores'],
                            random_seed=self.config['random_seed'],
                            step=step,
                            return_inferencedata=True,
                            progressbar=True,
                            compute_convergence_checks=True,
                            discard_tuned_samples=False,
                            mp_ctx='spawn',  # Keep stable multiprocessing
                            idata_kwargs={"log_likelihood": True}  # Enable log likelihood computation
                        )

                        # If log-likelihood not computed during sampling, compute it explicitly
                        if not hasattr(self.trace, 'log_likelihood'):
                            print("\nComputing log-likelihood...")
                            try:
                                with self.model:
                                    log_like = pm.compute_log_likelihood(
                                        self.trace,
                                        var_names=['frequency', 'recency', 'T']
                                    )
                                    self.trace.add_groups({'log_likelihood': log_like})
                            except Exception as ll_error:
                                print(f"Warning: Could not compute log-likelihood: {str(ll_error)}")
                                # Fallback to manual computation
                                log_like = self.compute_log_likelihood(
                                    self.model_data['frequency'],
                                    self.model_data['recency'],
                                    self.model_data['T'],
                                    self.model_data['frequency'] == 0
                                )
                                self.trace.add_groups({'log_likelihood': log_like})

                        # Check convergence metrics
                        converged = self._check_convergence()

                        # Clear memory between accumulation steps
                        if i < self.config.get('gradient_accumulation', 1) - 1:
                            self._clear_memory()
                            if self.gpu_enabled:
                                self._monitor_gpu()

                        # Early stopping if convergence is poor
                        if not converged:
                            print("\nWarning: Poor convergence detected")
                            if i > 0:  # Only warn if not first iteration
                                print("Consider reducing model complexity or increasing tuning steps")

                        # Monitor resources
                        if self.config.get('monitor_resources', True):
                            self._monitor_resources()

                except Exception as e:
                    print(f"\nSampling error: {str(e)}")
                    if hasattr(self, 'trace') and self.trace is not None:
                        print("\nPartial trace available - attempting to salvage results")
                        self._check_convergence()  # Check what we have
                    else:
                        raise

                finally:
                    # Final convergence check
                    if hasattr(self, 'trace') and self.trace is not None:
                        print("\nFinal convergence check:")
                        self._check_convergence()

                    # Clear memory
                    self._clear_memory()

                return self

        except Exception as e:
            print(f"\nTraining error: {str(e)}")
            return self  # Return self even on error to allow saving partial results

    def optimize_gpu_memory(self):
        """Optimize GPU memory usage"""
        try:
            import torch
            import gc

            if not self.gpu_enabled:
                return False

            # 1. Memory Management Functions
            def clear_gpu_memory():
                """Clear unused memory"""
                gc.collect()
                torch.cuda.empty_cache()
                if hasattr(torch.cuda, 'memory_summary'):
                    print("\nGPU Memory Summary:")
                    print(torch.cuda.memory_summary())

            def get_gpu_memory_usage():
                """Get current memory usage"""
                allocated = torch.cuda.memory_allocated() / 1024**2
                reserved = torch.cuda.memory_reserved() / 1024**2
                return allocated, reserved

            # 2. Set GPU Memory Optimization Settings
            torch.backends.cudnn.benchmark = True  # Optimize for fixed input sizes
            if hasattr(torch, 'backends') and hasattr(torch.backends, 'cuda') and \
              hasattr(torch.backends.cuda, 'matmul') and \
              hasattr(torch.backends.cuda.matmul, 'allow_tf32'):
                torch.backends.cuda.matmul.allow_tf32 = True  # Use TF32 for better performance

            # 3. Implement Memory-Efficient Settings in Config
            self.config.update({
                'gradient_accumulation': 4,  # Reduce memory by accumulating gradients
                'mixed_precision': True,     # Use mixed precision training
                'memory_efficient': True     # Enable memory-efficient options
            })

            # 4. Print Initial Memory State
            allocated, reserved = get_gpu_memory_usage()
            print("\nInitial GPU Memory State:")
            print(f"Allocated: {allocated:.1f} MB")
            print(f"Reserved:  {reserved:.1f} MB")

            # 5. Clear Memory
            clear_gpu_memory()

            # 6. Optimize Batch Size Based on GPU Memory
            try:
                total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**2
                available_memory = total_memory * 0.85  # Use 85% of total memory

                # Estimate memory per sample (adjust these based on your model)
                estimated_memory_per_sample = 0.5  # MB per sample
                optimal_batch_size = int(available_memory / estimated_memory_per_sample)

                # Cap batch size at reasonable limits
                optimal_batch_size = min(max(1000, optimal_batch_size), 10000)

                self.config['batch_size'] = optimal_batch_size
                print(f"\nOptimized batch size: {optimal_batch_size}")

            except Exception as e:
                print(f"Error optimizing batch size: {str(e)}")

            return True

        except Exception as e:
            print(f"GPU memory optimization error: {str(e)}")
            return False

    def optimize_resources(self):
        """Automatically optimize resources based on available memory"""
        try:
            import psutil
            import torch
            import gc

            # 1. System Memory Analysis
            system_memory = psutil.virtual_memory()
            total_ram_gb = system_memory.total / (1024**3)
            available_ram_gb = system_memory.available / (1024**3)

            print("\nSystem Memory Analysis:")
            print(f"Total RAM: {total_ram_gb:.1f} GB")
            print(f"Available RAM: {available_ram_gb:.1f} GB")

            # 2. GPU Memory Analysis
            if self.gpu_enabled:
                gpu_memory = {}
                for i in range(torch.cuda.device_count()):
                    total = torch.cuda.get_device_properties(i).total_memory / (1024**3)
                    reserved = torch.cuda.memory_reserved(i) / (1024**3)
                    allocated = torch.cuda.memory_allocated(i) / (1024**3)
                    available = total - allocated

                    gpu_memory[i] = {
                        'total': total,
                        'available': available,
                        'reserved': reserved,
                        'allocated': allocated
                    }

                    print(f"\nGPU {i} Memory:")
                    print(f"Total: {total:.1f} GB")
                    print(f"Available: {available:.1f} GB")

            # 3. Calculate Optimal Parameters
            def calculate_optimal_batch_size():
                """Calculate optimal batch size based on available memory"""
                if self.gpu_enabled:
                    # Use minimum available GPU memory across all GPUs
                    min_available_gpu = min(gpu['available'] for gpu in gpu_memory.values())
                    memory_for_batches = min_available_gpu * 0.8  # Use 80% of available GPU memory
                else:
                    memory_for_batches = available_ram_gb * 0.8

                # Estimate memory per sample based on data size
                sample_size_mb = (
                    self.data.memory_usage().sum() / len(self.data) / 1024**2
                ) if self.data is not None else 0.5

                optimal_batch_size = int(
                    (memory_for_batches * 1024) / sample_size_mb
                )

                # Apply reasonable bounds
                return min(max(1000, optimal_batch_size), 10000)

            def calculate_optimal_chains():
                """Calculate optimal number of chains based on CPU cores and memory"""
                cpu_count = psutil.cpu_count(logical=False)
                memory_based_chains = int(available_ram_gb / 20)  # Assume 20GB per chain

                if self.gpu_enabled:
                    gpu_based_chains = sum(
                        int(mem['available'] / 20) for mem in gpu_memory.values()
                    )
                    return min(cpu_count, memory_based_chains, gpu_based_chains, 4)

                return min(cpu_count, memory_based_chains, 4)

            # 4. Update Configuration
            optimal_batch_size = calculate_optimal_batch_size()
            optimal_chains = calculate_optimal_chains()

            memory_optimized_config = {
                'batch_size': optimal_batch_size,
                'chains': optimal_chains,
                'cores': optimal_chains,  # Match cores to chains
                'gradient_accumulation': max(1, int(4096 / optimal_batch_size)),
                'mixed_precision': True if self.gpu_enabled else False,

                # Memory thresholds
                'memory_warn_threshold': 0.85,
                'memory_critical_threshold': 0.95,

                # Automatic cleanup triggers
                'auto_cleanup_threshold': 0.9,
                'cleanup_interval': 100  # Clean every N batches
            }

            # 5. Update instance configuration
            self.config.update(memory_optimized_config)

            # 6. Print Optimization Summary
            print("\nResource Optimization Summary:")
            print(f"Optimal Batch Size: {optimal_batch_size}")
            print(f"Optimal Chains: {optimal_chains}")
            print(f"Gradient Accumulation Steps: {memory_optimized_config['gradient_accumulation']}")

            # 7. Setup Memory Monitoring
            def setup_memory_monitoring():
                """Setup automatic memory monitoring and cleanup"""
                def memory_monitor():
                    while True:
                        current_usage = (
                            torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated()
                            if self.gpu_enabled else
                            psutil.virtual_memory().percent / 100
                        )

                        if current_usage > self.config['memory_critical_threshold']:
                            print("\nCRITICAL: Memory usage too high! Forcing cleanup...")
                            self._clear_memory()
                        elif current_usage > self.config['memory_warn_threshold']:
                            print("\nWARNING: High memory usage detected")

                        time.sleep(5)  # Check every 5 seconds

                import threading
                monitor_thread = threading.Thread(
                    target=memory_monitor,
                    daemon=True
                )
                monitor_thread.start()

            setup_memory_monitoring()
            return True

        except Exception as e:
            print(f"Resource optimization error: {str(e)}")
            return False

    def _check_convergence(self):
        """Check MCMC convergence diagnostics with better error handling"""
        try:
            # Get R-hat statistics safely
            rhat_stats = pm.rhat(self.trace)
            print("\nR-hat statistics:")
            for var in rhat_stats.data_vars:
                rhat_value = rhat_stats[var].values
                if np.size(rhat_value) == 1:
                    print(f"{var}: {float(rhat_value):.3f}")
                else:
                    print(f"{var}: mean = {np.mean(rhat_value):.3f}, max = {np.max(rhat_value):.3f}")

            # Get effective sample size
            ess = pm.ess(self.trace)
            print("\nEffective sample sizes:")
            for var in ess.data_vars:
                ess_value = ess[var].values
                if np.size(ess_value) == 1:
                    print(f"{var}: {float(ess_value):.0f}")
                else:
                    print(f"{var}: mean = {np.mean(ess_value):.0f}, min = {np.min(ess_value):.0f}")

            # Check convergence criteria
            max_rhat = max(float(np.max(rhat_stats[var].values)) for var in rhat_stats.data_vars)
            min_ess = min(float(np.min(ess[var].values)) for var in ess.data_vars)

            print("\nConvergence Summary:")
            print(f"Maximum R-hat: {max_rhat:.3f} (should be < 1.1)")
            print(f"Minimum ESS: {min_ess:.0f} (should be > 400)")

            return max_rhat < 1.1 and min_ess > 400

        except Exception as e:
            print(f"\nError checking convergence: {str(e)}")
            return False

    def _clear_memory(self):
        """Clear memory between processing steps"""
        try:
            import gc

            # Force garbage collection
            gc.collect()

            # Clear JAX memory if GPU is enabled
            if hasattr(self, 'gpu_enabled') and self.gpu_enabled:
                try:
                    import jax
                    jax.clear_caches()
                except Exception as e:
                    print(f"Error clearing JAX memory: {str(e)}")

            print("\nMemory cleared")

        except Exception as e:
            print(f"Error clearing memory: {str(e)}")

    # 6. Results and Metrics Methods

    def _generate_results(self):
        """Generate comprehensive analysis results"""
        try:
            # Initialize results dictionary
            results = {
                'parameters': self._extract_parameters(),
                'metrics': self._calculate_metrics(),
                'diagnostics': self._get_diagnostics(),
                'metadata': {
                    'n_customers': len(self.data),
                    'n_groups': len(self.groups),
                    'timestamp': pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S'),
                    'model_type': 'Hierarchical BG/NBD',
                    'config': self.config
                }
            }

            # Add segmentation summary
            results['segmentation'] = {
                'rfm_segments': self.data['rfm_segment'].value_counts().to_dict(),
                'n_cohort_groups': len(self.groups)
            }

            # Add model performance metrics
            if hasattr(self, 'trace'):
                div = self.trace.sample_stats.diverging.sum()
                results['performance'] = {
                    'n_divergent': float(div),
                    'percent_divergent': float(div) / len(self.trace.sample_stats.diverging) * 100,
                    'n_chains': self.config.get('chains', 2),
                    'n_samples': self.config.get('mcmc_samples', 500)
                }

                # Add log-likelihood if available
                if hasattr(self.trace, 'log_likelihood'):
                    results['performance']['log_likelihood'] = float(self.trace.log_likelihood.sum())

                print("\nResults generated successfully")
                print(f"Contains {len(results['parameters'])} parameter sets")
                print(f"Tracked {len(results['metrics'])} metrics")
                print(f"Generated {len(results['diagnostics'])} diagnostic measures")

                # Print key metrics with proper type handling
                print("\nKey Model Metrics:")
                if 'metrics' in results:
                    for metric, value in results['metrics'].items():
                        if isinstance(value, dict):
                            print(f"\n{metric}:")
                            for k, v in value.items():
                                if isinstance(v, (int, float)):
                                    print(f"  {k}: {v:.4f}")
                                else:
                                    print(f"  {k}: {v}")
                        else:
                            if isinstance(value, (int, float)):
                                print(f"{metric}: {value:.4f}")
                            else:
                                print(f"{metric}: {value}")

            return results

        except Exception as e:
            print(f"Error generating results: {str(e)}")
            raise

    def _extract_parameters(self):
        """Extract model parameters from trace"""
        params = {}
        try:
            # Group-level parameters
            for param in ['r', 'alpha', 's', 'beta']:
                params[param] = {
                    'mean': self.trace.posterior[param].mean(dim=['chain', 'draw']).values,
                    'std': self.trace.posterior[param].std(dim=['chain', 'draw']).values,
                    'hdi_3%': pm.hdi(self.trace.posterior[param], hdi_prob=0.94).values
                }

            # Covariate coefficients
            for param in ['r', 'alpha']:
                for cov in self.coef_priors.get(param, {}):
                    coef_name = f"{param}_{cov}_coef"
                    if coef_name in self.trace.posterior:
                        params[coef_name] = {
                            'mean': float(self.trace.posterior[coef_name].mean()),
                            'std': float(self.trace.posterior[coef_name].std()),
                            'hdi_94%': pm.hdi(self.trace.posterior[coef_name], hdi_prob=0.94).values
                        }

            return params

        except Exception as e:
            print(f"Error extracting parameters: {str(e)}")
            raise

    def _calculate_metrics(self):
        """Calculate model performance metrics with robust log-likelihood handling"""
        metrics = {}
        try:
            # Ensure we have log-likelihood
            if hasattr(self.trace, 'log_likelihood'):
                try:
                    # Try to calculate WAIC
                    waic = pm.waic(self.trace, scale='deviance')
                    metrics['waic'] = float(waic.waic)
                    metrics['waic_se'] = float(waic.waic_se)
                except Exception as e:
                    print(f"Warning: Could not calculate WAIC: {str(e)}")
                    metrics['waic'] = "Not available"

                try:
                    # Try to calculate LOO
                    loo = pm.loo(self.trace, scale='deviance')
                    metrics['loo'] = float(loo.loo)
                    metrics['loo_se'] = float(loo.loo_se)
                except Exception as e:
                    print(f"Warning: Could not calculate LOO: {str(e)}")
                    metrics['loo'] = "Not available"

                # Calculate log likelihood
                try:
                    log_like = self.trace.log_likelihood.values
                    metrics['log_likelihood'] = float(np.mean(log_like))
                    metrics['log_likelihood_std'] = float(np.std(log_like))
                except Exception as e:
                    print(f"Warning: Could not process log likelihood: {str(e)}")
                    metrics['log_likelihood'] = "Not available"
            else:
                print("Warning: No log likelihood found in trace")
                metrics['log_likelihood'] = "Not available"
                metrics['waic'] = "Not available"
                metrics['loo'] = "Not available"

            return metrics

        except Exception as e:
            print(f"Error calculating metrics: {str(e)}")
            return {'error': str(e)}

    def _get_diagnostics(self):
        """Get model diagnostics"""
        diagnostics = {}
        try:
            # MCMC diagnostics
            diagnostics['r_hat'] = pm.rhat(self.trace).to_dict()
            diagnostics['ess'] = pm.ess(self.trace).to_dict()
            diagnostics['mcse'] = pm.mcse(self.trace).to_dict()

            # Parameter diagnostics
            for param in ['r', 'alpha', 's', 'beta']:
                # Compute diagnostics for the specific parameter
                r_hat_result = pm.rhat(self.trace.posterior[param])
                ess_result = pm.ess(self.trace.posterior[param])

                # Ensure r_hat_result and ess_result are xarray.Dataset
                if isinstance(r_hat_result, xr.Dataset) and isinstance(ess_result, xr.Dataset):
                    # Access the specific data variable within the Dataset
                    r_hat_data = r_hat_result[param].values
                    ess_data = ess_result[param].values
                else:
                    raise ValueError(f"Unexpected format for r_hat or ess: {r_hat_result}, {ess_result}")

                # Compute diagnostics
                diagnostics[f'{param}_diagnostics'] = {
                    'mean_r_hat': float(np.mean(r_hat_data)),  # Compute mean safely
                    'min_ess': float(np.min(ess_data))        # Compute min safely
                }

            return diagnostics

        except Exception as e:
            print(f"Error getting diagnostics: {str(e)}")
            raise

    # Save method with error handling
    def save_trained_model(self, prefix='clv_model'):
        """Save trained model with better error handling"""
        try:
            import pickle
            import os
            from datetime import datetime

            # Check model status before saving
            is_ready, message = HierarchicalCLVSystem.check_model_status(self)
            if not is_ready:
                raise ValueError(f"Cannot save model: {message}")

            # Create directory
            os.makedirs('trained_models', exist_ok=True)

            # Generate timestamp
            timestamp = datetime.now().strftime('%Y%m%d_%H%M')
            base_path = f'trained_models/{prefix}_{timestamp}'

            # Save what we can
            model_data = {
                'config': self.config,
                'segment_bins': self.segment_bins,
                'segments': self.segments if hasattr(self, 'segments') else None,
                'groups': self.groups if hasattr(self, 'groups') else None,
                'coords': self.coords if hasattr(self, 'coords') else None,
                'scalers': self.scalers if hasattr(self, 'scalers') else None,
                'metadata': {
                    'save_time': timestamp,
                    'model_type': 'Hierarchical BG/NBD',
                    'convergence_warning': not self._check_convergence() if hasattr(self, 'trace') else True
                }
            }

            # Save model data
            data_path = f'{base_path}_data.pkl'
            with open(data_path, 'wb') as f:
                pickle.dump(model_data, f)
            print(f"\nModel data saved to: {data_path}")

            # Try to save trace if it exists
            if hasattr(self, 'trace') and self.trace is not None:
                try:
                    trace_path = f'{base_path}_trace.nc'  # Note the .nc extension
                    self.trace.to_netcdf(trace_path)  # Use to_netcdf instead of pm.save_trace
                    print(f"Trace saved to: {trace_path}")
                except Exception as e:
                    print(f"Warning: Could not save trace: {str(e)}")

            return base_path

        except Exception as e:
            print(f"Error saving model: {str(e)}")
            raise

    # 7. Resource Management Methods

    def _setup_gpu(self):
        """Configure GPU environment"""
        try:
            import jax
            import numpyro

            # Print JAX config
            print("\nJAX Configuration:")
            print(f"JAX version: {jax.__version__}")
            print(f"Backend: {jax.config.x64_enabled}")

            # Configure JAX
            jax.config.update('jax_platform_name', 'gpu')
            jax.config.update('jax_enable_x64', True)  # Enable 64-bit precision if needed

            # Set device count for NumPyro
            numpyro.set_host_device_count(self.config['chains'])

            # Check available devices
            devices = jax.devices()
            print(f"\nGPU Setup:")
            print(f"Available devices: {len(devices)}")

            # Initialize each device
            for device in devices:
                try:
                    # Test device
                    jax.device_get(jax.device_put(1, device=device))
                    print(f"Device {device.id} initialized successfully")

                    # Try to get memory info
                    if hasattr(device, 'memory_stats'):
                        stats = device.memory_stats()
                        if stats:
                            memory_mb = stats.get('bytes_limit', 0) / (1024 * 1024)
                            print(f"Device memory: {memory_mb:.0f} MB")

                except Exception as e:
                    print(f"Warning: Could not initialize device {device.id}: {str(e)}")

            self.gpu_enabled = True
            print("\nGPU setup completed successfully")

        except ImportError as e:
            print(f"\nJAX/NumPyro import error: {str(e)}")
            print("Falling back to CPU.")
            self.config['USE_GPU'] = False
            self.gpu_enabled = False
        except Exception as e:
            print(f"\nGPU setup error: {str(e)}")
            print("Falling back to CPU.")
            self.config['USE_GPU'] = False
            self.gpu_enabled = False


    def setup_monitoring(self):
        """Setup resource monitoring"""
        if not self.config.get('monitor_resources', True):
            return

        def monitor_loop():
            """Monitoring loop to track resource usage"""
            while True:
                self._monitor_resources()
                time.sleep(self.config.get('monitor_interval', 300))  # Adjust the interval as needed

        # Start monitoring thread
        monitor_thread = threading.Thread(target=monitor_loop, daemon=True)
        monitor_thread.start()
        print("\nResource monitoring started")

    def _monitor_resources(self):
        """Monitor system resources"""
        try:
            import psutil
            process = psutil.Process()

            # Get memory info
            memory_info = process.memory_info()
            memory_used_gb = memory_info.rss / (1024 * 1024 * 1024)  # Convert to GB
            memory_percent = process.memory_percent()

            # Get CPU info
            cpu_percent = process.cpu_percent(interval=5)  # Increase the interval to 5 seconds

            print("\nResource Usage:")
            print(f"Memory Used: {memory_used_gb:.2f} GB ({memory_percent:.1f}%)")
            print(f"CPU Usage: {cpu_percent}%")
            print(f"Active Threads: {process.num_threads()}")

            # Get system-wide memory info
            system = psutil.virtual_memory()
            print(f"System Memory: {system.percent}% used")
            print(f"Available Memory: {system.available / (1024 * 1024 * 1024):.2f} GB")

            # Monitor GPU if enabled
            if hasattr(self, 'gpu_enabled') and self.gpu_enabled:
                self._monitor_gpu()

        except Exception as e:
            print(f"Monitoring error: {str(e)}")

    def _monitor_gpu(self):
        """Monitor GPU memory usage and status"""
        if not hasattr(self, 'gpu_enabled') or not self.gpu_enabled:
            return

        try:
            import jax
            devices = jax.devices()

            for device in devices:
                try:
                    # Get device info
                    device_info = jax.device_get(jax.device_put(1, device=device))
                    print(f"\nGPU {device.id} status: Active")

                    # Try to get memory info if available
                    if hasattr(device, 'memory_stats'):
                        memory_stats = device.memory_stats()
                        if memory_stats:
                            used_memory = memory_stats.get('bytes_in_use', 0) / (1024 * 1024)  # Convert to MB
                            total_memory = memory_stats.get('bytes_limit', 0) / (1024 * 1024)  # Convert to MB
                            print(f"Memory Usage: {used_memory:.1f}MB / {total_memory:.1f}MB")

                except Exception as e:
                    print(f"Could not query GPU {device.id}: {str(e)}")

        except Exception as e:
            print(f"GPU monitoring error: {str(e)}")

    def add_sampling_extensions(self):
        """Add all sampling extensions to the model"""
        # Add parallel tempering
        self.add_parallel_tempering()

        # Create adaptive Metropolis sampler
        self.create_adaptive_metropolis()

        # Add diagnostics
        self.add_diagnostics()

    def run_sampler_comparison(self, n_samples=1000):
        """Run comparison of different samplers"""
        samplers_to_test = [
            SamplerConfig("nuts"),
            SamplerConfig("metropolis"),
            SamplerConfig("slice"),
            SamplerConfig("hmc")
        ]

        results = self.diagnostics.compare_samplers(
            self, samplers_to_test, n_samples
        )

        print("\nSampler Comparison Results:")
        print(results)

        return results

    def plot_diagnostics(self):
        """Plot diagnostic visualizations"""
        if not hasattr(self, 'diagnostics'):
            print("No diagnostics available. Run sampler comparison first.")
            return

        # Create diagnostic plots
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))

        # Plot 1: Sampling times
        times = [m['sampling_time'] for m in self.diagnostics.performance_metrics.values()]
        axes[0,0].bar(self.diagnostics.performance_metrics.keys(), times)
        axes[0,0].set_title('Sampling Times')
        axes[0,0].set_xticklabels(self.diagnostics.performance_metrics.keys(), rotation=45)

        # Plot 2: ESS
        ess = [m['mean_ess'] for m in self.diagnostics.performance_metrics.values()]
        axes[0,1].bar(self.diagnostics.performance_metrics.keys(), ess)
        axes[0,1].set_title('Mean Effective Sample Size')
        axes[0,1].set_xticklabels(self.diagnostics.performance_metrics.keys(), rotation=45)

        # Plot 3: Efficiency
        efficiency = [m['mean_ess']/m['sampling_time'] for m in self.diagnostics.performance_metrics.values()]
        axes[1,0].bar(self.diagnostics.performance_metrics.keys(), efficiency)
        axes[1,0].set_title('Sampling Efficiency')
        axes[1,0].set_xticklabels(self.diagnostics.performance_metrics.keys(), rotation=45)

        # Plot 4: R-hat values
        r_hats = [m['max_r_hat'] for m in self.diagnostics.performance_metrics.values()]
        axes[1,1].bar(self.diagnostics.performance_metrics.keys(), r_hats)
        axes[1,1].set_title('Maximum R-hat Values')
        axes[1,1].set_xticklabels(self.diagnostics.performance_metrics.keys(), rotation=45)

        plt.tight_layout()
        plt.show()

In [13]:
# Training configuration
import dataclasses
import json
import os
import psutil
import torch
from datetime import datetime
from typing import Dict, Any, Optional

@dataclasses.dataclass
class CLVConfig:
    """Unified configuration for CLV analysis"""
    # Experiment tracking
    experiment_id: str = None
    override_memory_config: bool = False
    sample_size: int = 30000
    random_seed: int = 42

    # Hardware Settings
    use_gpu: bool = True
    cores: int = 36
    max_memory_gb: int = 150
    device_batch_size: int = 512

    # MCMC Settings
    batch_size: int = 10000
    mcmc_samples: int = 500
    mcmc_tune: int = 200
    target_accept: float = 0.995
    max_treedepth: int = 15 # TODO: 25, which is quite high. Review trace plots to see if the sampler consistently hits this limit, and increase it slightly if needed. If divergences persist despite higher depth, focus on reparameterization.
    chains: int = 2
    thinning: int = 0

    # Memory Optimization
    gradient_accumulation: int = 2
    mixed_precision: bool = True
    memory_efficient: bool = True
    memory_warn_threshold: float = 0.85
    memory_critical_threshold: float = 0.95
    auto_cleanup_threshold: float = 0.90
    cleanup_interval: int = 100

    # Training Controls
    early_stop_patience: int = 25
    early_stop_delta: float = 0.02
    min_ess: int = 400
    max_rhat: float = 1.05
    burn_in: int = 500

    # Monitoring
    monitor_resources: bool = True
    monitor_interval: int = 120

    # Segmentation Settings
    segment_config: Dict = dataclasses.field(default_factory=lambda: {
        'use_rfm': True,
        'rfm_bins': {'frequency': 4, 'recency': 4, 'monetary': 4},
        'use_engagement': False,
        'use_channel': False,
        'use_loyalty': False,
        'use_cohorts': True,
        'cohort_type': 'quarterly',
        'merge_small_segments': True,
        'min_segment_size': 200,
        'use_combined': True
    })

    def __post_init__(self):
        """Generate experiment ID if not provided"""
        if not self.experiment_id:
            self.experiment_id = f"clv_experiment_{datetime.now().strftime('%y%m%d%H%M')}"

class ConfigManager:
    """Manages CLV configurations and experiment tracking"""
    def __init__(self, base_config: Optional[CLVConfig] = None):
        self.base_config = base_config or CLVConfig()
        self.config_history = []

    def setup_config(self, custom_config: Optional[Dict] = None) -> Dict[str, Any]:
        """Setup configuration with optional custom overrides"""
        # Start with base config
        config = dataclasses.asdict(self.base_config)

        # Apply memory optimization if needed
        if not (custom_config and custom_config.get('override_memory_config', False)):
            memory_config = self._optimize_memory_settings()
            config.update(memory_config)

        # Apply custom overrides
        if custom_config:
            config.update(custom_config)

        # Save configuration
        self._save_config(config)

        return config

    def _optimize_memory_settings(self) -> Dict[str, Any]:
        """Get optimized memory settings based on system resources"""
        memory_config = {}
        available_memory = psutil.virtual_memory().available / (1024**3)

        # Adjust batch size based on available memory
        memory_config['device_batch_size'] = min(
            512, int(available_memory * 100)
        )

        # Adjust chain count based on GPU availability
        if torch.cuda.is_available():
            memory_config['chains'] = min(
                torch.cuda.device_count(),
                self.base_config.chains
            )

        return memory_config

    def _save_config(self, config: Dict[str, Any]):
        """Save configuration to JSON file and history"""
        # Add to history
        self.config_history.append({
            'timestamp': datetime.now().isoformat(),
            'config': config
        })

        # Save to file
        filename = f"active_training_config_{config['experiment_id']}.json"
        os.makedirs('configs', exist_ok=True)
        with open(os.path.join('configs', filename), 'w') as f:
            json.dump(config, f, indent=2)

    @staticmethod
    def load_config(experiment_id: str) -> CLVConfig:
        """Load configuration from saved experiment"""
        config_path = os.path.join('configs', f"active_training_config_{experiment_id}.json")
        with open(config_path, 'r') as f:
            config_dict = json.load(f)
        return CLVConfig(**config_dict)

In [14]:
def run_clv_analysis(
    processed_df: pd.DataFrame,
    config: Optional[CLVConfig] = None
) -> Tuple[Any, dict]:
    """Run CLV analysis with unified configuration"""
    try:
        # Initialize config manager
        config_manager = ConfigManager(config or CLVConfig())
        active_config = config_manager.setup_config()

        print(f"\nActive experiment ID: {active_config['experiment_id']}")

        # 1. Initialize system
        print("\n=== Initializing CLV System ===")
        clv_system = HierarchicalCLVSystem(
            config=active_config,
            segment_config=active_config['segment_config']
        )

        # 2. Data Validation
        print("\n=== Validating Input Data ===")
        print(f"Input shape: {processed_df.shape}")
        print(f"Memory usage: {processed_df.memory_usage().sum() / 1024**2:.1f} MB")

        # Check for missing values
        missing = processed_df.isnull().sum()
        if missing.any():
            print("\nMissing values detected:")
            print(missing[missing > 0])

        # 3. Create Segments
        print("\n=== Creating Segments ===")
        segmented_df = clv_system.create_segments(processed_df)

        # 4. Sample data if requested
        sample_size = active_config.get('sample_size')
        if sample_size:
            sampled_df = segmented_df.sample(n=sample_size, random_state=42)
            print(f"\nSampled {sample_size} records from segmented dataset")
        else:
            sampled_df = segmented_df

        # 5. Run Analysis
        print("\n=== Running Analysis ===")
        model, results = clv_system.run_analysis(processed_df=sampled_df)

        if model is None or results is None:
            raise ValueError("Analysis failed - check error messages above")

        # 5. Model Validation
        print("\n=== Validating Model ===")
        # Call the static method correctly
        is_ready, message = HierarchicalCLVSystem.check_model_status(model)
        if not is_ready:
            raise ValueError(f"Model validation failed: {message}")

        # 6. Save Model
        print("\n=== Saving Model ===")
        try:
            if is_ready:
                timestamp = datetime.now().strftime('%Y%m%d_%H%M')
                saved_path = model.save_trained_model(prefix=f'clv_model_{timestamp}')

                # Verify saved files
                data_path = f"{saved_path}_data.pkl"
                trace_path = f"{saved_path}_trace"

                files_exist = all(os.path.exists(p) for p in [data_path, trace_path])
                if files_exist:
                    print("\nSave verification:")
                    print(f"Data file: {os.path.getsize(data_path) / 1024**2:.1f}MB")
                    print(f"Trace file: {os.path.getsize(trace_path) / 1024**2:.1f}MB")
                else:
                    print("Warning: Some save files are missing")
            else:
                print(f"Warning: Model not saved - {message}")

        except Exception as e:
            print(f"Model saving error: {str(e)}")
            print("Continuing without saving...")

        # 7. Results Summary
        print("\n=== Analysis Results ===")
        if results:
            print("\nModel Overview:")
            print(f"Customers analyzed: {results['metadata']['n_customers']:,}")
            print(f"Segments created: {results['metadata']['n_groups']:,}")

            if 'segmentation' in results:
                print("\nTop Segments:")
                for segment, count in list(results['segmentation']['rfm_segments'].items())[:5]:
                    print(f"{segment}: {count:,} customers")

            print("\nModel Metrics:")
            if 'metrics' in results:
                for metric_name, metric_value in results['metrics'].items():
                    # Handle dictionary-type metrics
                    if isinstance(metric_value, dict):
                        print(f"\n{metric_name}:")
                        for k, v in metric_value.items():
                            if isinstance(v, (int, float)):
                                print(f"  {k}: {v:.4f}")
                            else:
                                print(f"  {k}: {v}")
                    # Handle scalar metrics
                    else:
                        if isinstance(metric_value, (int, float)):
                            print(f"{metric_name}: {metric_value:.4f}")
                        else:
                            print(f"{metric_name}: {metric_value}")

            print("\nConvergence Diagnostics:")
            if 'diagnostics' in results:
                diag = results['diagnostics']
                # Print R-hat statistics
                if 'r_hat' in diag:
                    try:
                        valid_rhats = [float(rhat) for rhat in diag['r_hat'].values()
                                     if isinstance(rhat, (int, float))]
                        if valid_rhats:
                            max_rhat = max(valid_rhats)
                            print(f"Maximum R-hat: {max_rhat:.3f}")
                    except Exception as e:
                        print(f"Could not calculate max R-hat: {str(e)}")

                # Print ESS statistics
                if 'ess' in diag:
                    try:
                        valid_ess = [float(ess) for ess in diag['ess'].values()
                                   if isinstance(ess, (int, float))]
                        if valid_ess:
                            min_ess = min(valid_ess)
                            print(f"Minimum ESS: {min_ess:.0f}")
                    except Exception as e:
                        print(f"Could not calculate min ESS: {str(e)}")

            # Print performance metrics if available
            if 'performance' in results:
                print("\nPerformance Metrics:")
                perf = results['performance']
                for metric_name, metric_value in perf.items():
                    if isinstance(metric_value, (int, float)):
                        print(f"{metric_name}: {metric_value:.4f}")
                    else:
                        print(f"{metric_name}: {metric_value}")

        return model, results

    except ValueError as ve:
        print(f"\nValidation Error: {str(ve)}")
        return None, None
    except Exception as e:
        print(f"\nExecution Error: {str(e)}")
        traceback.print_exc()
        return None, None
    finally:
        # Clean up resources
        print("\n=== Cleanup ===")
        try:
            if 'model' in locals() and hasattr(model, '_clear_memory'):
                model._clear_memory()
            gc.collect()
            if hasattr(torch, 'cuda'):
                torch.cuda.empty_cache()
            print("Memory cleared successfully")
        except Exception as e:
            print(f"Cleanup warning: {str(e)}")
        print("\n=== Analysis Complete ===")

In [15]:
# 1. Basic usage with defaults and sampling
model, results = run_clv_analysis(
    processed_df=processed_df
)


Active experiment ID: clv_experiment_2411150051

=== Initializing CLV System ===

JAX Configuration:
JAX version: 0.4.35
Backend: False

GPU Setup:
Available devices: 4
Device 0 initialized successfully
Device memory: 16859 MB
Device 1 initialized successfully
Device memory: 16859 MB
Device 2 initialized successfully
Device memory: 16859 MB
Device 3 initialized successfully
Device memory: 16859 MB

GPU setup completed successfully

Resource monitoring started

=== Validating Input Data ===
Input shape: (1200823, 22)
Memory usage: 207.3 MB

=== Creating Segments ===
Creating segments for 1,200,823 customers

f_segment distribution:
f_segment
Frequency 1    1045538
Frequency 2      98719
Frequency 3      39553
Frequency 4      17013
Name: count, dtype: int64

r_segment distribution:
r_segment
Recency 1    440116
Recency 2    328762
Recency 3    254907
Recency 4    177038
Name: count, dtype: int64

m_segment distribution:
m_segment
Monetary 1    858368
Monetary 2    211419
Monetary 3    

Output()

ERROR:pymc.stats.convergence:There were 2 divergences after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc.stats.convergence:The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details



Computing log-likelihood...

Sampling error: Arguments to add_groups() must be xr.Dataset, xr.Dataarray or dicts                    (argument 'log_likelihood' was type '<class 'pytensor.tensor.variable.TensorVariable'>')

Partial trace available - attempting to salvage results

R-hat statistics:
alpha: mean = 1.121, max = 1.193
alpha_avg_interpurchase_days_scaled_coef: 1.010
alpha_distinct_brands_scaled_coef: 1.027
alpha_distinct_categories_scaled_coef: 1.001
alpha_email_active_coef: 0.996
alpha_has_online_purchases_coef: 1.026
alpha_has_store_purchases_coef: 1.017
alpha_is_loyalty_member_coef: 0.999
alpha_loyalty_points_scaled_coef: 1.003
alpha_mu: 1.114
alpha_sigma: 1.526
alpha_sms_active_coef: 1.003
beta: mean = 1.010, max = 1.032
beta_mu: 1.008
beta_sigma: 1.199
r: mean = 1.517, max = 1.611
r_alpha: 1.560
r_avg_interpurchase_days_scaled_coef: 1.007
r_beta: 1.222
r_distinct_brands_scaled_coef: 1.000
r_distinct_categories_scaled_coef: 1.037
r_email_active_coef: 1.020
r_has_online_pu

In [23]:
class CLVPredictor:
    """Prediction interface for Hierarchical CLV System"""

    def __init__(self, model_system, config=None):
        self.model = model_system
        self.config = config or {}
        self.predictions = {}
        self.default_horizons = [30, 60, 90, 180, 365]  # Days

    def predict_clv(self, time_horizons=None, uncertainty=True):
        """
        Generate CLV predictions for specified time horizons

        Parameters:
        -----------
        time_horizons : list, optional
            List of time horizons in days
        uncertainty : bool, default=True
            Whether to include prediction intervals
        """
        try:
            horizons = time_horizons or self.default_horizons
            print(f"\nGenerating predictions for horizons: {horizons} days")

            # Generate predictions for each horizon
            results = {}
            for horizon in horizons:
                results[horizon] = self._predict_horizon(horizon, uncertainty)

            # Store predictions
            self.predictions = results

            return self._format_predictions(results)

        except Exception as e:
            print(f"Prediction error: {str(e)}")
            raise

    def _predict_horizon(self, horizon, uncertainty=True):
        """Generate predictions for a specific time horizon"""
        try:
            # Get parameter posterior samples
            params = self._get_posterior_samples()

            # Calculate expected transactions
            exp_transactions = self._calc_expected_transactions(params, horizon)

            # Calculate expected value
            exp_value = self._calc_expected_value(params, exp_transactions)

            if uncertainty:
                # Calculate prediction intervals
                intervals = self._calc_prediction_intervals(exp_value)
                return {
                    'expected_value': float(np.mean(exp_value)),
                    'lower_95': float(intervals['lower_95']),
                    'upper_95': float(intervals['upper_95']),
                    'std': float(np.std(exp_value))
                }
            else:
                return {
                    'expected_value': float(np.mean(exp_value))
                }

        except Exception as e:
            print(f"Error predicting horizon {horizon}: {str(e)}")
            raise

    def _get_posterior_samples(self):
        """Extract posterior samples from trained model"""
        try:
            # Get samples from trace
            samples = {
                'r': self.model.trace.posterior['r'].values,
                'alpha': self.model.trace.posterior['alpha'].values,
                's': self.model.trace.posterior['s'].values,
                'beta': self.model.trace.posterior['beta'].values
            }

            # Get covariate coefficients if they exist
            for param in ['r', 'alpha']:
                for cov in self.model.coef_priors.get(param, {}):
                    coef_name = f"{param}_{cov}_coef"
                    if coef_name in self.model.trace.posterior:
                        samples[coef_name] = self.model.trace.posterior[coef_name].values

            return samples

        except Exception as e:
            print(f"Error extracting posterior samples: {str(e)}")
            raise

    def _calc_expected_transactions(self, params, horizon):
        """Calculate expected number of transactions"""
        try:
            # Get current customer data
            freq = self.model.data['frequency'].values
            rec = self.model.data['recency'].values
            T = self.model.data['customer_age_days'].values
            group_idx = self.model.data['group_idx'].values

            # Calculate customer-specific parameters
            r = params['r'][:, :, group_idx]  # [chain, draw, customer]
            alpha = params['alpha'][:, :, group_idx]

            # Calculate probability of being alive
            p_alive = self._calc_probability_alive(r, alpha, freq, rec, T)

            # Calculate expected transactions
            exp_transactions = p_alive * r * (horizon / alpha)

            return exp_transactions

        except Exception as e:
            print(f"Error calculating expected transactions: {str(e)}")
            raise

    def _calc_probability_alive(self, r, alpha, freq, rec, T):
        """Calculate probability customer is still alive"""
        try:
            # Calculate components of the probability
            a1 = (alpha + T) ** (r + freq)
            a2 = (alpha + rec) ** (r + freq)
            b1 = (alpha + T) ** r
            b2 = (alpha + rec) ** r

            # Calculate probability
            p_alive = 1 - (
                (a1 - a2) / (a1 + b1)
            )

            return p_alive

        except Exception as e:
            print(f"Error calculating alive probability: {str(e)}")
            raise

    def _calc_expected_value(self, params, exp_transactions):
        """Calculate expected monetary value"""
        try:
            # Get monetary value parameters
            s = params['s'][:, :, self.model.data['group_idx'].values]
            beta = params['beta'][:, :, self.model.data['group_idx'].values]

            # Expected transaction value
            exp_value_per_transaction = s / beta

            # Total expected value
            total_value = exp_transactions * exp_value_per_transaction

            return total_value

        except Exception as e:
            print(f"Error calculating expected value: {str(e)}")
            raise

    def _calc_prediction_intervals(self, values):
        """Calculate prediction intervals"""
        try:
            # Calculate percentiles across chains and draws
            lower_95 = np.percentile(values, 2.5, axis=(0, 1))
            upper_95 = np.percentile(values, 97.5, axis=(0, 1))

            return {
                'lower_95': lower_95,
                'upper_95': upper_95
            }

        except Exception as e:
            print(f"Error calculating prediction intervals: {str(e)}")
            raise

    def _format_predictions(self, results):
        """Format predictions into a clean DataFrame"""
        try:
            # Create prediction DataFrame
            pred_data = []
            for horizon, metrics in results.items():
                pred_data.append({
                    'horizon_days': horizon,
                    'expected_clv': metrics['expected_value'],
                    'lower_95': metrics.get('lower_95', None),
                    'upper_95': metrics.get('upper_95', None),
                    'std': metrics.get('std', None)
                })

            predictions = pd.DataFrame(pred_data)

            # Add metadata
            predictions.attrs['prediction_date'] = pd.Timestamp.now()
            predictions.attrs['model_type'] = 'Hierarchical BG/NBD'
            predictions.attrs['n_customers'] = len(self.model.data)

            return predictions

        except Exception as e:
            print(f"Error formatting predictions: {str(e)}")
            raise

In [22]:
# Generate predictions Usage
def generate_clv_predictions(clv_system, horizons=None):
    """Generate CLV predictions using trained model"""
    try:
        # Create predictor
        predictor = CLVPredictor(clv_system)

        # Generate predictions
        predictions = predictor.predict_clv(
            time_horizons=horizons,
            uncertainty=True
        )

        # Print summary
        print("\nPrediction Summary:")
        print("-" * 50)
        for _, row in predictions.iterrows():
            print(f"\nHorizon: {row['horizon_days']} days")
            print(f"Expected CLV: ${row['expected_clv']:,.2f}")
            print(f"95% CI: (${row['lower_95']:,.2f}, ${row['upper_95']:,.2f})")

        return predictions, predictor

    except Exception as e:
        print(f"Error generating predictions: {str(e)}")
        raise

In [24]:
# Generate predictions and DataFrame
predictions_summary, horizon_preds_df, predictor = generate_clv_predictions(
    clv_system=model,
    processed_df=processed_df,
    horizons=[30, 90, 180, 365]
)

# Save predictions
timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M')
horizon_preds_df.to_csv(f'clv_predictions_{timestamp}.csv', index=False)

# Analyze predictions
print("\nPrediction Statistics:")
for horizon in [30, 90, 180, 365]:
    cols = [f'clv_{horizon}d_expected']
    stats = horizon_preds_df[cols].describe()
    print(f"\n{horizon}-day CLV:")
    print(stats)

TypeError: generate_clv_predictions() got an unexpected keyword argument 'processed_df'

In [19]:
@classmethod
def load_trained_model(cls, model_path, config=None):
    """
    Load trained model from files with enhanced error handling
    """
    try:
        import pickle
        import arviz as az
        import os

        # First verify files exist
        data_path = f'{model_path}_data.pkl'
        trace_path = f'{model_path}_trace.nc'

        if not os.path.exists(data_path):
            raise FileNotFoundError(f"Data file not found: {data_path}")

        if not os.path.exists(trace_path):
            print(f"Warning: Trace file not found: {trace_path}")

        # Check file sizes
        data_size = os.path.getsize(data_path) / (1024 * 1024)  # MB
        print(f"\nFile sizes:")
        print(f"Data file: {data_size:.1f} MB")

        # Try to load data with error handling
        try:
            with open(data_path, 'rb') as f:
                model_data = pickle.load(f)
        except (pickle.UnpicklingError, EOFError) as e:
            print(f"Error reading pickle file: {str(e)}")
            print("The model file might be corrupted or in wrong format")
            raise

        # Create new instance
        instance = cls(config=config)

        # Update instance attributes with verification
        required_attrs = ['config', 'segment_bins', 'segments', 'groups',
                        'coords', 'scalers']

        for attr in required_attrs:
            if attr in model_data:
                setattr(instance, attr, model_data[attr])
            else:
                print(f"Warning: Missing attribute '{attr}' in saved model")

        # Override with custom config if provided
        if config is not None:
            instance.config = config

        # Load trace if available
        if os.path.exists(trace_path):
            try:
                instance.trace = az.from_netcdf(trace_path)
                trace_size = os.path.getsize(trace_path) / (1024 * 1024)
                print(f"Trace file: {trace_size:.1f} MB")
                print(f"Trace loaded successfully")
            except Exception as e:
                print(f"Warning: Could not load trace: {str(e)}")

        # Print model info
        print(f"\nLoaded model from: {model_path}")
        if 'metadata' in model_data:
            print(f"Save timestamp: {model_data['metadata']['save_time']}")
            print(f"Model type: {model_data['metadata'].get('model_type', 'Unknown')}")

        # Verify critical components
        if hasattr(instance, 'data') and instance.data is not None:
            print(f"Data shape: {instance.data.shape}")
        if hasattr(instance, 'groups'):
            print(f"Number of groups: {len(instance.groups)}")

        return instance

    except FileNotFoundError as e:
        print(f"\nFile not found: {str(e)}")
        raise
    except pickle.UnpicklingError as e:
        print(f"\nError unpickling model data: {str(e)}")
        print("The model file might be corrupted or in wrong format")
        raise
    except Exception as e:
        print(f"\nError loading model: {str(e)}")
        raise

In [18]:
# Helper function to find most recent model
def get_latest_model_path():
    """Find the most recently saved model"""
    try:
        import glob
        import os

        # Look for all model data files
        model_files = glob.glob('trained_models/clv_model_*_data.pkl')

        if not model_files:
            return None

        # Get most recent file
        latest_file = max(model_files, key=os.path.getctime)

        # Remove _data.pkl to get base path
        base_path = latest_file.replace('_data.pkl', '')

        return base_path

    except Exception as e:
        print(f"Error finding latest model: {str(e)}")
        return None

In [17]:
get_latest_model_path()

'trained_models/clv_model_20241115_0053_20241115_0053'

In [None]:
# Usage with better error handling:
try:
    # Specify model path
    model_path = 'trained_models/clv_model_20241112_0013'

    # First verify files
    data_path = f'{model_path}_data.pkl'
    trace_path = f'{model_path}_trace.nc'

    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Model data file not found: {data_path}")

    # Try to load
    print(f"\nLoading model from: {model_path}")
    loaded_model = HierarchicalCLVSystem.load_trained_model(model_path)

    # Verify loaded model
    if loaded_model is not None:
        print("\nModel loaded successfully")
        print(f"Config parameters:")
        for key in ['MCMC_SAMPLES', 'CHAINS', 'BATCH_SIZE']:
            print(f"{key}: {loaded_model.config.get(key, 'Not found')}")

except FileNotFoundError as e:
    print(f"\nFile error: {str(e)}")
except pickle.UnpicklingError as e:
    print(f"\nPickle error: {str(e)}")
    print("The model file might be corrupted")
except Exception as e:
    print(f"\nError: {str(e)}")
    import traceback
    print("\nDetailed error:")
    print(traceback.format_exc())
