<a href="https://colab.research.google.com/github/wrymp/Final-Project-Walmart-Recruiting---Store-Sales-Forecasting/blob/main/model_experiment_TFT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Temporal Fusion Transformer Implementation for Walmart Sales Forecasting

This notebook implements Temporal Fusion Transformer (TFT) for Walmart sales forecasting following the exact pipeline structure from N-BEATS experiments.

# Setup & Data Download

In [1]:
from google.colab import drive
!pip install wandb -q
!pip install kaggle -q

drive.mount('/content/drive')
!mkdir -p ~/.kaggle
!cp /content/drive/MyDrive/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# Uncomment to download data if needed
# !kaggle competitions download -c walmart-recruiting-store-sales-forecasting
# !unzip /content/walmart-recruiting-store-sales-forecasting.zip
# !unzip /content/train.csv.zip
# !unzip /content/test.csv.zip
# !unzip /content/features.csv.zip
# !unzip /content/sampleSubmission.csv.zip

# Setup & Imports

In [2]:
import pandas as pd
import numpy as np
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import math
import warnings
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
import gc
import os
import pickle
import cloudpickle

warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Libraries imported successfully!
PyTorch version: 2.6.0+cu124
CUDA available: True
Using device: cuda


# Wandb Initialization

In [3]:
# Initialize Wandb project
wandb.login()
try:
    wandb.init(
        project="walmart-sales-forecasting",
        name="TFT_Initial_Setup",
        config={
            "model_type": "TFT",
            "framework": "PyTorch",
            "device": str(device),
            "random_seed": 42
        }
    )
    print("✓ Wandb initialized successfully!")
except Exception as e:
    print(f"⚠️ Wandb initialization failed: {e}")
    print("Continuing without wandb logging...")

