In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import json
from sklearn.metrics import mean_squared_error, r2_score
from matplotlib.colors import LinearSegmentedColormap

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Config:
    TARGET = 'AWS'
    USE_LAG_FEATURES = True
    USE_ROLLING_STATISTICS = True
    MIN_VALUE = 0.0  # Minimum allowed value
    MAX_VALUE = 10.0  # Maximum allowed value
    PLOT_MAX = 1   

selected_features = [
    'TCW', 'TCLW', 'R250', 'R500', 'R850', 'U850', 'V850', 'EWSS', 'KX', 'CAPE', 'SSHF', 'PEV'
]

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size=64, num_layers=2, dropout=0.0, time_step_out=1):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_size, time_step_out)
        self.time_step_out = time_step_out

    def forward(self, x):
        out, _ = self.lstm(x)
        return self.fc(out[:, -1, :])

def load_data(file_path):
    df = pd.read_csv(file_path)
    return df

def predict_with_model(model, test_x):
    model.eval()
    predictions = []
    with torch.no_grad():
        for i in range(0, test_x.shape[0], 128):  
            batch_x = test_x[i:i+128].to(DEVICE)
            batch_pred = model(batch_x).cpu().numpy()
            batch_pred = np.clip(batch_pred, Config.MIN_VALUE, Config.MAX_VALUE)
            predictions.append(batch_pred)
    return np.vstack(predictions)

In [2]:
def handle_missing_values(df, lag_steps=None, window_sizes=None):
    result_df = df.copy()
    if lag_steps:
        for lag in lag_steps:
            lag_col = f'{Config.TARGET}_lag{lag}'
            if lag_col in result_df.columns:
                result_df[lag_col] = result_df[lag_col].fillna(0)
    if window_sizes:
        for window in window_sizes:
            mean_col = f'{Config.TARGET}_rollmean_{window}'
            std_col = f'{Config.TARGET}_rollstd_{window}'
            
            if mean_col in result_df.columns:
                result_df[mean_col] = result_df[mean_col].fillna(0)
            
            if std_col in result_df.columns:
                result_df[std_col] = result_df[std_col].fillna(0)
    result_df = result_df.fillna(0)
    return result_df

def prepare_data_for_prediction(train_df, val_df, test_df, best_params):
    time_step_in = best_params["time_step_in"]
    time_step_out = best_params["time_step_out"]
    stride = best_params["stride"]
    
    # Extract lag steps and window sizes
    num_lags = best_params.get("num_lags", 0)
    lag_steps = [best_params[f"lag_{i}"] for i in range(num_lags)] if num_lags > 0 else []
    
    num_windows = best_params.get("num_windows", 0)
    window_sizes = [best_params[f"window_{i}"] for i in range(num_windows)] if num_windows > 0 else []
    
    # Make sure all dataframes are sorted by time
    for df in [train_df, val_df, test_df]:
        if 'DATETIME' in df.columns:
            df.sort_values("DATETIME", inplace=True)
    original_test_df = test_df.copy()
    
    # 1. Create lag features
    if Config.USE_LAG_FEATURES and lag_steps:
        combined_df = pd.concat([train_df, val_df, test_df]).reset_index(drop=True)
        combined_df = create_lag_features(combined_df, Config.TARGET, lag_steps)
        test_df = combined_df.iloc[len(train_df) + len(val_df):].reset_index(drop=True)
    
    # 2. Create rolling statistics
    if Config.USE_ROLLING_STATISTICS and window_sizes:
        combined_df = pd.concat([train_df, val_df, test_df]).reset_index(drop=True)
        combined_df = create_rolling_statistics(combined_df, Config.TARGET, window_sizes)
        test_df = combined_df.iloc[len(train_df) + len(val_df):].reset_index(drop=True)
    
    # 3. Handle missing values
    test_df = handle_missing_values(test_df, lag_steps, window_sizes)
    
    # 4. Prepare feature columns
    basic_cols = [col for col in selected_features if col in test_df.columns]
    lag_cols = [f'{Config.TARGET}_lag{lag}' for lag in lag_steps if f'{Config.TARGET}_lag{lag}' in test_df.columns]
    roll_cols = []
    
    for window in window_sizes:
        mean_col = f'{Config.TARGET}_rollmean_{window}'
        std_col = f'{Config.TARGET}_rollstd_{window}'
        if mean_col in test_df.columns:
            roll_cols.append(mean_col)
        if std_col in test_df.columns:
            roll_cols.append(std_col)
    
    feature_cols = basic_cols + lag_cols + roll_cols
    # 5. Create sequences
    test_x, test_y = create_sequences(test_df, feature_cols, Config.TARGET, time_step_in, time_step_out, stride)
    return test_x, original_test_df, feature_cols

