In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from types import SimpleNamespace

from tools.inference_utils import FinCast_Inference, plot_last_outputs, plot_last_outputs # this is mine and it's a WIP: plot_predictions_multi
print(os.getcwd())

In [None]:
#paths to set up

# data path -> #path to your data
# model_path -> #the pth file for model path
# save_output_path -> #path for saving your outputs.



actual_data_path = '../data/my_data/clean_data/SIL_cleaned_1S.csv' # path to data

data_path = '../data/my_data/split_data/SIL_cleaned_1S.csv'

model_path = "/Users/michaelharoon/Projects/tasty/FinCast-fts/weights/FinCast/v1.pth"            #the pth file for model path

save_output_path = "../notebooks/fincast_outputs"          #path for saving your outputs.



In [None]:
df = pd.read_csv(data_path, index_col=[0,1])
df_actual = pd.read_csv(actual_data_path, index_col=[0,1])


In [None]:
## set up 

config = SimpleNamespace()

#device
config.backend = "cpu"                    #cpu for cpu only, gpu for cuda gpu

#model path
config.model_path = model_path

#model version
config.model_version = "v1"           #only v1 for now, v1 is the 1b model in cikm 2025 paper, will release a better v2 soon with better performance and smaller size


#data related
config.data_path = data_path        #if using csv, set the data path
# OR 
config.df = df                      # if using pandas dataframe, set the dataframe

config.data_frequency = "1S"          # uses pandas conversion, so minutes is T or MIN and months is m or N
                                            # Valid suffixes:
                                            # - Minutes: `"MIN"`
                                            # - Hours: `"H"`
                                            # - Seconds: `"S"`
                                            # - Days: `"D"`
                                            # - Business days: `"B"`
                                            # - Microseconds: `"U"`
                                            # - Weeks: `"W"`
                                            # - Months: `"M"`
                                            # - Month start: `"MS"`
                                            # - Years: `"Y"` or `"A"`
                                            # - Quarters: `"Q"`

config.context_len = 1024               #input length for forecast input, from 32 to 1024
config.horizon_len = 256              #output length for forecast output, from 1 to 256

config.all_data = False               #boolean, False => only use the last input(input = context length), True =>  stride = 1, slices all input data and inference

config.columns_target = ['close']             #the columns which you want to forecast, can use both int (index of that column) or str for column name

config.series_norm = False            #True for norm each series, false for not norm.

config.batch_size = 4                #go lower if you have a smaller vram.


#output related
config.forecast_mode = "median"         #mean or median
config.quantile_outputs = []          #put in the optional quantile outputs you want, from q1 to q9, leave empty for nothing


config.save_output = True            #saves numeric outputs to csv
config.save_output_path = save_output_path

config.plt_outputs = True             #plt all the last outputs
config.plt_quantiles = [1, 3, 7, 9]             #quantiles to plt, from 1 to 9, use int


In [None]:
#auto run code, just run this cell

fincast_inference = FinCast_Inference(
    config, 
    use_df=True     # use_df = True when using pd.DataFrame set in config.df = df
                    # use_df = False when using csv file set inconfig.data_path = data_path
    ) 


preds, mapping, full_outputs = fincast_inference.run_inference()



In [None]:
if getattr(config, "plt_outputs", True):
    plot_last_outputs(
    fincast_inference=fincast_inference,
    mean_all=preds,
    mapping_df=mapping,
    full_all=full_outputs,
    config=config,
)

In [None]:
# Overlay forecast vs actuals for the last window
# This cell is self-contained: if forecasts aren't computed yet, it will run inference for `config`.

# Assumptions:
# - `df` is the truncated data (used for forecast context)
# - `df_actual` is the full series (contains future after df)
# - `config` points to `data_path` (truncated) and defines context_len (L) and horizon_len (H)
# - Model weights at `model_path`

L = int(config.context_len)
H = int(config.horizon_len)

# Ensure we have predictions for the truncated data
need_run = False
try:
    _ = preds  # noqa: F401
    _ = mapping  # noqa: F401
except NameError:
    need_run = True

if need_run:
    fincast_inference = FinCast_Inference(config)
    preds, mapping, full_outputs = fincast_inference.run_inference()

# Slice predictions to the last H (in case model returned on-context forecasts)
mean_all = preds
if mean_all.shape[1] != H:
    mean_all = mean_all[:, -H:]

# Determine row indices for the last window per series if mapping is available
if (mapping is not None) and ("series_idx" in mapping.columns) and ("window_end" in mapping.columns):
    pick_idx = mapping.groupby("series_idx")["window_end"].idxmax().to_numpy()
    # Sort by series_idx to maintain column order
    pick_idx = pick_idx[np.argsort(mapping.loc[pick_idx, "series_idx"].to_numpy())]