[34m[1mwandb[0m: Currently logged in as: [33mqitiashvili13[0m ([33mdshan21-free-university-of-tbilisi-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


✓ Wandb initialized successfully!


# Data Loading

In [4]:
# Load Walmart datasets
print("Loading Walmart datasets...")

try:
    train_df = pd.read_csv('/content/drive/MyDrive/walmart-recruiting-store-sales-forecasting/train.csv/train.csv')
    test_df = pd.read_csv('/content/drive/MyDrive/walmart-recruiting-store-sales-forecasting/test.csv/test.csv')
    stores_df = pd.read_csv('/content/drive/MyDrive/walmart-recruiting-store-sales-forecasting/stores.csv')
    features_df = pd.read_csv('/content/drive/MyDrive/walmart-recruiting-store-sales-forecasting/features.csv/features.csv')

    print(f"✓ Train data shape: {train_df.shape}")
    print(f"✓ Test data shape: {test_df.shape}")
    print(f"✓ Stores data shape: {stores_df.shape}")
    print(f"✓ Features data shape: {features_df.shape}")

    # Log basic dataset info
    wandb.log({
        "train_samples": len(train_df),
        "test_samples": len(test_df),
        "num_stores": stores_df['Store'].nunique(),
        "num_departments": train_df['Dept'].nunique(),
        "date_range_train": f"{train_df['Date'].min()} to {train_df['Date'].max()}"
    })

except FileNotFoundError as e:
    print(f"❌ Error loading data: {e}")
    print("Please ensure data files are in the correct directory")
    raise

Loading Walmart datasets...
✓ Train data shape: (421570, 5)
✓ Test data shape: (115064, 4)
✓ Stores data shape: (45, 3)
✓ Features data shape: (8190, 12)


# Data Exploration Run

In [5]:
# Start new wandb run for exploration
wandb.finish()
wandb.init(
    project="walmart-sales-forecasting",
    name="TFT_Exploration",
    config={"stage": "exploration"}
)

print("\n=== DATA EXPLORATION ===")

# Convert date columns
train_df['Date'] = pd.to_datetime(train_df['Date'])
test_df['Date'] = pd.to_datetime(test_df['Date'])
features_df['Date'] = pd.to_datetime(features_df['Date'])

# Basic statistics
print("\nTrain Data Info:")
print(f"Date range: {train_df['Date'].min()} to {train_df['Date'].max()}")
print(f"Unique stores: {train_df['Store'].nunique()}")
print(f"Unique departments: {train_df['Dept'].nunique()}")
print(f"Total store-dept combinations: {train_df.groupby(['Store', 'Dept']).ngroups}")

# Sales statistics
print(f"\nSales Statistics:")
print(f"Mean weekly sales: ${train_df['Weekly_Sales'].mean():,.2f}")
print(f"Median weekly sales: ${train_df['Weekly_Sales'].median():,.2f}")
print(f"Min weekly sales: ${train_df['Weekly_Sales'].min():,.2f}")
print(f"Max weekly sales: ${train_df['Weekly_Sales'].max():,.2f}")

# Holiday impact
holiday_sales = train_df.groupby('IsHoliday')['Weekly_Sales'].agg(['mean', 'count'])
print(f"\nHoliday Impact:")
print(holiday_sales)

# Store types
store_types = stores_df['Type'].value_counts()
print(f"\nStore Types:")
print(store_types)

# Missing values in features
print(f"\nMissing Values in Features:")
missing_pct = (features_df.isnull().sum() / len(features_df)) * 100
print(missing_pct[missing_pct > 0].sort_values(ascending=False))

# TFT-specific analysis
print(f"\nTFT-Specific Analysis:")
print(f"Time series length per store-dept: {train_df.groupby(['Store', 'Dept']).size().describe()}")

# Log exploration metrics
wandb.log({
    "unique_stores": train_df['Store'].nunique(),
    "unique_departments": train_df['Dept'].nunique(),
    "total_timeseries": train_df.groupby(['Store', 'Dept']).ngroups,
    "avg_weekly_sales": train_df['Weekly_Sales'].mean(),
    "median_weekly_sales": train_df['Weekly_Sales'].median(),
    "sales_std": train_df['Weekly_Sales'].std(),
    "holiday_sales_boost": holiday_sales.loc[True, 'mean'] / holiday_sales.loc[False, 'mean'],
    "missing_markdown1_pct": missing_pct['MarkDown1'],
    "missing_markdown2_pct": missing_pct['MarkDown2'],
    "missing_markdown3_pct": missing_pct['MarkDown3'],
    "missing_markdown4_pct": missing_pct['MarkDown4'],
    "missing_markdown5_pct": missing_pct['MarkDown5'],
    "avg_timeseries_length": train_df.groupby(['Store', 'Dept']).size().mean()
})

print("\n✓ Exploration completed and logged to wandb")

0,1
num_departments,▁
num_stores,▁
test_samples,▁
train_samples,▁

0,1
date_range_train,2010-02-05 to 2012-1...
num_departments,81
num_stores,45
test_samples,115064
train_samples,421570



=== DATA EXPLORATION ===

Train Data Info:
Date range: 2010-02-05 00:00:00 to 2012-10-26 00:00:00
Unique stores: 45
Unique departments: 81
Total store-dept combinations: 3331

Sales Statistics:
Mean weekly sales: $15,981.26
Median weekly sales: $7,612.03
Min weekly sales: $-4,988.94
Max weekly sales: $693,099.36

Holiday Impact:
                   mean   count
IsHoliday                      
False      15901.445069  391909
True       17035.823187   29661

Store Types:
Type
A    22
B    17
C     6
Name: count, dtype: int64

Missing Values in Features:
MarkDown2       64.334554
MarkDown4       57.704518
MarkDown3       55.885226
MarkDown1       50.769231
MarkDown5       50.549451
CPI              7.142857
Unemployment     7.142857
dtype: float64

TFT-Specific Analysis:
Time series length per store-dept: count    3331.000000
mean      126.559592
std        40.212763
min         1.000000
25%       143.000000
50%       143.000000
75%       143.000000
max       143.000000
dtype: float64

✓ 

# Custom Transformers for TFT Pipeline

In [6]:
class TFTDataProcessor(BaseEstimator, TransformerMixin):
    """Processes raw Walmart data into TFT format with static and time-varying features"""

    def __init__(self, lookback_window=52, forecast_horizon=1):
        self.lookback_window = lookback_window
        self.forecast_horizon = forecast_horizon
        self.store_dept_combinations = None
        self.date_range = None
        self.scalers = {}

    def fit(self, X, y=None):
        """Learn the store-department combinations, date range, and scalers"""
        self.store_dept_combinations = X.groupby(['Store', 'Dept']).size().index.tolist()
        self.date_range = sorted(X['Date'].unique())

        # Fit scalers for numerical features
        numerical_cols = ['Weekly_Sales', 'Temperature', 'Fuel_Price', 'CPI', 'Unemployment'] + \
                        [col for col in X.columns if 'MarkDown' in col]

        for col in numerical_cols:
            if col in X.columns:
                scaler = StandardScaler()
                valid_data = X[col].dropna()
                if len(valid_data) > 0:
                    scaler.fit(valid_data.values.reshape(-1, 1))
                    self.scalers[col] = scaler

        print(f"Found {len(self.store_dept_combinations)} store-dept combinations")
        print(f"Date range: {self.date_range[0]} to {self.date_range[-1]}")
        print(f"Fitted scalers for: {list(self.scalers.keys())}")
        return self

    def transform(self, X):
        """Transform data into TFT format with static and time-varying features"""
        sequences = []
        targets = []
        static_features = []
        metadata = []

        for store, dept in self.store_dept_combinations:
            # Get time series for this store-dept combination
            series_data = X[(X['Store'] == store) & (X['Dept'] == dept)].copy()
            series_data = series_data.sort_values('Date')

            if len(series_data) < self.lookback_window + self.forecast_horizon:
                continue

            # Static features (constant for each store-dept combination)
            static_feat = []
            if 'Type' in series_data.columns:
                # Encode store type: A=0, B=1, C=2
                type_map = {'A': 0, 'B': 1, 'C': 2}
                static_feat.append(type_map.get(series_data['Type'].iloc[0], 0))
            if 'Size' in series_data.columns:
                # Normalize store size
                size_val = series_data['Size'].iloc[0]
                static_feat.append(size_val / 200000.0)  # Rough normalization

            # Store and dept as categorical features
            static_feat.extend([store / 45.0, dept / 100.0])  # Normalize to [0,1]

            # Create sliding windows
            for i in range(len(series_data) - self.lookback_window - self.forecast_horizon + 1):
                window_data = series_data.iloc[i:i + self.lookback_window]

                # Time-varying features
                time_varying_feat = []

                # Sales (target variable)
                sales_sequence = window_data['Weekly_Sales'].values
                if np.any(np.isnan(sales_sequence)) or np.any(np.isinf(sales_sequence)):
                    continue

                # Scale sales
                if 'Weekly_Sales' in self.scalers:
                    sales_scaled = self.scalers['Weekly_Sales'].transform(sales_sequence.reshape(-1, 1)).flatten()
                else:
                    sales_scaled = sales_sequence

                time_varying_feat.append(sales_scaled)

                # External time-varying features
                feature_names = ['Temperature', 'Fuel_Price', 'CPI', 'Unemployment',
                               'MarkDown1', 'MarkDown2', 'MarkDown3', 'MarkDown4', 'MarkDown5']

                for feat_name in feature_names:
                    if feat_name in window_data.columns:
                        feat_vals = window_data[feat_name].fillna(method='ffill').fillna(method='bfill').fillna(0).values

                        # Scale feature if scaler exists
                        if feat_name in self.scalers:
                            feat_vals = self.scalers[feat_name].transform(feat_vals.reshape(-1, 1)).flatten()

                        time_varying_feat.append(feat_vals)

                # Holiday indicator
                if 'IsHoliday_x' in window_data.columns:
                    holiday_vals = window_data['IsHoliday_x'].astype(float).values
                    time_varying_feat.append(holiday_vals)
                elif 'IsHoliday' in window_data.columns:
                    holiday_vals = window_data['IsHoliday'].astype(float).values
                    time_varying_feat.append(holiday_vals)

                # Date features (cyclical encoding)
                dates = pd.to_datetime(window_data['Date'])
                week_of_year = dates.dt.isocalendar().week.values
                month = dates.dt.month.values

                # Cyclical encoding
                week_sin = np.sin(2 * np.pi * week_of_year / 52)
                week_cos = np.cos(2 * np.pi * week_of_year / 52)
                month_sin = np.sin(2 * np.pi * month / 12)
                month_cos = np.cos(2 * np.pi * month / 12)

                time_varying_feat.extend([week_sin, week_cos, month_sin, month_cos])

                # Stack time-varying features
                try:
                    feature_matrix = np.column_stack(time_varying_feat)
                except ValueError:
                    continue

                sequences.append(feature_matrix)
                static_features.append(np.array(static_feat))

                # Target (next forecast_horizon values)
                target_data = series_data.iloc[i + self.lookback_window:i + self.lookback_window + self.forecast_horizon]
                target_sales = target_data['Weekly_Sales'].values

                if np.any(np.isnan(target_sales)) or np.any(np.isinf(target_sales)):
                    sequences.pop()  # Remove the last added sequence
                    static_features.pop()  # Remove the last added static feature
                    continue

                # Scale target
                if 'Weekly_Sales' in self.scalers:
                    target_scaled = self.scalers['Weekly_Sales'].transform(target_sales.reshape(-1, 1)).flatten()
                else:
                    target_scaled = target_sales

                targets.append(target_scaled)

                # Metadata
                metadata.append({
                    'store': store,
                    'dept': dept,
                    'start_date': window_data['Date'].iloc[0],
                    'end_date': window_data['Date'].iloc[-1],
                    'forecast_date': target_data['Date'].iloc[0] if len(target_data) > 0 else None
                })

        print(f"Generated {len(sequences)} valid sequences from {len(self.store_dept_combinations)} store-dept combinations")

        if len(sequences) > 0:
            print(f"Time-varying features shape: {sequences[0].shape}")
            print(f"Static features shape: {static_features[0].shape}")

        return {
            'sequences': np.array(sequences, dtype=object),
            'targets': np.array(targets, dtype=object),
            'static_features': np.array(static_features, dtype=object),
            'metadata': metadata
        }

class FeatureMerger(BaseEstimator, TransformerMixin):
    """Merges train/test data with stores and features data"""

    def __init__(self):
        self.stores_data = None
        self.features_data = None

    def fit(self, X, y=None, stores_df=None, features_df=None):
        """Store the auxiliary dataframes"""
        self.stores_data = stores_df.copy() if stores_df is not None else None
        self.features_data = features_df.copy() if features_df is not None else None
        return self

    def transform(self, X):
        """Merge main data with stores and features"""
        result = X.copy()

        # Merge with stores data
        if self.stores_data is not None:
            result = result.merge(self.stores_data, on='Store', how='left')

        # Merge with features data
        if self.features_data is not None:
            result = result.merge(self.features_data, on=['Store', 'Date'], how='left')

        return result

class MissingValueHandler(BaseEstimator, TransformerMixin):
    """Handle missing values in time-series data"""

    def __init__(self):
        self.fill_values = {}

    def fit(self, X, y=None):
        """Learn fill values for missing data"""
        # For MarkDown columns, fill with 0 (no markdown)
        markdown_cols = [col for col in X.columns if 'MarkDown' in col]
        for col in markdown_cols:
            self.fill_values[col] = 0.0

        # For other numerical columns, use median
        numerical_cols = X.select_dtypes(include=[np.number]).columns
        for col in numerical_cols:
            if col not in self.fill_values and X[col].isnull().any():
                self.fill_values[col] = X[col].median()

        return self

    def transform(self, X):
        """Fill missing values"""
        result = X.copy()
        for col, fill_value in self.fill_values.items():
            if col in result.columns:
                result[col] = result[col].fillna(fill_value)
        return result

print("✓ Custom transformers defined")

✓ Custom transformers defined


# Temporal Fusion Transformer Model Implementation

In [7]:
class VariableSelectionNetwork(nn.Module):
    """Variable selection network for TFT"""

    def __init__(self, input_dim, num_inputs, hidden_dim, dropout_rate=0.1):
        super().__init__()
        self.input_dim = input_dim
        self.num_inputs = num_inputs # Number of variables/features
        self.hidden_dim = hidden_dim

        # Flatten and process
        # The linear layer should expect input_dim * num_inputs, but here num_inputs refers to the number of variables,
        # and input_dim refers to the dimension of each variable at a given time step (usually 1).
        # The flattening should be across the sequence length and the input_dim of each variable.
        # The input x has shape (batch_size, sequence_length, num_variables, input_dim) conceptually,
        # but is passed as (batch_size, sequence_length, num_variables * input_dim) or simpler (batch_size, sequence_length, num_features)
        # where num_features = num_variables * input_dim.
        # The current implementation of TFTDataProcessor provides (batch_size, sequence_length, num_time_features)
        # So num_features = num_time_features, and input_dim is effectively num_features/num_inputs
        # Let's assume input_dim=1 as per standard TFT, and num_inputs = num_time_features
        # The flattening should be (batch_size, sequence_length * num_time_features)

        # Corrected linear layer input dimension
        # This will be set dynamically based on the actual input shape during the forward pass
        self.flattened_linear1 = nn.Linear(1, hidden_dim) # Placeholder, will be replaced
        self.flattened_linear2 = nn.Linear(hidden_dim, num_inputs) # num_inputs is num_time_features here

        self.flattened_grn = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            self.flattened_linear2,
            nn.Softmax(dim=-1)
        )


        # Individual processing for each variable
        # Each GRN processes the sequence for a single variable (input_dim=1)
        self.single_variable_grns = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.input_dim, hidden_dim), # Should be input_dim
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.Linear(hidden_dim, hidden_dim)
            ) for _ in range(num_inputs) # num_inputs is num_time_features
        ])

    def forward(self, x):
        # x shape: (batch_size, sequence_length, num_features)
        # where num_features is effectively num_time_features in this context
        batch_size, seq_len, num_features = x.shape
        num_variables = self.num_inputs # num_inputs is num_time_features

        # Ensure num_features matches expected num_variables * input_dim
        # In this case, input_dim is 1, so num_features should be num_variables
        if num_features != num_variables * self.input_dim:
             # This could happen if input_dim is not 1 or num_features is incorrect
             # Given the TFTDataProcessor, num_features should be num_time_features,
             # and input_dim is implicitly 1 per feature.
             # Let's adjust the linear layer size dynamically or raise error
             # For now, assume input_dim=1 and num_features = num_variables
             if num_features != num_variables:
                 raise ValueError(f"Input features {num_features} does not match expected num_variables ({num_variables}) * input_dim ({self.input_dim})")
             # If num_features == num_variables, it means input_dim is 1


        # Dynamically set the input size of the first flattened linear layer
        if self.flattened_linear1.in_features != seq_len * num_features:
             self.flattened_linear1 = nn.Linear(seq_len * num_features, self.hidden_dim).to(x.device)
             self.flattened_grn = nn.Sequential(
                self.flattened_linear1,
                nn.ReLU(),
                nn.Dropout(self.flattened_grn[2].p), # Keep dropout rate
                self.flattened_linear2,
                nn.Softmax(dim=-1)
             ).to(x.device)


        # Flatten for variable selection across sequence length
        flattened = x.view(batch_size, -1) # (batch_size, sequence_length * num_features)

        # Apply the dynamically sized flattened_grn
        variable_weights = self.flattened_grn(flattened)  # (batch_size, num_variables)


        # Process each variable separately
        processed_variables = []
        # Iterate through each variable (column) in the time-varying input
        for i in range(num_variables):
            # Select the i-th variable across the sequence length
            var_input = x[:, :, i:i+self.input_dim] # (batch_size, seq_len, input_dim)
            # Apply the GRN for this variable across the sequence length
            processed = self.single_variable_grns[i](var_input) # (batch_size, seq_len, hidden_dim)
            processed_variables.append(processed)

        if len(processed_variables) == 0:
            # Handle case with no variables
            return torch.zeros(batch_size, seq_len, self.hidden_dim, device=x.device)

        # Stack processed variables
        # stacked shape: (batch_size, seq_len, hidden_dim, num_variables)
        stacked = torch.stack(processed_variables, dim=-1)

        # Apply variable weights
        # weights_expanded shape: (batch_size, 1, 1, num_variables)
        weights_expanded = variable_weights.unsqueeze(1).unsqueeze(1)

        # Weighted sum over variables
        # output shape: (batch_size, seq_len, hidden_dim)
        output = (stacked * weights_expanded).sum(dim=-1)

        return output

class GatedResidualNetwork(nn.Module):
    """Gated Residual Network component"""

    def __init__(self, input_dim, hidden_dim, dropout_rate=0.1):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)

        self.gate = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Sigmoid()
        )

        self.dropout = nn.Dropout(dropout_rate)
        self.layer_norm = nn.LayerNorm(hidden_dim)

        # Skip connection projection if dimensions don't match
        if input_dim != hidden_dim:
            self.skip_projection = nn.Linear(input_dim, hidden_dim)
        else:
            self.skip_projection = None

    def forward(self, x):
        # Primary path
        # Apply fc1 along the last dimension
        y = F.relu(self.fc1(x))
        y = self.dropout(y)
        y = self.fc2(y)

        # Gating
        gate = self.gate(y)
        y = y * gate

        # Skip connection
        if self.skip_projection is not None:
            # Apply skip projection along the last dimension
            x = self.skip_projection(x)

        # Only add skip connection if dimensions match
        # This check is important if x has sequence length dimension
        if x.shape[-1] == y.shape[-1] and x.shape[:-1] == y.shape[:-1]: # Check all but last dim
            y = y + x

        y = self.layer_norm(y)
        return y