In [3]:
def create_sequences(df, input_cols, target_col, time_step_in, time_step_out=1, stride=1):
    sequences, targets = [], []
    grouped = df.groupby(['ROW', 'COL'])
    
    for _, group in grouped:
        if 'DATETIME' in group.columns:
            group = group.sort_values("DATETIME")
        data = group[input_cols].values
        target_data = group[target_col].values
        
        if len(data) < time_step_in + time_step_out:
            continue
        
        for i in range(0, len(data) - time_step_in - time_step_out + 1, stride):
            seq = data[i:i+time_step_in]
            if time_step_out == 1:
                target = target_data[i+time_step_in]
                targets.append(target)
            else:
                target = target_data[i+time_step_in:i+time_step_in+time_step_out]
                targets.append(target)
            sequences.append(seq)
    
    if not sequences:
        return None, None
    
    if time_step_out == 1:
        return torch.tensor(sequences, dtype=torch.float32), torch.tensor(targets, dtype=torch.float32).unsqueeze(1)
    else:
        return torch.tensor(sequences, dtype=torch.float32), torch.tensor(targets, dtype=torch.float32)

def create_lag_features(df, target_column, lag_steps, groupby_cols=['ROW', 'COL']):
    result_df = df.copy()
    for lag in lag_steps:
        result_df[f'{target_column}_lag{lag}'] = result_df.groupby(groupby_cols)[target_column].shift(lag)
    
    return result_df

def create_rolling_statistics(df, target_column, window_sizes, groupby_cols=['ROW', 'COL']):
    result_df = df.copy()
    for window in window_sizes:
        result_df[f'{target_column}_rollmean_{window}'] = result_df.groupby(groupby_cols)[target_column].transform(
            lambda x: x.rolling(window, min_periods=1).mean())
        result_df[f'{target_column}_rollstd_{window}'] = result_df.groupby(groupby_cols)[target_column].transform(
            lambda x: x.rolling(window, min_periods=1).std())
    return result_df