else:
    pick_idx = np.array([len(mean_all) - 1], dtype=int)

# Get the target columns from config (fallback to Close if not specified)
target_columns = getattr(config, "columns_target", ["Close"])
if not isinstance(target_columns, list):
    target_columns = [target_columns]

# Prepare index alignment for actuals - find where df ends in df_actual
last_idx = df.index[-1]
actual_start_pos = None

# Try multiple methods to align indices
try:
    # Method 1: Direct index lookup (works if indices are identical)
    actual_start_pos = df_actual.index.get_loc(last_idx)
except (KeyError, TypeError):
    try:
        # Method 2: Convert to numeric and find by position
        df_numeric_idx = pd.to_numeric(df.index, errors='coerce')
        df_actual_numeric_idx = pd.to_numeric(df_actual.index, errors='coerce')
        last_idx_numeric = df_numeric_idx.iloc[-1]
        # Find the last occurrence of this value in actual
        matches = np.where(df_actual_numeric_idx == last_idx_numeric)[0]
        if len(matches) > 0:
            actual_start_pos = matches[-1]
    except Exception:
        pass

# If still no alignment, use length-based fallback
if actual_start_pos is None:
    # Assume df is a prefix of df_actual
    actual_start_pos = len(df) - 1

# Plot overlay for each selected series (one per column)
n_series = len(pick_idx)
fig, axes = plt.subplots(n_series, 1, figsize=(12, 5 * n_series), sharex=True)
if n_series == 1:
    axes = [axes]

for plot_idx, (row_i, i) in enumerate(zip(range(len(pick_idx)), pick_idx)):
    ax = axes[plot_idx]
    
    # Get series name from mapping
    if mapping is not None and i < len(mapping):
        series_name = mapping.iloc[i].get("series_name", f"series_{i}")
    else:
        series_name = f"series_{i}"
    
    # Get the forecast for this series
    y_forecast = mean_all[i, :]
    
    # Get context from df for this specific column
    if series_name in df.columns:
        ctx_vals = df[series_name].to_numpy()[-L:]
    else:
        # Fallback: try to match by column index or use first numeric column
        ctx_vals = df.select_dtypes(include=[np.number]).iloc[:, plot_idx % len(df.columns)].to_numpy()[-L:]
    
    # Get actual future values for this column with proper alignment
    actual_future = None
    if series_name in df_actual.columns:
        try:
            # Get the actual future values starting right after the last index
            if actual_start_pos + 1 < len(df_actual):
                actual_future = df_actual[series_name].iloc[actual_start_pos + 1 : actual_start_pos + 1 + H].to_numpy()
                # Ensure we have the right number of values
                if len(actual_future) < H:
                    # Pad with NaN if needed, or truncate forecast
                    if len(actual_future) < len(y_forecast):
                        y_forecast = y_forecast[:len(actual_future)]
        except Exception as e:
            print(f"Warning: Could not extract actuals for {series_name}: {e}")
            # Fallback: use tail, but this may cause discontinuity
            actual_future = df_actual[series_name].to_numpy()[-H:]
    
    # Create x-axis indices
    x_ctx = np.arange(L)
    x_fut = np.arange(L, L + len(y_forecast))
    
    # Plot context
    ax.plot(x_ctx, ctx_vals, label="context", color="tab:blue", linewidth=1.8)
    
    # Plot forecast
    ax.plot(x_fut, y_forecast, label="forecast (mean)", color="tab:red", linewidth=2.2)
    
    # Plot actual if available
    if actual_future is not None and len(actual_future) > 0:
        k = min(len(actual_future), len(y_forecast))
        # Check for discontinuity and attempt to align
        if k > 0:
            # If there's a large gap, try to connect smoothly
            ctx_end = ctx_vals[-1]
            actual_start = actual_future[0]
            forecast_start = y_forecast[0]
            
            # Plot the actual values
            ax.plot(x_fut[:k], actual_future[:k], label="actual", color="tab:green", linestyle="--", linewidth=2)
            
            # Optionally add a vertical line to show the transition point
            if abs(actual_start - ctx_end) > abs(ctx_end) * 0.01:  # More than 1% difference
                ax.axvline(x=L, color="gray", linestyle=":", alpha=0.5, linewidth=1)
    
    ax.set_title(f"{series_name} - Forecast vs Actual (L={L}, H={H})")
    if plot_idx == n_series - 1:
        ax.set_xlabel("Time (relative index)")
    ax.set_ylabel("Value")
    ax.grid(True, alpha=0.3)
    ax.legend(loc="best")

plt.tight_layout()
plt.show()