class TemporalFusionTransformer(nn.Module):
    """Temporal Fusion Transformer model"""

    def __init__(self, num_time_features, num_static_features, hidden_dim=128,
                 num_attention_heads=4, dropout_rate=0.1, forecast_horizon=1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_attention_heads = num_attention_heads
        self.forecast_horizon = forecast_horizon
        self.num_time_features = num_time_features # Store this

        # Variable selection networks
        # Assuming input_dim=1 for each time-varying feature
        self.temporal_vsn = VariableSelectionNetwork(
            input_dim=1, # Dimension of each variable at a time step
            num_inputs=num_time_features, # Number of time-varying variables
            hidden_dim=hidden_dim,
            dropout_rate=dropout_rate
        )

        # Static feature processing
        if num_static_features > 0:
            self.static_grn = GatedResidualNetwork(
                input_dim=num_static_features, hidden_dim=hidden_dim, dropout_rate=dropout_rate
            )
        else:
            self.static_grn = None

        # LSTM encoder
        self.encoder_lstm = nn.LSTM(
            input_size=hidden_dim, hidden_size=hidden_dim,
            batch_first=True, dropout=dropout_rate if dropout_rate > 0 else 0 # Avoid dropout=0 in LSTM
        )

        # Multi-head attention
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=hidden_dim, num_heads=num_attention_heads,
            dropout=dropout_rate, batch_first=True
        )

        # Position-wise feed forward
        self.feed_forward = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )

        # Output layers
        self.output_grn = GatedResidualNetwork(
            input_dim=hidden_dim, hidden_dim=hidden_dim, dropout_rate=dropout_rate
        )

        self.output_projection = nn.Linear(hidden_dim, forecast_horizon)

        # Layer normalization
        self.layer_norm1 = nn.LayerNorm(hidden_dim)
        self.layer_norm2 = nn.LayerNorm(hidden_dim)

    def forward(self, time_varying_inputs, static_inputs=None):
        # time_varying_inputs: (batch_size, sequence_length, num_time_features)
        batch_size, seq_len, num_features = time_varying_inputs.shape

        # Variable selection for temporal features
        temporal_features = self.temporal_vsn(time_varying_inputs) # Output: (batch_size, seq_len, hidden_dim)

        # Process static features if available
        if static_inputs is not None and self.static_grn is not None:
            static_features = self.static_grn(static_inputs)  # (batch_size, hidden_dim)
            # Expand to sequence length and combine with temporal features
            static_features_expanded = static_features.unsqueeze(1).expand(-1, seq_len, -1)
            combined_features = temporal_features + static_features_expanded
        else:
            combined_features = temporal_features

        # LSTM encoding
        # combined_features shape: (batch_size, seq_len, hidden_dim)
        lstm_out, (hidden, cell) = self.encoder_lstm(combined_features) # lstm_out shape: (batch_size, seq_len, hidden_dim)

        # Multi-head attention
        # attn_out shape: (batch_size, seq_len, hidden_dim)
        attn_out, _ = self.multihead_attn(lstm_out, lstm_out, lstm_out)

        # Add & Norm (Residual connection)
        # Ensure shapes match for addition
        if attn_out.shape == lstm_out.shape:
             attn_out = self.layer_norm1(attn_out + lstm_out)
        else:
             # If shapes don't match, just normalize attn_out (shouldn't happen if MHA output dim is hidden_dim)
             attn_out = self.layer_norm1(attn_out)


        # Feed forward
        # ff_out shape: (batch_size, seq_len, hidden_dim)
        ff_out = self.feed_forward(attn_out)

        # Add & Norm (Residual connection)
        # Ensure shapes match for addition
        if ff_out.shape == attn_out.shape:
             ff_out = self.layer_norm2(ff_out + attn_out)
        else:
            # If shapes don't match, just normalize ff_out (shouldn't happen if FF output dim is hidden_dim)
            ff_out = self.layer_norm2(ff_out)


        # Use the last time step for prediction
        last_output = ff_out[:, -1, :]  # (batch_size, hidden_dim)

        # Final processing
        processed_output = self.output_grn(last_output) # (batch_size, hidden_dim)

        # Output projection
        predictions = self.output_projection(processed_output) # (batch_size, forecast_horizon)

        return predictions