In [None]:
def visualize_rainfall(test_df, predictions, time_step_in, sample_indices=None, num_samples=3, 
                    select_highest_rainfall=False, select_moderate_rainfall=False):
    test_df = test_df.copy()
    test_df['AWS'].replace([-np.inf, np.inf], np.nan, inplace=True)
    mask = ~test_df['AWS'].isna()
    test_df.loc[mask, 'AWS'] = np.clip(test_df.loc[mask, 'AWS'], Config.MIN_VALUE, Config.MAX_VALUE)
    has_datetime = 'DATETIME' in test_df.columns
    if has_datetime:
        unique_timestamps = test_df['DATETIME'].unique()
        valid_timestamps = unique_timestamps[time_step_in:]

        if select_highest_rainfall or select_moderate_rainfall:
            rainfall_by_timestamp = {}
            for timestamp in valid_timestamps:
                timestamp_data = test_df[test_df['DATETIME'] == timestamp].copy()
                timestamp_data['AWS'].replace([-np.inf, np.inf], np.nan, inplace=True)
                if timestamp_data['AWS'].isna().all():
                    print(f"Skipping timestamp {timestamp} because all AWS values are NaN or -inf")
                    continue
                avg_rainfall = timestamp_data['AWS'].replace([-np.inf, np.inf], np.nan).dropna().mean()
                if not np.isnan(avg_rainfall) and avg_rainfall != -np.inf:
                    rainfall_by_timestamp[timestamp] = avg_rainfall
            if not rainfall_by_timestamp:
                if len(valid_timestamps) <= num_samples:
                    sample_indices = list(range(len(valid_timestamps)))
                else:
                    sample_indices = np.random.choice(range(len(valid_timestamps)), num_samples, replace=False)
                    sample_indices.sort()
            else:
                if select_highest_rainfall:
                    sorted_timestamps = sorted(rainfall_by_timestamp.items(), key=lambda x: x[1], reverse=True)
                    label = "highest"
                elif select_moderate_rainfall:
                    median_rainfall = np.median(list(rainfall_by_timestamp.values()))
                    sorted_timestamps = sorted(rainfall_by_timestamp.items(), 
                                              key=lambda x: abs(x[1] - median_rainfall))
                    label = "moderate"
                selected_timestamps = [ts for ts, _ in sorted_timestamps[:num_samples]]

                for i, (ts, avg_rain) in enumerate(sorted_timestamps[:10]):  # Show top 10 for reference
                    if select_moderate_rainfall:
                        median_rainfall = np.median(list(rainfall_by_timestamp.values()))
                        diff_from_median = abs(avg_rain - median_rainfall)
                        print(f"Timestamp: {ts}, Average rainfall: {avg_rain:.4f}, Diff from median: {diff_from_median:.4f}")
                    else:
                        print(f"Timestamp: {ts}, Average rainfall: {avg_rain:.4f}")
                for ts in selected_timestamps:
                    if select_moderate_rainfall:
                        median_rainfall = np.median(list(rainfall_by_timestamp.values()))
                        diff_from_median = abs(rainfall_by_timestamp[ts] - median_rainfall)
                        print(f"Timestamp: {ts}, Average rainfall: {rainfall_by_timestamp[ts]:.4f}, Diff from median: {diff_from_median:.4f}")
                    else:
                        print(f"Timestamp: {ts}, Average rainfall: {rainfall_by_timestamp[ts]:.4f}")

                sample_indices = []
                for ts in selected_timestamps:
                    try:
                        idx = list(valid_timestamps).index(ts)
                        sample_indices.append(idx)
                    except ValueError:
                        print(f"Warning: Timestamp {ts} not found in valid_timestamps")
            
            print(f"Sample indices for visualization: {sample_indices}")
        elif sample_indices is None:
            if len(valid_timestamps) <= num_samples:
                sample_indices = list(range(len(valid_timestamps)))
            else:
                sample_indices = np.random.choice(range(len(valid_timestamps)), num_samples, replace=False)
                sample_indices.sort()  # Sort for consistent ordering
        
        timestamps_to_plot = [valid_timestamps[i] for i in sample_indices]
        for ts in timestamps_to_plot:
            ts_data = test_df[test_df['DATETIME'] == ts]
            print(f"Timestamp: {ts}")
            print(f"  - Mean AWS: {ts_data['AWS'].mean():.4f}")
            print(f"  - Max AWS: {ts_data['AWS'].max():.4f}")
            print(f"  - Min AWS: {ts_data['AWS'].min():.4f}")
            print(f"  - Number of data points: {len(ts_data)}")
    else:
        print("❌ No DATETIME column found in test_df. Cannot create visualization.")
        return
    
    fig, axes = plt.subplots(len(timestamps_to_plot), 2, figsize=(15, 5*len(timestamps_to_plot)))
    if len(timestamps_to_plot) == 1:
        axes = np.array([axes])  
    colors_actual = [(1, 1, 1, 0.7), (0.8, 0.8, 1, 0.8), (0, 0, 1, 1)]  # White to blue with alpha
    colors_pred = [(1, 1, 1, 0.7), (0.8, 0.8, 1, 0.8), (0, 0, 1, 1)]    # White to blue with alpha
    
    cmap_actual = LinearSegmentedColormap.from_list('actual_rainfall', colors_actual)
    cmap_pred = LinearSegmentedColormap.from_list('predicted_rainfall', colors_pred)
    
    for i, timestamp in enumerate(timestamps_to_plot):
        timestamp_data = test_df[test_df['DATETIME'] == timestamp].copy()
        timestamp_indices = timestamp_data.index
        timestamp_data['PREDICTED_AWS'] = 0
        pred_start_idx = 0
        for ts_idx in timestamp_indices:
            if pred_start_idx < len(predictions):
                if predictions[pred_start_idx].ndim > 0:
                    pred_value = float(predictions[pred_start_idx].flatten()[0])
                else:
                    pred_value = float(predictions[pred_start_idx])
                
                pred_value = np.clip(pred_value, Config.MIN_VALUE, Config.MAX_VALUE)
                
                timestamp_data.loc[ts_idx, 'PREDICTED_AWS'] = pred_value
                pred_start_idx += 1
            else:
                pass
        
        # 1. Replace -inf/inf with NaN
        timestamp_data['AWS'].replace([-np.inf, np.inf], np.nan, inplace=True)
        
        # 2. For the purpose of visualization, replace NaN with 0
        # But first report how many values were NaN
        nan_count = timestamp_data['AWS'].isna().sum()
        if nan_count > 0:
            print(f"  - Replacing {nan_count} NaN values with 0 for timestamp {timestamp}")
        
        # 3. Clip values to valid range
        mask = ~timestamp_data['AWS'].isna()
        timestamp_data.loc[mask, 'AWS'] = np.clip(timestamp_data.loc[mask, 'AWS'], Config.MIN_VALUE, Config.MAX_VALUE)
        
        # 4. Finally fill NaNs with 0 for visualization
        timestamp_data['AWS'].fillna(0, inplace=True)
        timestamp_data['PREDICTED_AWS'].fillna(0, inplace=True)
        
        # Set fixed color scaling range from 0 to 0.5
        # Values above 0.5 will all appear as the maximum color intensity
        vmin = 0.0
        vmax = 0.5  # Fixed maximum for visualization purposes
        
        # Calculate metrics for this timestamp, excluding positions where AWS was NaN in original data
        try:
            valid_positions = ~test_df[test_df['DATETIME'] == timestamp]['AWS'].isna()
            
            if valid_positions.sum() > 0:
                valid_actual = timestamp_data.loc[valid_positions.values, 'AWS']
                valid_pred = timestamp_data.loc[valid_positions.values, 'PREDICTED_AWS']
                
                mse = mean_squared_error(valid_actual, valid_pred)
                if len(set(valid_actual)) > 1:  # R² requires variance in the data
                    r2 = r2_score(valid_actual, valid_pred)
                else:
                    r2 = np.nan
                    print(f"  - Cannot calculate R² for timestamp {timestamp}: No variance in actual values")
            else:
                mse = np.nan
                r2 = np.nan
                print(f"  - Cannot calculate metrics for timestamp {timestamp}: No valid AWS values")
        except Exception as e:
            print(f"Error calculating metrics for timestamp {timestamp}: {e}")
            mse = np.nan
            r2 = np.nan

        avg_actual = test_df[test_df['DATETIME'] == timestamp]['AWS'].replace([-np.inf, np.inf], np.nan).dropna().mean()
        if np.isnan(avg_actual):
            avg_actual = 0
            
        avg_pred = timestamp_data['PREDICTED_AWS'].mean()

        ax_actual = axes[i, 0]
        scatter_actual = ax_actual.scatter(
            timestamp_data['COL'], timestamp_data['ROW'], 
            c=timestamp_data['AWS'], 
            cmap=cmap_actual,
            s=30, alpha=0.9,  # Larger dots with more opacity
            vmin=vmin, vmax=vmax,
            edgecolors='none'
        )
        
        ax_actual.set_title(f"Actual Rainfall ({timestamp})\nAvg: {avg_actual:.4f}")
        ax_actual.set_xlabel("Column")
        ax_actual.set_ylabel("Row")
        ax_actual.invert_yaxis()  # Invert y-axis to match image convention
        ax_actual.grid(True, linestyle='--', alpha=0.7)  # Add grid for better visibility
        ax_pred = axes[i, 1]
        scatter_pred = ax_pred.scatter(
            timestamp_data['COL'], timestamp_data['ROW'], 
            c=timestamp_data['PREDICTED_AWS'], 
            cmap=cmap_pred,
            s=30, alpha=0.9,  # Larger dots with more opacity
            vmin=vmin, vmax=vmax,
            edgecolors='none'
        )
        
        ax_pred.set_title(f"Predicted Rainfall ({timestamp})\nAvg: {avg_pred:.4f}, MSE: {mse:.4f}, R²: {r2:.4f}")
        ax_pred.set_xlabel("Column")
        ax_pred.set_ylabel("Row")
        ax_pred.invert_yaxis()  # Invert y-axis to match image convention
        ax_pred.grid(True, linestyle='--', alpha=0.7)  # Add grid for better visibility
        cbar_actual = plt.colorbar(scatter_actual, ax=ax_actual, label='Rainfall Intensity')
        cbar_actual.set_label('Actual Rainfall (0 - 0.5+)', fontsize=10)
        
        cbar_pred = plt.colorbar(scatter_pred, ax=ax_pred, label='Rainfall Intensity')
        cbar_pred.set_label('Predicted Rainfall (0 - 0.5+)', fontsize=10)
    
    plt.tight_layout()
    plt.savefig('rainfall_prediction_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    return fig

In [None]:
base_path = "/kaggle/input/ai-dataimputedataset-k-fold"
month = "2020-04"  # Update as needed: "2019-04", "2019-10", "2020-04", "2020-10"
fold = "fold_5"    # Update as needed: "fold_1" to "fold_5"

folder = os.path.join(base_path, month, fold)
train_file = os.path.join(folder, "processed_train.csv")
val_file = os.path.join(folder, "processed_val.csv")
test_file = os.path.join(folder, "processed_val.csv")
    
best_params_path = "/kaggle/input/lstm-checkpoint/best_params_{}.json".format(month)
checkpoint_path = "/kaggle/input/lstm-checkpoint/best_model_{}.pt".format(month)
    
train_df = load_data(train_file)
val_df = load_data(val_file)
test_df = load_data(test_file)

with open(best_params_path, "r") as f:
    best_params = json.load(f)
    
test_x, original_test_df, feature_cols = prepare_data_for_prediction(train_df, val_df, test_df, best_params)
    
input_size = test_x.shape[2]  # Number of features
model = LSTMModel(
    input_size=input_size,
    hidden_size=best_params["hidden_size"],
    num_layers=best_params["num_layers"],
    dropout=best_params["dropout"],
    time_step_out=best_params["time_step_out"]
).to(DEVICE)
    
model.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
predictions = predict_with_model(model, test_x)
visualize_rainfall(original_test_df, predictions, best_params["time_step_in"], 
                      num_samples=3, select_moderate_rainfall=True)