In [2]:
import pandas as pd
import numpy as np
import torch
from chronos import BaseChronosPipeline
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
warnings.filterwarnings('ignore')

In [9]:


class ShortTermFinancialPredictor:
    def __init__(self, model_name="amazon/chronos-bolt-base", device="cuda"):
        """Initialize predictor for short-term (1-day) forecasting"""
        self.model_name = model_name
        self.device = device if torch.cuda.is_available() else "cpu"
        self.pipeline = None
        self.df = None
        self.features_df = None
        
    def load_data(self, csv_file_path, date_col="DATE", close_col="CLOSE PRICE"):
        """Load and preprocess financial data"""
        print("Loading financial data for short-term prediction...")
        
        # Load CSV
        self.df = pd.read_csv(csv_file_path)
        self.df.columns = self.df.columns.str.strip()
        
        # Convert date and sort
        self.df[date_col.strip()] = pd.to_datetime(self.df[date_col.strip()], format='%d-%b-%Y')
        self.df = self.df.sort_values(date_col.strip()).reset_index(drop=True)
        
        # Convert price columns to numeric
        price_cols = ['OPEN PRICE', 'HIGH PRICE', 'LOW PRICE', 'CLOSE PRICE', 'SETTLE PRICE']
        for col in price_cols:
            if col in self.df.columns:
                self.df[col] = pd.to_numeric(self.df[col], errors='coerce')
        
        # Handle volume (remove commas)
        if 'Volume' in self.df.columns:
            self.df['Volume'] = self.df['Volume'].astype(str).str.replace(',', '')
            self.df['Volume'] = pd.to_numeric(self.df['Volume'], errors='coerce')
        
        # Remove rows with missing OHLC data
        self.df = self.df.dropna(subset=price_cols[:4])
        
        print(f"Data loaded: {len(self.df)} records from {self.df[date_col.strip()].min()} to {self.df[date_col.strip()].max()}")
        return self.df
    
    def engineer_short_term_features(self, target_col="CLOSE PRICE"):
        """Create features optimized for short-term (1-day) prediction"""
        print("Engineering short-term features...")
        
        df = self.df.copy()
        
        # Short-term price patterns (last 3-7 days)
        df['price_change_1d'] = df[target_col].pct_change(1)
        df['price_change_2d'] = df[target_col].pct_change(2)
        df['price_change_3d'] = df[target_col].pct_change(3)
        
        # Short-term momentum
        df['momentum_3d'] = df[target_col] / df[target_col].shift(3) - 1
        df['momentum_5d'] = df[target_col] / df[target_col].shift(5) - 1
        
        # Short-term volatility
        df['volatility_3d'] = df['price_change_1d'].rolling(window=3).std()
        df['volatility_5d'] = df['price_change_1d'].rolling(window=5).std()
        
        # High-Low spread (intraday volatility)
        df['hl_spread'] = (df['HIGH PRICE'] - df['LOW PRICE']) / df[target_col]
        df['hl_spread_ma3'] = df['hl_spread'].rolling(window=3).mean()
        
        # Short-term moving averages
        df['SMA_3'] = df[target_col].rolling(window=3).mean()
        df['SMA_5'] = df[target_col].rolling(window=5).mean()
        df['SMA_7'] = df[target_col].rolling(window=7).mean()
        
        # Price position relative to recent range
        df['price_position_3d'] = (df[target_col] - df[target_col].rolling(3).min()) / (df[target_col].rolling(3).max() - df[target_col].rolling(3).min())
        df['price_position_5d'] = (df[target_col] - df[target_col].rolling(5).min()) / (df[target_col].rolling(5).max() - df[target_col].rolling(5).min())
        
        # Volume patterns (if available)
        if 'Volume' in df.columns:
            df['volume_change_1d'] = df['Volume'].pct_change(1)
            df['volume_ma3'] = df['Volume'].rolling(window=3).mean()
            df['volume_ratio'] = df['Volume'] / df['volume_ma3']
        
        # Gap analysis (open vs previous close)
        df['gap'] = (df['OPEN PRICE'] - df[target_col].shift(1)) / df[target_col].shift(1)
        
        # Create feature matrix focusing on recent patterns
        feature_columns = [
            target_col, 'price_change_1d', 'price_change_2d', 'price_change_3d',
            'momentum_3d', 'momentum_5d', 'volatility_3d', 'volatility_5d',
            'hl_spread', 'SMA_3', 'SMA_5', 'SMA_7', 
            'price_position_3d', 'price_position_5d', 'gap'
        ]
        
        # Add volume features if available
        if 'Volume' in df.columns:
            feature_columns.extend(['volume_change_1d', 'volume_ratio'])
        
        # Select available columns
        available_features = [col for col in feature_columns if col in df.columns]
        self.features_df = df[['DATE'] + available_features].copy()
        
        # Fill missing values with forward fill only for short-term
        self.features_df = self.features_df.fillna(method='ffill')
        
        print(f"Created {len(available_features)} short-term features")
        return self.features_df
    
    def load_model(self):
        """Load the Chronos-Bolt model"""
        print(f"Loading {self.model_name} model...")
        
        self.pipeline = BaseChronosPipeline.from_pretrained(
            self.model_name,
            device_map=self.device,
            torch_dtype=torch.bfloat16,
        )
        print("Model loaded successfully!")
    
    def predict_next_day(self, context_days=7):
        """Predict next day using only last 7 days of data"""
        if self.pipeline is None:
            self.load_model()
        
        print(f"Generating next-day forecast using last {context_days} days...")
        
        # Use only close price for context (univariate works best for Chronos)
        target_col = 'CLOSE PRICE'
        close_prices = self.features_df[target_col].values
        
        # Take only last context_days observations
        context = torch.tensor(close_prices[-context_days:], dtype=torch.float32)
        
        # Predict only 1 day ahead
        forecast = self.pipeline.predict(
            context=context,
            prediction_length=1
        )
        
        # Extract forecast
        forecast_np = forecast[0].cpu().numpy()  # Remove batch dimension
        
        if len(forecast_np.shape) == 2:  # [prediction_length, num_quantiles]
            num_quantiles = forecast_np.shape[1]
            if num_quantiles >= 5:
                quantiles = {
                    'q10': forecast_np[0, 0],  # 10th percentile
                    'q30': forecast_np[0, min(2, num_quantiles-1)],  # 30th percentile
                    'q50': forecast_np[0, num_quantiles//2],  # Median
                    'q70': forecast_np[0, min(num_quantiles-2, 3*num_quantiles//4)],  # 70th percentile
                    'q90': forecast_np[0, -1]   # 90th percentile
                }
            else:
                # Fallback if fewer quantiles
                quantiles = {
                    'q10': forecast_np[0, 0],
                    'q30': forecast_np[0, 0],
                    'q50': forecast_np[0, num_quantiles//2],
                    'q70': forecast_np[0, -1],
                    'q90': forecast_np[0, -1]
                }
            
            mean_forecast = quantiles['q50']
        else:
            # Single forecast value
            mean_forecast = float(forecast_np[0])
            std_forecast = abs(mean_forecast * 0.02)  # 2% std dev assumption
            quantiles = {
                'q10': mean_forecast - 1.28 * std_forecast,
                'q30': mean_forecast - 0.52 * std_forecast,
                'q50': mean_forecast,
                'q70': mean_forecast + 0.52 * std_forecast,
                'q90': mean_forecast + 1.28 * std_forecast
            }
        
        return {
            'mean_forecast': mean_forecast,
            'quantiles': quantiles,
            'context_used': context.numpy(),
            'current_price': close_prices[-1]
        }
    
    def rolling_validation(self, validation_days=10, context_days=7):
        """Validate model using rolling window approach"""
        print(f"Running rolling validation for last {validation_days} days...")
        
        if self.pipeline is None:
            self.load_model()
        
        target_col = 'CLOSE PRICE'
        close_prices = self.features_df[target_col].values
        dates = self.features_df['DATE'].values
        
        predictions = []
        actuals = []
        prediction_dates = []
        
        # Start validation from day that has enough history
        start_idx = max(context_days, len(close_prices) - validation_days - 1)
        
        for i in range(start_idx, len(close_prices) - 1):
            # Use last context_days for prediction
            context = torch.tensor(close_prices[i-context_days+1:i+1], dtype=torch.float32)
            actual_next = close_prices[i+1]
            
            try:
                # Predict next day
                forecast = self.pipeline.predict(
                    context=context,
                    prediction_length=1
                )
                
                forecast_np = forecast[0].cpu().numpy()
                if len(forecast_np.shape) == 2:
                    pred_next = forecast_np[0, forecast_np.shape[1]//2]  # Use median
                else:
                    pred_next = float(forecast_np[0])
                
                predictions.append(pred_next)
                actuals.append(actual_next)
                prediction_dates.append(dates[i+1])
                
            except Exception as e:
                print(f"Prediction failed for day {i}: {e}")
                continue
        
        # Calculate accuracy metrics
        predictions = np.array(predictions)
        actuals = np.array(actuals)
        
        mape = np.mean(np.abs((actuals - predictions) / actuals)) * 100
        rmse = np.sqrt(np.mean((actuals - predictions) ** 2))
        mae = np.mean(np.abs(actuals - predictions))
        
        # Direction accuracy (up/down prediction)
        actual_directions = np.sign(np.diff(actuals))
        pred_directions = np.sign(predictions[1:] - actuals[:-1])
        direction_accuracy = np.mean(actual_directions == pred_directions) * 100
        
        print(f"\nValidation Results ({len(predictions)} predictions):")
        print(f"MAPE: {mape:.2f}%")
        print(f"RMSE: {rmse:.4f}")
        print(f"MAE: {mae:.4f}")
        print(f"Direction Accuracy: {direction_accuracy:.1f}%")
        
        return {
            'predictions': predictions,
            'actuals': actuals,
            'dates': prediction_dates,
            'mape': mape,
            'rmse': rmse,
            'mae': mae,
            'direction_accuracy': direction_accuracy
        }
    
    def plot_short_term_forecast(self, forecast_result, validation_result=None):
        """Corrected plot for short-term forecast with proper validation visualization"""
        
        # Get recent data for context
        lookback = 20  # Show more context
        recent_dates = self.features_df['DATE'].values[-lookback:]
        recent_prices = self.features_df['CLOSE PRICE'].values[-lookback:]
        
        # Create next day date
        last_date = pd.to_datetime(recent_dates[-1])
        next_date = last_date + pd.Timedelta(days=1)
        
        fig = make_subplots(
            rows=3, cols=1,
            shared_xaxes=True,
            vertical_spacing=0.05,
            subplot_titles=[
                'Historical Prices + Next Day Prediction',
                'Validation: Actual vs Predicted (Recent 10 Days)',
                'Prediction Accuracy Analysis'
            ],
            row_heights=[0.4, 0.35, 0.25]
        )
        
        # --- ROW 1: Historical prices with next day prediction ---
        fig.add_trace(
            go.Scatter(
                x=recent_dates,
                y=recent_prices,
                mode='lines+markers',
                name='Historical Prices',
                line=dict(color='blue', width=2),
                marker=dict(size=4)
            ),
            row=1, col=1
        )
        
        # Next day prediction with error bars
        mean_pred = forecast_result['mean_forecast']
        quantiles = forecast_result['quantiles']
        
        # Main prediction point
        fig.add_trace(
            go.Scatter(
                x=[next_date],
                y=[mean_pred],
                mode='markers',
                name='Next Day Prediction',
                marker=dict(color='red', size=15, symbol='star'),
                error_y=dict(
                    type='data',
                    symmetric=False,
                    array=[quantiles['q90'] - mean_pred],  # Upper error
                    arrayminus=[mean_pred - quantiles['q10']],  # Lower error
                    visible=True,
                    color='red',
                    thickness=3,
                    width=5
                )
            ),
            row=1, col=1
        )
        
        # Add prediction range as shaded area
        fig.add_trace(
            go.Scatter(
                x=[next_date, next_date, next_date],
                y=[quantiles['q10'], mean_pred, quantiles['q90']],
                mode='markers',
                name='90% Range',
                marker=dict(color='red', size=8, opacity=0.3),
                showlegend=False
            ),
            row=1, col=1
        )
        
        # --- ROW 2: Validation results (if available) ---
        if validation_result is not None:
            val_dates = pd.to_datetime(validation_result['dates'])
            val_actuals = validation_result['actuals']
            val_predictions = validation_result['predictions']
            
            # Actual prices
            fig.add_trace(
                go.Scatter(
                    x=val_dates,
                    y=val_actuals,
                    mode='lines+markers',
                    name='Actual Prices',
                    line=dict(color='green', width=2),
                    marker=dict(size=6)
                ),
                row=2, col=1
            )
            
            # Predicted prices
            fig.add_trace(
                go.Scatter(
                    x=val_dates,
                    y=val_predictions,
                    mode='lines+markers',
                    name='Predicted Prices',
                    line=dict(color='red', width=2, dash='dash'),
                    marker=dict(size=6, symbol='x')
                ),
                row=2, col=1
            )
            
            # --- ROW 3: Prediction accuracy analysis ---
            # Prediction errors
            errors = val_actuals - val_predictions
            error_pct = (errors / val_actuals) * 100
            
            # Error bars
            colors = ['green' if abs(e) < 0.5 else 'orange' if abs(e) < 1.0 else 'red' for e in error_pct]
            
            fig.add_trace(
                go.Bar(
                    x=val_dates,
                    y=error_pct,
                    name='Prediction Error %',
                    marker_color=colors,
                    text=[f'{e:.1f}%' for e in error_pct],
                    textposition='outside'
                ),
                row=3, col=1
            )
            
            # Add horizontal lines for acceptable error ranges
            fig.add_hline(y=1.0, line_dash="dash", line_color="orange", 
                        annotation_text="±1% Error", row=3, col=1)
            fig.add_hline(y=-1.0, line_dash="dash", line_color="orange", row=3, col=1)
            fig.add_hline(y=0.5, line_dash="dot", line_color="green", row=3, col=1)
            fig.add_hline(y=-0.5, line_dash="dot", line_color="green", row=3, col=1)
        
        # Update layout
        fig.update_layout(
            title='Short-Term Financial Forecast Analysis (7-Day Context → 1-Day Prediction)',
            height=800,
            showlegend=True,
            legend=dict(
                orientation="v",
                yanchor="top",
                y=0.99,
                xanchor="left",
                x=1.01
            )
        )
        
        # Update axes
        fig.update_xaxes(title_text="Date", row=3, col=1)
        fig.update_yaxes(title_text="Price (₹)", row=1, col=1)
        fig.update_yaxes(title_text="Price (₹)", row=2, col=1)
        fig.update_yaxes(title_text="Error (%)", row=3, col=1)
        
        # Add current price and prediction info as annotation
        current_price = forecast_result['current_price']
        change_pct = (mean_pred / current_price - 1) * 100
        
        fig.add_annotation(
            x=0.02, y=0.98,
            xref="paper", yref="paper",
            text=f"Current: ₹{current_price:.4f}<br>Predicted: ₹{mean_pred:.4f}<br>Change: {change_pct:+.2f}%",
            showarrow=False,
            align="left",
            bgcolor="rgba(255,255,255,0.8)",
            bordercolor="black",
            borderwidth=1
        )
        
        # Save as HTML
        fig.write_html("corrected_short_term_forecast.html")
        print("Corrected forecast chart saved to 'corrected_short_term_forecast.html'")
        
        fig.show()
        return fig

    # Additional function to create a model performance summary
    def plot_model_performance_summary(self, validation_result):
        """Create a comprehensive model performance dashboard"""
        
        if validation_result is None:
            print("No validation results available for performance analysis")
            return None
        
        val_actuals = validation_result['actuals']
        val_predictions = validation_result['predictions']
        dates = pd.to_datetime(validation_result['dates'])
        
        # Calculate metrics
        errors = val_actuals - val_predictions
        error_pct = (errors / val_actuals) * 100
        abs_error_pct = np.abs(error_pct)
        
        # Direction accuracy
        actual_directions = np.diff(val_actuals) > 0  # True if price went up
        pred_directions = val_predictions[1:] > val_actuals[:-1]  # True if predicted higher than previous actual
        direction_matches = actual_directions == pred_directions
        
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=[
                'Prediction vs Actual Scatter',
                'Error Distribution',
                'Direction Accuracy Over Time',
                'Model Performance Metrics'
            ],
            specs=[[{"type": "scatter"}, {"type": "histogram"}],
                [{"type": "scatter"}, {"type": "table"}]]
        )
        
        # 1. Prediction vs Actual scatter plot
        fig.add_trace(
            go.Scatter(
                x=val_actuals,
                y=val_predictions,
                mode='markers',
                name='Predictions',
                marker=dict(size=8, color='blue', opacity=0.6)
            ),
            row=1, col=1
        )
        
        # Perfect prediction line
        min_val, max_val = min(val_actuals.min(), val_predictions.min()), max(val_actuals.max(), val_predictions.max())
        fig.add_trace(
            go.Scatter(
                x=[min_val, max_val],
                y=[min_val, max_val],
                mode='lines',
                name='Perfect Prediction',
                line=dict(color='red', dash='dash')
            ),
            row=1, col=1
        )
        
        # 2. Error distribution histogram
        fig.add_trace(
            go.Histogram(
                x=error_pct,
                nbinsx=15,
                name='Error Distribution',
                marker_color='lightblue'
            ),
            row=1, col=2
        )
        
        # 3. Direction accuracy over time
        fig.add_trace(
            go.Scatter(
                x=dates[1:],  # Skip first date since we need differences
                y=direction_matches.astype(int),
                mode='markers+lines',
                name='Direction Correct',
                marker=dict(
                    size=8,
                    color=['green' if x else 'red' for x in direction_matches]
                )
            ),
            row=2, col=1
        )
        
        # 4. Performance metrics table
        metrics_data = [
            ['MAPE', f"{validation_result['mape']:.2f}%"],
            ['RMSE', f"{validation_result['rmse']:.4f}"],
            ['MAE', f"{validation_result['mae']:.4f}"],
            ['Direction Accuracy', f"{validation_result['direction_accuracy']:.1f}%"],
            ['Avg Error', f"{np.mean(error_pct):.2f}%"],
            ['Std Error', f"{np.std(error_pct):.2f}%"],
            ['Max Error', f"{np.max(abs_error_pct):.2f}%"],
            ['Errors < 0.5%', f"{np.sum(abs_error_pct < 0.5)/len(abs_error_pct)*100:.1f}%"],
            ['Errors < 1.0%', f"{np.sum(abs_error_pct < 1.0)/len(abs_error_pct)*100:.1f}%"]
        ]
        
        fig.add_trace(
            go.Table(
                header=dict(values=['Metric', 'Value']),
                cells=dict(values=[[row[0] for row in metrics_data], 
                                [row[1] for row in metrics_data]])
            ),
            row=2, col=2
        )
        
        fig.update_layout(
            title='Model Performance Analysis Dashboard',
            height=700,
            showlegend=True
        )
        
        fig.update_xaxes(title_text="Actual Price", row=1, col=1)
        fig.update_yaxes(title_text="Predicted Price", row=1, col=1)
        fig.update_xaxes(title_text="Error %", row=1, col=2)
        fig.update_yaxes(title_text="Frequency", row=1, col=2)
        fig.update_xaxes(title_text="Date", row=2, col=1)
        fig.update_yaxes(title_text="Correct (1) / Wrong (0)", row=2, col=1)
        
        fig.write_html("model_performance_dashboard.html")
        print("Performance dashboard saved to 'model_performance_dashboard.html'")
        
        fig.show()
        return fig

    # Update the main class to include these corrected methods
    def add_corrected_plotting_methods(predictor_instance):
        """Add corrected plotting methods to existing predictor instance"""
        import types
        
        predictor_instance.plot_short_term_forecast = types.MethodType(plot_short_term_forecast, predictor_instance)
        predictor_instance.plot_model_performance_summary = types.MethodType(plot_model_performance_summary, predictor_instance)
        
        return predictor_instance

# Usage example:
# predictor = add_corrected_plotting_methods(predictor)
# predictor.plot_short_term_forecast(forecast, validation)
# predictor.plot_model_performance_summary(validation)

    


In [10]:


# Main execution function for short-term prediction
def run_short_term_forecast(csv_file="Quote-CD-USDINR-15-09-2024-to-15-09-2025.csv"):
    """Run short-term (next day) forecasting pipeline"""
    
    # Initialize predictor
    predictor = ShortTermFinancialPredictor(
        model_name="amazon/chronos-bolt-base",
        device="cuda"
    )
    
    # Load and process data
    predictor.load_data(csv_file)
    predictor.engineer_short_term_features()
    
    # Run validation first to check accuracy
    validation_results = predictor.rolling_validation(
        validation_days=10,
        context_days=20
    )
    
    # Generate next day forecast
    forecast_result = predictor.predict_next_day(context_days=7)
    
    # Plot results
    fig = predictor.plot_short_term_forecast(forecast_result, validation_results)
    
    # Print detailed results
    current_price = forecast_result['current_price']
    predicted_price = forecast_result['mean_forecast']
    change_pct = (predicted_price / current_price - 1) * 100
    
    print("\n" + "="*60)
    print("SHORT-TERM FORECAST SUMMARY (NEXT DAY)")
    print("="*60)
    print(f"Current Price: ₹{current_price:.4f}")
    print(f"Predicted Next Day Price: ₹{predicted_price:.4f}")
    print(f"Expected Change: {change_pct:.2f}%")
    print(f"Confidence Range: ₹{forecast_result['quantiles']['q10']:.4f} - ₹{forecast_result['quantiles']['q90']:.4f}")
    
    # Trading signal
    if abs(change_pct) > 0.5:
        signal = "BUY 📈" if change_pct > 0 else "SELL 📉"
        strength = "STRONG" if abs(change_pct) > 1.0 else "WEAK"
    else:
        signal = "HOLD ➖"
        strength = ""
    
    print(f"Trading Signal: {strength} {signal}")
    print(f"Model Validation Accuracy: {validation_results['direction_accuracy']:.1f}%")
    
    return predictor, forecast_result, validation_results


In [11]:

# Run the short-term forecast
predictor, forecast, validation = run_short_term_forecast(
    csv_file="Data/Quote-CD-USDINR-15-09-2024-to-15-09-2025.csv"
)


Loading financial data for short-term prediction...
Data loaded: 92 records from 2024-09-30 00:00:00 to 2025-09-12 00:00:00
Engineering short-term features...
Created 17 short-term features
Running rolling validation for last 10 days...
Loading amazon/chronos-bolt-base model...
Model loaded successfully!

Validation Results (10 predictions):
MAPE: 0.53%
RMSE: 0.4996
MAE: 0.4679
Direction Accuracy: 55.6%
Generating next-day forecast using last 7 days...
Corrected forecast chart saved to 'corrected_short_term_forecast.html'



SHORT-TERM FORECAST SUMMARY (NEXT DAY)
Current Price: ₹88.3200
Predicted Next Day Price: ₹88.1672
Expected Change: -0.17%
Confidence Range: ₹88.1672 - ₹88.1672
Trading Signal:  HOLD ➖
Model Validation Accuracy: 55.6%