class WalmartTFTDataset(Dataset):
    """PyTorch dataset for Walmart TFT data"""

    def __init__(self, sequences, targets, static_features):
        self.sequences = sequences
        self.targets = targets
        self.static_features = static_features

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = torch.FloatTensor(self.sequences[idx])
        target = torch.FloatTensor(self.targets[idx])
        static = torch.FloatTensor(self.static_features[idx])
        return sequence, static, target

print("✓ TFT model architecture defined")

✓ TFT model architecture defined


# Data Cleaning Run

In [8]:
# Start new wandb run for data cleaning
wandb.finish()
wandb.init(
    project="walmart-sales-forecasting",
    name="TFT_Cleaning",
    config={"stage": "cleaning"}
)

print("\n=== DATA CLEANING ===")

# Create feature merger and missing value handler
feature_merger = FeatureMerger()
missing_handler = MissingValueHandler()

# Fit the merger with auxiliary data
feature_merger.fit(train_df, stores_df=stores_df, features_df=features_df)

# Merge train data with stores and features
print("Merging train data with stores and features...")
train_merged = feature_merger.transform(train_df)
print(f"Train data shape after merging: {train_merged.shape}")

# Fit and transform missing values
print("Handling missing values...")
missing_handler.fit(train_merged)
train_cleaned = missing_handler.transform(train_merged)

# Check for remaining missing values
remaining_missing = train_cleaned.isnull().sum()
remaining_missing = remaining_missing[remaining_missing > 0]

print(f"\nRemaining missing values after cleaning:")
if len(remaining_missing) > 0:
    print(remaining_missing)
else:
    print("No missing values remaining!")

# Basic data quality checks
print(f"\nData quality checks:")
print(f"Total records: {len(train_cleaned):,}")
print(f"Date range: {train_cleaned['Date'].min()} to {train_cleaned['Date'].max()}")
print(f"Unique store-dept combinations: {train_cleaned.groupby(['Store', 'Dept']).ngroups:,}")

# Check for negative sales (data quality issue)
negative_sales = (train_cleaned['Weekly_Sales'] < 0).sum()
print(f"Records with negative sales: {negative_sales:,} ({negative_sales/len(train_cleaned)*100:.2f}%)")

# Log cleaning metrics
wandb.log({
    "cleaned_records": len(train_cleaned),
    "remaining_missing_values": len(remaining_missing),
    "negative_sales_count": int(negative_sales),
    "negative_sales_pct": float(negative_sales/len(train_cleaned)*100),
    "store_dept_combinations": train_cleaned.groupby(['Store', 'Dept']).ngroups
})

print("\n✓ Data cleaning completed and logged to wandb")

# Save cleaned data for next steps
print("\nSample of cleaned data:")
print(train_cleaned.head())
print(f"\nColumns: {list(train_cleaned.columns)}")

0,1
avg_timeseries_length,▁
avg_weekly_sales,▁
holiday_sales_boost,▁
median_weekly_sales,▁
missing_markdown1_pct,▁
missing_markdown2_pct,▁
missing_markdown3_pct,▁
missing_markdown4_pct,▁
missing_markdown5_pct,▁
sales_std,▁

0,1
avg_timeseries_length,126.55959
avg_weekly_sales,15981.25812
holiday_sales_boost,1.07134
median_weekly_sales,7612.03
missing_markdown1_pct,50.76923
missing_markdown2_pct,64.33455
missing_markdown3_pct,55.88523
missing_markdown4_pct,57.70452
missing_markdown5_pct,50.54945
sales_std,22711.18352



=== DATA CLEANING ===
Merging train data with stores and features...
Train data shape after merging: (421570, 17)
Handling missing values...

Remaining missing values after cleaning:
No missing values remaining!

Data quality checks:
Total records: 421,570
Date range: 2010-02-05 00:00:00 to 2012-10-26 00:00:00
Unique store-dept combinations: 3,331
Records with negative sales: 1,285 (0.30%)

✓ Data cleaning completed and logged to wandb

Sample of cleaned data:
   Store  Dept       Date  Weekly_Sales  IsHoliday_x Type    Size  \
0      1     1 2010-02-05      24924.50        False    A  151315   
1      1     1 2010-02-12      46039.49         True    A  151315   
2      1     1 2010-02-19      41595.55        False    A  151315   
3      1     1 2010-02-26      19403.54        False    A  151315   
4      1     1 2010-03-05      21827.90        False    A  151315   

   Temperature  Fuel_Price  MarkDown1  MarkDown2  MarkDown3  MarkDown4  \
0        42.31       2.572        0.0        

# Feature Selection Run

In [9]:
# Start new wandb run for feature selection
wandb.finish()
wandb.init(
    project="walmart-sales-forecasting",
    name="TFT_Feature_Selection",
    config={"stage": "feature_selection"}
)

print("\n=== FEATURE SELECTION ===")

# For TFT, we categorize features into static and time-varying

# Core features
core_features = ['Store', 'Dept', 'Date', 'Weekly_Sales', 'IsHoliday_x']

# Static features (constant per store-dept combination)
static_features = ['Type', 'Size']

# Time-varying features
time_varying_features = ['Temperature', 'Fuel_Price', 'CPI', 'Unemployment',
                        'MarkDown1', 'MarkDown2', 'MarkDown3', 'MarkDown4', 'MarkDown5']

print(f"Available columns: {list(train_cleaned.columns)}")

# Analyze correlation between features and sales
correlation_analysis = {}

# Time-varying features correlation
for feature in time_varying_features:
    if feature in train_cleaned.columns:
        corr = train_cleaned['Weekly_Sales'].corr(train_cleaned[feature])
        correlation_analysis[feature] = corr
        print(f"Correlation between Weekly_Sales and {feature}: {corr:.4f}")

# Holiday impact analysis
# Handle different holiday column names
holiday_col = 'IsHoliday_x' if 'IsHoliday_x' in train_cleaned.columns else 'IsHoliday'
holiday_impact = train_cleaned.groupby(holiday_col)['Weekly_Sales'].mean()
holiday_boost = holiday_impact[True] / holiday_impact[False] - 1
print(f"\nHoliday sales boost: {holiday_boost:.2%}")

# Store type impact (static feature)
if 'Type' in train_cleaned.columns:
    store_type_sales = train_cleaned.groupby('Type')['Weekly_Sales'].mean()
    print(f"\nAverage sales by store type (static feature):")
    print(store_type_sales)

# Store size impact (static feature)
if 'Size' in train_cleaned.columns:
    size_corr = train_cleaned['Weekly_Sales'].corr(train_cleaned['Size'])
    print(f"\nCorrelation between Weekly_Sales and Store Size: {size_corr:.4f}")

# Select features for TFT
selected_core = core_features.copy()
selected_static = []
selected_time_varying = []

# Always include static features for TFT
for feature in static_features:
    if feature in train_cleaned.columns:
        selected_static.append(feature)

# Include time-varying features with any correlation
for feature, corr in correlation_analysis.items():
    if abs(corr) > 0.001:  # Very low threshold for TFT
        selected_time_varying.append(feature)
        print(f"Selected time-varying {feature} (correlation: {corr:.4f})")

# Combine all selected features
all_selected_features = selected_core + selected_static + selected_time_varying
all_selected_features = list(set(all_selected_features))

print(f"\nSelected Features:")
print(f"Core features ({len(selected_core)}): {selected_core}")
print(f"Static features ({len(selected_static)}): {selected_static}")
print(f"Time-varying features ({len(selected_time_varying)}): {selected_time_varying}")
print(f"Total selected features: {len(all_selected_features)}")

# Create feature-selected dataset
train_selected = train_cleaned[all_selected_features].copy()

print(f"\nFeature-selected data shape: {train_selected.shape}")

# Log feature selection metrics
wandb.log({
    "total_available_features": len(train_cleaned.columns),
    "selected_features_count": len(all_selected_features),
    "static_features_count": len(selected_static),
    "time_varying_features_count": len(selected_time_varying),
    "holiday_sales_boost": float(holiday_boost),
    "selected_features": all_selected_features,
    "static_features": selected_static,
    "time_varying_features": selected_time_varying,
    **{f"corr_{k}": v for k, v in correlation_analysis.items() if not np.isnan(v)}
})

print("\n✓ Feature selection completed and logged to wandb")

0,1
cleaned_records,▁
negative_sales_count,▁
negative_sales_pct,▁
remaining_missing_values,▁
store_dept_combinations,▁

0,1
cleaned_records,421570.0
negative_sales_count,1285.0
negative_sales_pct,0.30481
remaining_missing_values,0.0
store_dept_combinations,3331.0



=== FEATURE SELECTION ===
Available columns: ['Store', 'Dept', 'Date', 'Weekly_Sales', 'IsHoliday_x', 'Type', 'Size', 'Temperature', 'Fuel_Price', 'MarkDown1', 'MarkDown2', 'MarkDown3', 'MarkDown4', 'MarkDown5', 'CPI', 'Unemployment', 'IsHoliday_y']
Correlation between Weekly_Sales and Temperature: -0.0023
Correlation between Weekly_Sales and Fuel_Price: -0.0001
Correlation between Weekly_Sales and CPI: -0.0209
Correlation between Weekly_Sales and Unemployment: -0.0259
Correlation between Weekly_Sales and MarkDown1: 0.0472
Correlation between Weekly_Sales and MarkDown2: 0.0207
Correlation between Weekly_Sales and MarkDown3: 0.0386
Correlation between Weekly_Sales and MarkDown4: 0.0375
Correlation between Weekly_Sales and MarkDown5: 0.0505

Holiday sales boost: 7.13%

Average sales by store type (static feature):
Type
A    20099.568043
B    12237.075977
C     9519.532538
Name: Weekly_Sales, dtype: float64

Correlation between Weekly_Sales and Store Size: 0.2438
Selected time-varying Te

# Cross Validation Run

In [10]:
# Start new wandb run for cross validation
wandb.finish()
wandb.init(
    project="walmart-sales-forecasting",
    name="TFT_Cross_Validation",
    config={
        "stage": "cross_validation",
        "lookback_window": 52,
        "forecast_horizon": 1,
        "model_type": "TFT"
    }
)

print("\n=== CROSS VALIDATION ===")

# Use the same efficient approach as N-BEATS
class EfficientTFTDataProcessor(BaseEstimator, TransformerMixin):
    """Efficient TFT processor - similar to N-BEATS approach"""

    def __init__(self, lookback_window=52, forecast_horizon=1):
        self.lookback_window = lookback_window
        self.forecast_horizon = forecast_horizon
        self.store_dept_combinations = None
        self.date_range = None

    def fit(self, X, y=None):
        """Learn the store-department combinations and date range"""
        self.store_dept_combinations = X.groupby(['Store', 'Dept']).size().index.tolist()
        self.date_range = sorted(X['Date'].unique())
        print(f"Found {len(self.store_dept_combinations)} store-dept combinations")
        print(f"Date range: {self.date_range[0]} to {self.date_range[-1]}")
        return self

    def transform(self, X):
        """Transform data into sequences for TFT - efficient like N-BEATS"""
        sequences = []
        targets = []
        static_features = []
        metadata = []

        for store, dept in self.store_dept_combinations:
            # Get time series for this store-dept combination
            series_data = X[(X['Store'] == store) & (X['Dept'] == dept)].copy()
            series_data = series_data.sort_values('Date')

            if len(series_data) < self.lookback_window + self.forecast_horizon:
                continue

            # Static features (simple, like N-BEATS)
            static_feat = []
            if 'Type' in series_data.columns:
                type_map = {'A': 0, 'B': 1, 'C': 2}
                static_feat.append(type_map.get(series_data['Type'].iloc[0], 0))
            if 'Size' in series_data.columns:
                static_feat.append(series_data['Size'].iloc[0] / 200000.0)

            # Add store/dept as features
            static_feat.extend([store / 45.0, dept / 100.0])

            # Create sliding windows (same approach as N-BEATS)
            for i in range(len(series_data) - self.lookback_window - self.forecast_horizon + 1):
                window_data = series_data.iloc[i:i + self.lookback_window]

                # Sales sequence (target variable)
                sales_sequence = window_data['Weekly_Sales'].values

                # Check for valid sales data
                if np.any(np.isnan(sales_sequence)) or np.any(np.isinf(sales_sequence)):
                    continue

                # External features (same as N-BEATS)
                external_features = []
                if 'Temperature' in window_data.columns:
                    temp_vals = window_data['Temperature'].fillna(method='ffill').fillna(method='bfill').fillna(0)
                    external_features.append(temp_vals.values)
                if 'Fuel_Price' in window_data.columns:
                    fuel_vals = window_data['Fuel_Price'].fillna(method='ffill').fillna(method='bfill').fillna(0)
                    external_features.append(fuel_vals.values)
                if 'CPI' in window_data.columns:
                    cpi_vals = window_data['CPI'].fillna(method='ffill').fillna(method='bfill').fillna(0)
                    external_features.append(cpi_vals.values)
                if 'Unemployment' in window_data.columns:
                    unemp_vals = window_data['Unemployment'].fillna(method='ffill').fillna(method='bfill').fillna(0)
                    external_features.append(unemp_vals.values)

                # Holiday feature (simplified)
                if 'IsHoliday_x' in window_data.columns:
                    holiday_vals = window_data['IsHoliday_x'].astype(float).values
                    external_features.append(holiday_vals)
                elif 'IsHoliday' in window_data.columns:
                    holiday_vals = window_data['IsHoliday'].astype(float).values
                    external_features.append(holiday_vals)

                # Combine features (same as N-BEATS approach)
                if external_features:
                    try:
                        feature_matrix = np.column_stack([sales_sequence] + external_features)
                    except ValueError:
                        feature_matrix = sales_sequence.reshape(-1, 1)
                else:
                    feature_matrix = sales_sequence.reshape(-1, 1)

                sequences.append(feature_matrix)
                static_features.append(np.array(static_feat))

                # Target (next forecast_horizon values)
                target_data = series_data.iloc[i + self.lookback_window:i + self.lookback_window + self.forecast_horizon]
                target_sales = target_data['Weekly_Sales'].values

                # Check for valid target data
                if np.any(np.isnan(target_sales)) or np.any(np.isinf(target_sales)):
                    continue

                targets.append(target_sales)

                # Metadata
                metadata.append({
                    'store': store,
                    'dept': dept,
                    'start_date': window_data['Date'].iloc[0],
                    'end_date': window_data['Date'].iloc[-1],
                    'forecast_date': target_data['Date'].iloc[0] if len(target_data) > 0 else None
                })

        print(f"Generated {len(sequences)} valid sequences from {len(self.store_dept_combinations)} store-dept combinations")

        return {
            'sequences': np.array(sequences, dtype=object),
            'targets': np.array(targets, dtype=object),
            'static_features': np.array(static_features, dtype=object),
            'metadata': metadata
        }

# Create efficient TFT data processor
tft_processor = EfficientTFTDataProcessor(lookback_window=52, forecast_horizon=1)

# Fit and transform the data
print("Processing time-series data for TFT...")
tft_processor.fit(train_selected)
processed_data = tft_processor.transform(train_selected)

sequences = processed_data['sequences']
targets = processed_data['targets']
static_features = processed_data['static_features']
metadata = processed_data['metadata']

print(f"Generated {len(sequences)} sequences")
if len(sequences) > 0:
    print(f"Sequence shape example: {sequences[0].shape}")
    print(f"Target shape example: {targets[0].shape}")
    print(f"Static features shape example: {static_features[0].shape}")

if len(sequences) == 0:
    print("❌ No sequences generated. Check data processing.")
    wandb.log({"sequences_generated": 0, "processing_failed": True})
else:
    # Convert to consistent numpy arrays (same as N-BEATS)
    max_time_features = max([seq.shape[1] if len(seq.shape) > 1 else 1 for seq in sequences])
    max_static_features = max([sf.shape[0] if len(sf.shape) > 0 else 1 for sf in static_features])
    lookback_length = sequences[0].shape[0]

    # Pad sequences to have consistent feature count and convert to float32
    padded_sequences = []
    padded_static = []
    valid_targets = []

    for i, (seq, static, tgt) in enumerate(zip(sequences, static_features, targets)):
        if len(seq.shape) == 1:
            seq = seq.reshape(-1, 1)

        # Pad features if necessary
        if seq.shape[1] < max_time_features:
            padding = np.zeros((seq.shape[0], max_time_features - seq.shape[1]), dtype=np.float32)
            seq = np.column_stack([seq, padding]).astype(np.float32)
        else:
            seq = seq.astype(np.float32)

        # Process static features
        if len(static.shape) == 0:
            static = static.reshape(1)

        if static.shape[0] < max_static_features:
            padding = np.zeros(max_static_features - static.shape[0], dtype=np.float32)
            static = np.concatenate([static, padding]).astype(np.float32)
        else:
            static = static.astype(np.float32)

        padded_sequences.append(seq)
        padded_static.append(static)
        valid_targets.append(tgt.astype(np.float32))

    sequences_np = np.array(padded_sequences, dtype=np.float32)
    static_np = np.array(padded_static, dtype=np.float32)
    targets_np = np.array(valid_targets, dtype=np.float32)

    print(f"Processed sequences shape: {sequences_np.shape}")
    print(f"Processed static features shape: {static_np.shape}")
    print(f"Processed targets shape: {targets_np.shape}")

    # Time-based train-validation split (same as N-BEATS)
    split_idx = int(0.8 * len(sequences_np))

    X_train_cv = sequences_np[:split_idx]
    X_static_train_cv = static_np[:split_idx]
    y_train_cv = targets_np[:split_idx]

    X_val_cv = sequences_np[split_idx:]
    X_static_val_cv = static_np[split_idx:]
    y_val_cv = targets_np[split_idx:]

    print(f"\nTrain-Validation Split:")
    print(f"Training sequences: {len(X_train_cv)}")
    print(f"Validation sequences: {len(X_val_cv)}")

    # Create datasets and dataloaders
    train_dataset = WalmartTFTDataset(X_train_cv, y_train_cv, X_static_train_cv)
    val_dataset = WalmartTFTDataset(X_val_cv, y_val_cv, X_static_val_cv)

    batch_size = 32
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Simplified TFT model (closer to N-BEATS complexity)
    class SimplifiedTFT(nn.Module):
        """Simplified TFT for reasonable training time"""

        def __init__(self, num_time_features, num_static_features, hidden_dim=128,
                     num_attention_heads=4, dropout_rate=0.1, forecast_horizon=1):
            super().__init__()
            self.hidden_dim = hidden_dim

            # Simple projections instead of complex VSN
            self.temporal_projection = nn.Linear(num_time_features, hidden_dim)

            if num_static_features > 0:
                self.static_projection = nn.Linear(num_static_features, hidden_dim)
            else:
                self.static_projection = None

            # LSTM encoder (like N-BEATS but with attention)
            self.encoder_lstm = nn.LSTM(
                input_size=hidden_dim, hidden_size=hidden_dim,
                batch_first=True, dropout=dropout_rate
            )

            # Simple attention
            self.multihead_attn = nn.MultiheadAttention(
                embed_dim=hidden_dim, num_heads=num_attention_heads,
                dropout=dropout_rate, batch_first=True
            )

            # Output layers
            self.output_projection = nn.Linear(hidden_dim, forecast_horizon)

            # Layer normalization
            self.layer_norm = nn.LayerNorm(hidden_dim)

        def forward(self, time_varying_inputs, static_inputs=None):
            batch_size, seq_len, num_features = time_varying_inputs.shape

            # Project temporal features
            temporal_features = self.temporal_projection(time_varying_inputs)

            # Add static features if available
            if static_inputs is not None and self.static_projection is not None:
                static_features = self.static_projection(static_inputs)
                static_features = static_features.unsqueeze(1).expand(-1, seq_len, -1)
                combined_features = temporal_features + static_features
            else:
                combined_features = temporal_features

            # LSTM encoding
            lstm_out, _ = self.encoder_lstm(combined_features)

            # Attention
            attn_out, _ = self.multihead_attn(lstm_out, lstm_out, lstm_out)

            # Add & Norm
            attn_out = self.layer_norm(attn_out + lstm_out)

            # Use the last time step for prediction
            last_output = attn_out[:, -1, :]

            # Output projection
            predictions = self.output_projection(last_output)

            return predictions

    # Initialize simplified TFT model
    model_config = {
        "num_time_features": max_time_features,
        "num_static_features": max_static_features,
        "hidden_dim": 128,
        "num_attention_heads": 4,
        "dropout_rate": 0.1,
        "forecast_horizon": 1
    }

    model = SimplifiedTFT(**model_config).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()

    print(f"\nSimplified TFT model initialized with config: {model_config}")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Training loop (same as N-BEATS)
    num_epochs = 5  # Limited epochs for CV
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0

        for batch_idx, (sequences, static_feat, targets) in enumerate(train_loader):
            sequences = sequences.to(device)
            static_feat = static_feat.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            outputs = model(sequences, static_feat)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            if batch_idx % 50 == 0:
                print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}')

        avg_train_loss = train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # Validation phase
        model.eval()
        val_loss = 0.0
        all_predictions = []
        all_targets = []

        with torch.no_grad():
            for sequences, static_feat, targets in val_loader:
                sequences = sequences.to(device)
                static_feat = static_feat.to(device)
                targets = targets.to(device)

                outputs = model(sequences, static_feat)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                all_predictions.extend(outputs.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

        # Calculate metrics
        all_predictions = np.array(all_predictions).flatten()
        all_targets = np.array(all_targets).flatten()

        val_mae = mean_absolute_error(all_targets, all_predictions)
        val_rmse = np.sqrt(mean_squared_error(all_targets, all_predictions))
        val_r2 = r2_score(all_targets, all_predictions)

        # Safe MAPE calculation
        def safe_mape(y_true, y_pred):
            mask = y_true != 0
            if mask.sum() == 0:
                return float('inf')
            return np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100

        val_mape = safe_mape(all_targets, all_predictions)

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {avg_train_loss:.4f}')
        print(f'  Val Loss: {avg_val_loss:.4f}')
        print(f'  Val MAE: {val_mae:.2f}')
        print(f'  Val RMSE: {val_rmse:.2f}')
        print(f'  Val MAPE: {val_mape:.2f}%')
        print(f'  Val R²: {val_r2:.4f}')

        # Log to wandb
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": avg_train_loss,
            "val_loss": avg_val_loss,
            "val_mae": val_mae,
            "val_rmse": val_rmse,
            "val_mape": val_mape if not np.isinf(val_mape) else 0.0,
            "val_r2": val_r2
        })

    # Final CV results
    final_metrics = {
        "cv_final_train_loss": train_losses[-1],
        "cv_final_val_loss": val_losses[-1],
        "cv_final_val_mae": val_mae,
        "cv_final_val_rmse": val_rmse,
        "cv_final_val_mape": val_mape if not np.isinf(val_mape) else 0.0,
        "cv_final_val_r2": val_r2,
        "sequences_generated": len(sequences_np),
        "train_sequences": len(X_train_cv),
        "val_sequences": len(X_val_cv)
    }

    wandb.log(final_metrics)

    print("\n✓ Cross validation completed and logged to wandb")
    print(f"Final validation metrics: MAE={val_mae:.2f}, RMSE={val_rmse:.2f}, MAPE={val_mape:.2f}%, R²={val_r2:.4f}")

0,1
corr_CPI,▁
corr_Fuel_Price,▁
corr_MarkDown1,▁
corr_MarkDown2,▁
corr_MarkDown3,▁
corr_MarkDown4,▁
corr_MarkDown5,▁
corr_Temperature,▁
corr_Unemployment,▁
holiday_sales_boost,▁

0,1
corr_CPI,-0.02092
corr_Fuel_Price,-0.00012
corr_MarkDown1,0.04717
corr_MarkDown2,0.02072
corr_MarkDown3,0.03856
corr_MarkDown4,0.03747
corr_MarkDown5,0.05047
corr_Temperature,-0.00231
corr_Unemployment,-0.02586
holiday_sales_boost,0.07134



=== CROSS VALIDATION ===
Processing time-series data for TFT...
Found 3331 store-dept combinations
Date range: 2010-02-05 00:00:00 to 2012-10-26 00:00:00
Generated 261083 valid sequences from 3331 store-dept combinations
Generated 261083 sequences
Sequence shape example: (52, 5)
Target shape example: (1,)
Static features shape example: (4,)
Processed sequences shape: (261083, 52, 5)
Processed static features shape: (261083, 4)
Processed targets shape: (261083, 1)

Train-Validation Split:
Training sequences: 208866
Validation sequences: 52217

Simplified TFT model initialized with config: {'num_time_features': 5, 'num_static_features': 4, 'hidden_dim': 128, 'num_attention_heads': 4, 'dropout_rate': 0.1, 'forecast_horizon': 1}
Total parameters: 199,937
Epoch 1/5, Batch 0, Loss: 461803552.0000
Epoch 1/5, Batch 50, Loss: 735711808.0000
Epoch 1/5, Batch 100, Loss: 492844000.0000
Epoch 1/5, Batch 150, Loss: 198830752.0000
Epoch 1/5, Batch 200, Loss: 719259840.0000
Epoch 1/5, Batch 250, Loss

# Final Training & Model Registry

In [12]:
# Start new wandb run for final training
wandb.finish()
wandb.init(
    project="walmart-sales-forecasting",
    name="TFT_Final_Training",
    config={
        "stage": "final_training",
        "model_config": model_config,
        "num_epochs": 20,
        "batch_size": 32
    }
)

print("\n=== FINAL TRAINING ===")

# OPTIMIZATION: Use the efficient processor from CV instead of re-creating
print("Re-processing data for final training (using efficient processor)...")

# Use the efficient processor we defined in CV instead of the slow original
tft_processor_final = EfficientTFTDataProcessor(lookback_window=52, forecast_horizon=1)
tft_processor_final.fit(train_selected)
processed_data_final = tft_processor_final.transform(train_selected)

sequences_final = processed_data_final['sequences']
targets_final = processed_data_final['targets']
static_features_final = processed_data_final['static_features']
metadata_final = processed_data_final['metadata']

print(f"Re-generated {len(sequences_final)} sequences for final training")

if len(sequences_final) == 0:
    print("❌ No sequences generated for final training. Check data processing.")
    wandb.log({"final_training_failed": True, "reason": "no_sequences"})
else:
    # Convert to consistent numpy arrays (same as CV)
    max_time_features = max([seq.shape[1] if len(seq.shape) > 1 else 1 for seq in sequences_final])
    max_static_features = max([sf.shape[0] if len(sf.shape) > 0 else 1 for sf in static_features_final])
    lookback_length = sequences_final[0].shape[0]

    padded_sequences_final = []
    padded_static_final = []
    valid_targets_final = []

    for i, (seq, static, tgt) in enumerate(zip(sequences_final, static_features_final, targets_final)):
        # Process time-varying sequence
        if len(seq.shape) == 1:
            seq = seq.reshape(-1, 1)

        if seq.shape[1] < max_time_features:
            padding = np.zeros((seq.shape[0], max_time_features - seq.shape[1]), dtype=np.float32)
            seq = np.column_stack([seq, padding]).astype(np.float32)
        else:
            seq = seq.astype(np.float32)

        # Process static features
        if len(static.shape) == 0:
            static = static.reshape(1)

        if static.shape[0] < max_static_features:
            padding = np.zeros(max_static_features - static.shape[0], dtype=np.float32)
            static = np.concatenate([static, padding]).astype(np.float32)
        else:
            static = static.astype(np.float32)

        padded_sequences_final.append(seq)
        padded_static_final.append(static)
        valid_targets_final.append(tgt.astype(np.float32))

    sequences_final_np = np.array(padded_sequences_final, dtype=np.float32)
    static_final_np = np.array(padded_static_final, dtype=np.float32)
    targets_final_np = np.array(valid_targets_final, dtype=np.float32)

    print(f"Final training data shape: {sequences_final_np.shape}")
    print(f"Final static features shape: {static_final_np.shape}")
    print(f"Final training targets shape: {targets_final_np.shape}")

    # Create dataset with CPU numpy arrays
    final_dataset = WalmartTFTDataset(sequences_final_np, targets_final_np, static_final_np)
    final_loader = DataLoader(final_dataset, batch_size=32, shuffle=True)

    # OPTIMIZATION: Use the simplified model from CV instead of complex TFT
    # This will train much faster while still providing good results
    final_model = SimplifiedTFT(**model_config).to(device)
    final_optimizer = optim.Adam(final_model.parameters(), lr=0.001)
    final_criterion = nn.MSELoss()

    print(f"Training on {len(sequences_final_np)} sequences...")
    print(f"Using simplified TFT model with {sum(p.numel() for p in final_model.parameters()):,} parameters")

    # OPTIMIZATION: Slightly reduced epochs for first run
    num_epochs_final = 15  # Reduced from 20 to 15 for faster completion
    best_loss = float('inf')

    for epoch in range(num_epochs_final):
        final_model.train()
        epoch_loss = 0.0

        for batch_idx, (sequences_batch, static_batch, targets_batch) in enumerate(final_loader):
            sequences_batch = sequences_batch.to(device)
            static_batch = static_batch.to(device)
            targets_batch = targets_batch.to(device)

            final_optimizer.zero_grad()
            outputs = final_model(sequences_batch, static_batch)
            loss = final_criterion(outputs, targets_batch)
            loss.backward()
            final_optimizer.step()

            epoch_loss += loss.item()

        avg_epoch_loss = epoch_loss / len(final_loader)

        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss

        if (epoch + 1) % 3 == 0:  # More frequent logging
            print(f'Final Training Epoch {epoch+1}/{num_epochs_final}, Loss: {avg_epoch_loss:.4f}')

            # Log training progress
            wandb.log({
                "final_epoch": epoch + 1,
                "final_train_loss": avg_epoch_loss,
                "best_loss": best_loss
            })

    # Final evaluation on training data
    print("\nEvaluating final model...")
    final_model.eval()
    all_final_predictions = []
    all_final_targets = []

    with torch.no_grad():
        for sequences_batch, static_batch, targets_batch in final_loader:
            sequences_batch = sequences_batch.to(device)
            static_batch = static_batch.to(device)
            outputs = final_model(sequences_batch, static_batch)
            all_final_predictions.extend(outputs.cpu().numpy())
            all_final_targets.extend(targets_batch.numpy())

    all_final_predictions = np.array(all_final_predictions).flatten()
    all_final_targets = np.array(all_final_targets).flatten()

    # Calculate final metrics
    final_mae = mean_absolute_error(all_final_targets, all_final_predictions)
    final_rmse = np.sqrt(mean_squared_error(all_final_targets, all_final_predictions))
    final_r2 = r2_score(all_final_targets, all_final_predictions)

    # Safe MAPE calculation
    def safe_mape(y_true, y_pred):
        mask = y_true != 0
        if mask.sum() == 0:
            return float('inf')
        return np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100

    final_mape = safe_mape(all_final_targets, all_final_predictions)

    print(f"\nFinal Training Metrics:")
    print(f"MAE: {final_mae:.2f}")
    print(f"RMSE: {final_rmse:.2f}")
    print(f"MAPE: {final_mape:.2f}%")
    print(f"R²: {final_r2:.4f}")

    # Create complete pipeline
    class TFTPipeline:
        """Complete pipeline for TFT inference"""

        def __init__(self, feature_merger, missing_handler, tft_processor, model):
            self.feature_merger = feature_merger
            self.missing_handler = missing_handler
            self.tft_processor = tft_processor
            self.model = model
            self.model.eval()

        def predict(self, X_raw, stores_df=None, features_df=None):
            """Make predictions on raw test data"""
            # If auxiliary data provided, update the merger
            if stores_df is not None or features_df is not None:
                self.feature_merger.fit(X_raw, stores_df=stores_df, features_df=features_df)

            # Process through pipeline
            merged_data = self.feature_merger.transform(X_raw)
            cleaned_data = self.missing_handler.transform(merged_data)

            # Process for TFT
            processed = self.tft_processor.transform(cleaned_data)

            if len(processed['sequences']) == 0:
                return np.array([])

            # Convert to tensors and predict
            sequences_tensor = torch.FloatTensor(processed['sequences']).to(device)
            static_tensor = torch.FloatTensor(processed['static_features']).to(device)

            with torch.no_grad():
                predictions = self.model(sequences_tensor, static_tensor)

            return predictions.cpu().numpy().flatten()

    # Create final pipeline
    final_pipeline = TFTPipeline(
        feature_merger=feature_merger,
        missing_handler=missing_handler,
        tft_processor=tft_processor_final,
        model=final_model
    )

    print("\n=== SAVING FINAL MODEL ===")

    # Save pipeline with cloudpickle
    try:
        import cloudpickle
    except ImportError:
        import subprocess
        subprocess.check_call(['pip', 'install', 'cloudpickle'])
        import cloudpickle

    # Create filename
    pipeline_filename = f"tft_pipeline_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pkl"

    # Save with cloudpickle
    with open(pipeline_filename, 'wb') as f:
        cloudpickle.dump(final_pipeline, f)

    print(f"Pipeline saved as: {pipeline_filename}")

    # Try to upload to wandb with error handling
    try:
        # Create model artifact
        model_artifact = wandb.Artifact(
            name="TFT_pipeline",
            type="model",
            description="Final TFT pipeline for Walmart sales forecasting",
            metadata={
                "train_mae": float(final_mae),
                "train_rmse": float(final_rmse),
                "train_mape": float(final_mape) if not np.isinf(final_mape) else 0.0,
                "train_r2": float(final_r2),
                "sequences_count": len(sequences_final_np),
                "training_samples": len(all_final_targets),
                "model_type": "TFT_Simplified",
                "lookback_window": 52,
                "forecast_horizon": 1,
                "hidden_dim": model_config["hidden_dim"],
                "num_attention_heads": model_config["num_attention_heads"],
                "optimization": "simplified_for_speed"
            }
        )

        # Add model file to artifact
        model_artifact.add_file(pipeline_filename)

        # Log artifact
        wandb.log_artifact(model_artifact)
        print("✓ Model artifact logged to wandb successfully!")

    except Exception as e:
        print(f"⚠️ Error uploading to wandb: {e}")
        print("Model saved locally - you can manually upload later")

        # Log just the metrics without artifact
        wandb.log({
            'final_train_mae': final_mae,
            'final_train_rmse': final_rmse,
            'final_train_mape': final_mape if not np.isinf(final_mape) else 0.0,
            'final_train_r2': final_r2,
            'model_saved_locally': pipeline_filename
        })

    # Final summary
    print(f"\n" + "="*60)
    print("FINAL MODEL SUMMARY")
    print("="*60)
    print(f"Model Type: Temporal Fusion Transformer (Simplified)")
    print(f"Training Sequences: {len(all_final_targets):,}")
    print(f"Lookback Window: 52 weeks")
    print(f"Forecast Horizon: 1 week")
    print(f"Time-varying Features: {max_time_features}")
    print(f"Static Features: {max_static_features}")
    print(f"Training MAE: {final_mae:.2f}")
    print(f"Training RMSE: {final_rmse:.2f}")
    print(f"Training MAPE: {final_mape:.2f}%")
    print(f"Training R²: {final_r2:.4f}")
    print(f"Pipeline saved as: {pipeline_filename}")
    print(f"Optimizations: Efficient processor, simplified model, 15 epochs")
    print("="*60)

wandb.finish()
print("\n✓ Final training completed and model saved!")


=== FINAL TRAINING ===
Re-processing data for final training (using efficient processor)...
Found 3331 store-dept combinations
Date range: 2010-02-05 00:00:00 to 2012-10-26 00:00:00
Generated 261083 valid sequences from 3331 store-dept combinations
Re-generated 261083 sequences for final training
Final training data shape: (261083, 52, 5)
Final static features shape: (261083, 4)
Final training targets shape: (261083, 1)
Training on 261083 sequences...
Using simplified TFT model with 199,937 parameters
Final Training Epoch 3/15, Loss: 518992803.8534
Final Training Epoch 6/15, Loss: 519063895.9029
Final Training Epoch 9/15, Loss: 518917823.9157
Final Training Epoch 12/15, Loss: 508648936.8830
Final Training Epoch 15/15, Loss: 518921973.2497

Evaluating final model...

Final Training Metrics:
MAE: 15366.54
RMSE: 22776.36
MAPE: 15608.71%
R²: 0.0000

=== SAVING FINAL MODEL ===
Pipeline saved as: tft_pipeline_20250706_185107.pkl
✓ Model artifact logged to wandb successfully!

FINAL MODEL SU

0,1
best_loss,▁▁▁▁▁
final_epoch,▁▃▅▆█
final_train_loss,███▁█

0,1
best_loss,319042889.71384
final_epoch,15.0
final_train_loss,518921973.24966



✓ Final training completed and model saved!
