In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import KFold, StratifiedKFold, TimeSeriesSplit, GroupKFold
from sklearn.preprocessing import RobustScaler, StandardScaler
from sklearn.svm import SVR, LinearSVR
import lightgbm as lgb
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# --- 1. Define Column Lists and Location Blacklist ---
cat_cols = ['Category_Health_Facility_UUID', 'Disease']
key_cols = ['Transformed_Latitude', 'Transformed_Longitude', 'Year', 'Month']
loc_cols = ['Transformed_Latitude', 'Transformed_Longitude']
blacklist_locations = [
    'ID_02f25962-4419-4e72-9402-3f6c513ec454',
    'ID_1b04544b-5c96-4053-9a26-a1ab26f8a4a2',
    'ID_361566e9-3fa1-4c0e-9778-84d5dc605feb',
    'ID_57abdcf0-63df-4ea8-ace6-816787f63911',
    'ID_ecc21577-6c70-4ba3-a2a2-1d9c0e0477f0'
]
unique_locations = [
    'ID_00cd8292-dd85-4fa3-8148-9592e88a1651','ID_02b8390e-f332-4ed5-9e24-9163f09ba478','ID_02f25962-4419-4e72-9402-3f6c513ec454','ID_030ee103-7194-452c-b0e6-58142de19bf6','ID_0358ea0e-af2d-450e-8124-4144ca7860fa','ID_069981aa-3e3d-4008-8460-0b9a564b93ce','ID_0bb8622b-8f9f-4443-b63e-5d1cc9cc46dd','ID_13386124-f7bf-404c-a905-4960e15f1b42','ID_1881eac3-60eb-4a11-b124-3f435335bda0','ID_1883a4d5-4803-4d4b-8c36-af9a1ada280b','ID_193775ae-c108-4564-917e-b99ce259dddf','ID_1b04544b-5c96-4053-9a26-a1ab26f8a4a2','ID_1e0e8bde-3d9e-4846-8ad8-1fa3787d66e6','ID_283f5ebc-25e8-45c0-b20b-f063d851dc2f','ID_2cdf865f-07a9-428f-a5c5-fa19cdf03a49','ID_309b6f8a-c625-4efd-891a-c2d654984998','ID_361566e9-3fa1-4c0e-9778-84d5dc605feb','ID_3a11929e-3317-476d-99f7-1bd9fb58f018','ID_3b04a772-80cb-421b-a64a-01d098fa7e46','ID_41ee755f-c8fb-4605-998d-5dc3d005167c','ID_41fbe2d6-4f32-4976-9048-94252493b5a9','ID_42d6c301-12c2-487e-8955-a126843210be','ID_4705dcf9-f034-4584-b8ab-c8d855f57f94','ID_4716e9d7-2b2b-4d04-adff-e11f43993707','ID_47f5bc9e-f219-451c-b2f7-27e83c58cf7b','ID_57abdcf0-63df-4ea8-ace6-816787f63911','ID_679ca97a-9b9d-400b-9809-b04fc6b4a1c1','ID_69a65631-3c68-4800-9b1a-32996da92c7a','ID_6aa33452-3f11-4c71-819d-e15e6b11f949','ID_6bf0aeb9-67be-4f3a-ad4d-2ac5b7f9c497','ID_6f399bd7-f181-432a-bddc-9fff8dca7f81','ID_704a38c1-35ca-4e81-ab81-02fcf41d1f72','ID_776a9566-8bcb-4936-83a9-e8ef57834a9f','ID_777b4bae-d725-4a1a-9f4a-31b42ebc3e10','ID_79e64e06-dccb-401a-b74d-5d8e4d863615','ID_8016b072-00e9-47d3-b26b-0bc55f541664','ID_82613626-d68d-4d21-b6f0-f90be35004d7','ID_83880e91-c9db-47f3-8cbb-2a1b485260eb','ID_880270f1-8d75-4bc9-965d-27c5a22ec0eb','ID_8b36c0ac-b46c-4c9a-af37-d6717faa340e','ID_8c25ad79-9b7f-4ba0-8e03-69cbfcd39569','ID_8efc0ee1-e183-4518-afc4-83b63b79b1f8','ID_8f26ab15-3d25-43ed-bd52-8b2431c6ba90','ID_90ae3848-6fe9-4326-95ee-d7b8679e7494','ID_90d75acb-4528-43d5-b688-2853541772ea','ID_989f8bc0-6d2a-4ad0-9b8b-641540895ddf','ID_9a9f0ca2-bff2-4e4a-bb38-d9c83ef0602c','ID_a81e0fdf-34f9-4334-b25d-20194fad0381','ID_ac8a77f4-353d-4351-82ed-29cf82cc7775','ID_bd1b57a9-066c-497a-bdef-7ae8e849db14','ID_ddf39ebf-663b-4f32-92b1-a700598dd4b9','ID_e5b10b72-c677-430e-8eee-289c77eeac0b','ID_ebec04b4-99d8-462d-983e-cfbe2294a90e','ID_ecc21577-6c70-4ba3-a2a2-1d9c0e0477f0','ID_ef027061-3d99-4215-8487-374b4ab4699a','ID_f2c528af-f0d7-4c88-ba81-257166928227','ID_f8aa9c00-4ac5-4cea-a386-de165a3671ca','ID_feed50e1-074b-47c9-8ecd-0f65ac9305e6'
]
train_cols = ['Month', 'Year', 'Total_1', 'Total_2',
                 'Transformed_Latitude', 'Transformed_Longitude'] + cat_cols
toilets_cols = ['10u', '10v', '2d', '2t', 'evabs', 'evaow', 'evatc', 'evavt', 'albedo', 'lshf', 'lai_hv', 'lai_lv', 'pev', 'ro', 'src',
                 'skt', 'es', 'stl1', 'stl2', 'stl3', 'stl4', 'ssro', 'slhf', 'ssr', 'str', 'sp', 'sro', 'sshf', 'ssrd', 'strd', 'e', 'tp', 'swvl1', 'swvl2', 'swvl3', 'swvl4']
disease_cols = ['Diarrhea', 'Dysentery', 'Intestinal Worms', 'Malaria', 'Typhoid', 'Cholera']
dis_others = ['Dysentery', 'Dysentery', 'Typhoid', 'Cholera']
waste_cols = [f'{c}_wm' for c in toilets_cols] # Corrected waste_cols suffix to '_wm'
water_cols = [f'{c}_water' for c in toilets_cols]
grp_cols = ['Disease', 'Location', 'Category_Health_Facility_UUID']
target = 'Total'
corr_features = ['10u_water', '10v_water', '2d_water', '2t_water', 'evabs_water', 'evaow_water', 'evatc_water', 'evavt_water', 'albedo_water', 'lshf_water', 'lai_hv_water', 'lai_lv_water', 'pev_water', 'ro_water', 'src_water', 'skt_water', 'es_water', 'stl1_water', 'stl2_water', 'stl3_water', 'stl4_water', 'ssro_water', 'slhf_water', 'ssr_water', 'str_water', 'sp_water', 'sro_water', 'sshf_water', 'ssrd_water', 'strd_water', 'e_water', 'tp_water', 'swvl1_water', 'swvl2_water', 'swvl3_water', 'swvl4_water'] # Example - replace with actual correlated features


# --- 2. Load Datasets ---
train = pd.read_csv("/kaggle/input/sua-outsmarting-outbreaks-challenge/outsmarting-outbreaks-challenge20241207-28044-iehqcg/Train.csv")
test = pd.read_csv("/kaggle/input/sua-outsmarting-outbreaks-challenge/outsmarting-outbreaks-challenge20241207-28044-iehqcg/Test.csv")
toilets = pd.read_csv("/kaggle/input/sua-outsmarting-outbreaks-challenge/outsmarting-outbreaks-challenge20241207-28044-iehqcg/toilets.csv")
waste_management = pd.read_csv("/kaggle/input/sua-outsmarting-outbreaks-challenge/outsmarting-outbreaks-challenge20241207-28044-iehqcg/waste_management.csv")
water_sources = pd.read_csv("/kaggle/input/sua-outsmarting-outbreaks-challenge/outsmarting-outbreaks-challenge20241207-28044-iehqcg/water_sources.csv")

test[target] = 0
test['Predicted_Total'] = np.nan

# --- 3. Aggregate Training Data ---
train_sum = train.groupby(['Disease', 'Location', 'Year', 'Month', 'Category_Health_Facility_UUID'] + loc_cols)[target].sum().reset_index()

# --- 4. Feature Engineering and Preprocessing Function ---
def feature_engineering_and_preprocessing(df):
    df['day'] = 1
    df['date'] = df[['Year', 'Month', 'day']].astype(str).apply(' '.join, axis=1)
    df['date'] = pd.to_datetime(df['date'])
    df['tag'] = 1
    df['tag_id'] = df.groupby(['Disease', 'Location', 'Year', 'Month', 'Category_Health_Facility_UUID'])['tag'].cumsum()
    df['time_index'] = df['date'].astype(int)
    df = df.sort_values(by=['date'])
    df['start_date'] = df['Location'].map(df.groupby(['Location'])['date'].first().to_dict())
    df['diff_date'] = (df['date'] - df['start_date']).dt.days
    return df

train_sum = feature_engineering_and_preprocessing(train_sum)
train = feature_engineering_and_preprocessing(train)
test = feature_engineering_and_preprocessing(test)
test_sum = feature_engineering_and_preprocessing(test)

# --- 5. Lag Feature Engineering Function ---
def create_lag_features(df, data, lag_months=36):
    date_col = data[['date']]
    temp_df = data[[target, 'Disease', 'date', 'Location', 'tag_id']]
    for i in range(1, lag_months + 1):
        temp_df['date'] = date_col['date'] + pd.DateOffset(months=i)
        df = df.merge(temp_df, on=['Disease', 'date', 'Location', 'tag_id'], how='left', suffixes=('', f'_{i}'))
    return df

train = train.sort_values(by=['Year', 'Month'])

data_for_lag = train.groupby(['Disease','date','Location','tag_id'])[target].mean().reset_index()
data_sum_for_lag = train_sum.groupby(['Disease','date','Location','tag_id'])[target].mean().reset_index()

train = create_lag_features(train, data_for_lag)
test = create_lag_features(test, data_for_lag)
train_sum = create_lag_features(train_sum, data_sum_for_lag)
test_sum = create_lag_features(test_sum, data_sum_for_lag)

# --- 6. Merge Environmental Data ---
toilets[loc_cols] = toilets[loc_cols].astype(float).round(0)
waste_management[loc_cols] = waste_management[loc_cols].astype(float).round(0)
water_sources[loc_cols] = water_sources[loc_cols].astype(float).round(0)

toilets = toilets.groupby(key_cols)[toilets_cols].mean().reset_index()
waste_management = waste_management.groupby(key_cols)[toilets_cols].sum().reset_index()
water_sources = water_sources.groupby(key_cols)[toilets_cols].mean().reset_index()

def merge_environmental_data(df):
    df['hosp_id'] = df[['Disease', 'Location']].astype(str).apply('_'.join, axis=1)
    df[loc_cols] = df[loc_cols].astype(float).round(0)
    env_dfs = [(toilets, '_toilets'), (waste_management, '_wm'), (water_sources, '_water')] # Corrected suffix here
    for env_df, suffix in env_dfs:
        df = df.merge(env_df, on=key_cols, how='left', suffixes=('', suffix))
    return df

train = merge_environmental_data(train)
train_sum = merge_environmental_data(train_sum)
test = merge_environmental_data(test)
test_sum = merge_environmental_data(test_sum)

# --- 7. Function to Generate Static Features ---
def get_static_features(df, cols):
    df['medianT'] = df[cols].median(axis=1)
    df['meanT'] = df[cols].mean(axis=1)
    df['maxT'] = df[cols].max(axis=1)
    df['sumT'] = df[cols].sum(axis=1)
    df['prodT'] = df[cols].prod(axis=1)
    df['skewT'] = df[cols].skew(axis=1)
    df['kurtT'] = df[cols].kurt(axis=1)
    df['semT'] = df[cols].sem(axis=1)
    df['stdT'] = df[cols].std(axis=1)
    df['q2T'] = df[cols].quantile(0.75, axis=1)
    df['q4T'] = df[cols].quantile(0.95, axis=1)
    df['q5T'] = df[cols].quantile(0.99, axis=1)
    df['zeros_countT'] = (df[cols] != 0).sum(axis=1)
    return df


# --- 8. Feature Correlation Analysis ---
correlation = train_sum[train_sum['Year'] < 2023][water_cols + waste_cols + toilets_cols + [target]].corr()[target].sort_values(ascending=False)
correlation = abs(correlation)
corr_features = correlation[correlation > 0.05].index[1:].tolist()
print(f"Features selected based on correlation: {corr_features}")


# --- 9. Centralized Cross-Validation Function ---
def generic_train_cv(X, y, valid_df, test_df, model, feature_cols, cv_type='kfold', n_splits=3, shuffle=True, rs=42, use_scaler=False):
    if cv_type == 'kfold':
        cv = KFold(n_splits=n_splits, shuffle=shuffle, random_state=rs)
    elif cv_type == 'timeseries':
        cv = TimeSeriesSplit(n_splits=n_splits, test_size=50)
    else:
        raise ValueError(f"Invalid cv_type: {cv_type}. Must be 'kfold' or 'timeseries'.")

    valid_predictions = []; test_predictions = []
    X_cols_filled = X[feature_cols].fillna(0)

    for fold, (train_index, val_index) in enumerate(cv.split(X, y)):
        X_train, X_val = X_cols_filled.iloc[train_index], X_cols_filled.iloc[val_index]
        y_train, y_val = y.iloc[train_index], y.iloc[val_index]

        if use_scaler:
            scaler_x = RobustScaler()
            X_train = scaler_x.fit_transform(X_train)
            X_val = scaler_x.transform(X_val)
            valid_df[feature_cols] = scaler_x.transform(valid_df[feature_cols])
            test_df[feature_cols] = scaler_x.transform(test_df[feature_cols])
            scaler_y = RobustScaler()
            y_train = scaler_y.fit_transform(y_train.to_frame())

        model_fold = model
        model_fold.fit(X_train, y_train)

        y_hat_valid = model_fold.predict(valid_df[feature_cols].fillna(0))
        y_hat_test = model_fold.predict(test_df[feature_cols].fillna(0))

        if use_scaler:
            y_hat_valid = scaler_y.inverse_transform(y_hat_valid.reshape(1, -1)).flatten()
            y_hat_test = scaler_y.inverse_transform(y_hat_test.reshape(1, -1)).flatten()

        valid_predictions.append(y_hat_valid)
        test_predictions.append(y_hat_test)

    return np.mean(valid_predictions, axis=0), np.mean(test_predictions, axis=0)


# --- 10. Centralized Model Training Function ---
def generic_model_trainer(df_train, df_test, disease_name, model_params, train_params):
    year_valid = train_params.get('year_valid', 2022)
    train_year = train_params.get('train_year', 2023)
    use_diff_target = train_params.get('use_diff_target', False)
    use_org_target = train_params.get('use_org_target', False)
    use_ratio_target = train_params.get('use_ratio_target', False)
    use_median_target = train_params.get('use_median_target', True)
    par_dec = train_params.get('par_dec', 1)
    na_val = train_params.get('na_val', 9)
    use_scaler = train_params.get('use_scaler', False)
    cv_type = train_params.get('cv_type', 'kfold')
    cv_splits = train_params.get('cv_splits', 4)
    cv_shuffle = train_params.get('cv_shuffle', True)
    cv_rs = train_params.get('cv_rs', 42)
    model_type = train_params.get('model_type', 'lgbm')


    print(f"\nTraining model for {disease_name}")

    valid_preds_all = pd.DataFrame() # DataFrame to store validation predictions

    for i in tqdm(df_test['date'].unique()):
        temp_test = df_test[(df_test['date'] == i)]
        temp_train = df_train
        month = temp_test['date'].dt.month.values[0]

        static_feature_cols = [f'Total_{j}' for j in range(month, month + 12)]
        temp_train = get_static_features(temp_train, static_feature_cols)
        temp_test = get_static_features(temp_test, static_feature_cols)
        X_valid_month = temp_train[(temp_train['Year'] == year_valid) & (temp_train['Month'] == month)].copy() # Copy to avoid SettingWithCopyWarning
        y_valid_month = temp_train[(temp_train['Year'] == year_valid) & (temp_train['Month'] == month)][target].copy() # Copy to avoid SettingWithCopyWarning
        lag_target_median = 'medianT' # Define lag_target_median here

        X_valid_month['valid_index'] = X_valid_month.index # Keep index for merging back predictions


        if use_org_target:
            feature_cols_org = static_feature_cols + corr_features + ['Year', 'Month'] + [c for c in temp_train.columns if c.endswith('T') and c not in ['Category_Health_Facility_UUID', 'ID']]
            feature_cols_lr_org = ['medianT']

            train_data_org = temp_train[(temp_train['Year'] < train_year)].fillna(na_val)
            X_org = train_data_org.drop(target, axis=1)
            y_org = train_data_org[target]
            X_org = X_org[y_org.notna()]
            y_org = y_org[X_org.index]

            model_lgbm_org = lgb.LGBMRegressor(**model_params.get('lgbm_org', lgb_params_general)) # Get params or default
            model_lr_org = LinearSVR(max_iter=400, random_state=432) # Use LinearSVR directly

            y_hat_valid_lgbm_org, y_hat_test_lgbm_org = generic_train_cv(X_org, y_org, X_valid_month, temp_test, model_lgbm_org, feature_cols_org, cv_type=cv_type, n_splits=cv_splits, shuffle=cv_shuffle, rs=cv_rs)
            y_hat_valid_lr_org, y_hat_test_lr_org = generic_train_cv(X_org, y_org, X_valid_month, temp_test, model_lr_org, feature_cols_lr_org, cv_type=cv_type, n_splits=cv_splits, shuffle=cv_shuffle, rs=cv_rs, use_scaler=use_scaler)

            y_hat_valid_org = (y_hat_valid_lgbm_org * 0.5) + (y_hat_valid_lr_org * 0.5)
            y_hat_test_org = (y_hat_test_lgbm_org * 0.5) + (y_hat_test_lr_org * 0.5)

            valid_preds_month_org = pd.DataFrame({'pred_'+disease_name.lower()+'_org': y_hat_valid_org, 'valid_index': X_valid_month['valid_index']}) # Store valid preds with index
            valid_preds_all = pd.concat([valid_preds_all, valid_preds_month_org]) # Append to all valid preds

            df_test.loc[(df_test['date'] == i), 'Predicted_Total_'+disease_name.lower()+'_org'] = y_hat_test_org


        if use_diff_target:
            feature_cols_diff = corr_features + ['Month'] + [c for c in temp_train.columns if c.endswith('T') and c not in ['Category_Health_Facility_UUID', 'ID']] + [f'Total_{j}' for j in range(month, month + 6)]
            feature_cols_lr_diff = ['medianT']


            train_data_diff = temp_train[temp_train['Year'] < train_year][feature_cols_diff + [target, lag_target_median]].fillna(na_val)
            train_data_diff = train_data_diff.reset_index(drop=True) # Added reset_index here
            X_diff = train_data_diff.drop(target, axis=1)
            y_diff = train_data_diff[target] - X_diff[lag_target_median]
            X_diff = X_diff[y_diff.notna()]
            y_diff = y_diff[X_diff.index]

            model_lgbm_diff = lgb.LGBMRegressor(**model_params.get('lgbm_diff', lgb_params_diff)) # Get params or default
            model_lr_diff = LinearSVR(max_iter=400, random_state=432)

            y_hat_valid_lgbm_diff, y_hat_test_lgbm_diff = generic_train_cv(X_diff, y_diff, X_valid_month, temp_test, model_lgbm_diff, feature_cols_diff, cv_type=cv_type, n_splits=cv_splits, shuffle=cv_shuffle, rs=cv_rs)
            y_hat_valid_lr_diff, y_hat_test_lr_diff = generic_train_cv(X_diff, y_diff, X_valid_month, temp_test, model_lr_diff, feature_cols_lr_diff, cv_type=cv_type, n_splits=cv_splits, shuffle=cv_shuffle, rs=cv_rs, use_scaler=use_scaler)


            y_hat_valid_diff = (y_hat_valid_lr_diff * 0.5) + (y_hat_valid_lgbm_diff * 0.5)
            y_hat_test_diff = (y_hat_test_lr_diff * 0.5) + (y_hat_test_lgbm_diff * 0.5)

            y_hat_valid_diff = y_hat_valid_diff + X_valid_month[lag_target_median].fillna(0)
            y_hat_test_diff = y_hat_test_diff + temp_test[lag_target_median].fillna(0)


            valid_preds_month_diff = pd.DataFrame({'pred_'+disease_name.lower()+'_diff': y_hat_valid_diff, 'valid_index': X_valid_month['valid_index']}) # Store valid preds with index
            valid_preds_all = pd.concat([valid_preds_all, valid_preds_month_diff]) # Append to all valid preds


            df_test.loc[(df_test['date'] == i), 'Predicted_Total_'+disease_name.lower()+'_diff'] = y_hat_test_diff


        if use_ratio_target:
            feature_cols_ratio = static_feature_cols + corr_features + [c for c in temp_train.columns if c.endswith('T') and c not in ['Category_Health_Facility_UUID', 'ID']]
            feature_cols_lr_ratio = ['medianT']


            train_data_ratio = temp_train[temp_train['Year'] < train_year][feature_cols_ratio + [target, lag_target_median]].fillna(na_val)
            X_ratio = train_data_ratio.drop(target, axis=1)
            y_ratio = train_data_ratio[target] / X_ratio[lag_target_median]
            y_ratio = y_ratio.astype(np.float32).replace([np.inf, -np.inf], np.nan)
            X_ratio = X_ratio[y_ratio.notna()]
            y_ratio = y_ratio[X_ratio.index]


            model_lgbm_ratio = lgb.LGBMRegressor(**model_params.get('lgbm_ratio', lgb_params_ratio)) # Get params or default
            model_lr_ratio = LinearSVR(max_iter=400, random_state=432)


            y_hat_valid_lgbm_ratio, y_hat_test_lgbm_ratio = generic_train_cv(X_ratio, y_ratio, X_valid_month, temp_test, model_lgbm_ratio, feature_cols_ratio, cv_type=cv_type, n_splits=cv_splits, shuffle=cv_shuffle, rs=cv_rs)
            y_hat_valid_lr_ratio, y_hat_test_lr_ratio = generic_train_cv(X_ratio, y_ratio, X_valid_month, temp_test, model_lr_ratio, feature_cols_lr_ratio, cv_type=cv_type, n_splits=cv_splits, shuffle=cv_shuffle, rs=cv_rs, use_scaler=use_scaler)


            y_hat_valid_ratio = (y_hat_valid_lgbm_ratio * 0.5) + (y_hat_valid_lr_ratio * 0.5)
            y_hat_test_ratio = (y_hat_test_lr_ratio * 0.5) + (y_hat_test_lgbm_ratio * 0.5)

            y_hat_valid_ratio = y_hat_valid_ratio * X_valid_month[lag_target_median].fillna(0)
            y_hat_test_ratio = y_hat_test_ratio * temp_test[lag_target_median].fillna(0)

            valid_preds_month_ratio = pd.DataFrame({'pred_'+disease_name.lower()+'_ratio': y_hat_valid_ratio, 'valid_index': X_valid_month['valid_index']}) # Store valid preds with index
            valid_preds_all = pd.concat([valid_preds_all, valid_preds_month_ratio]) # Append to all valid preds


            df_train.loc[(df_train['Year'] == year_valid) & (df_train['Month'] == month), 'pred_'+disease_name.lower()+'_ratio'] = y_hat_valid_ratio
            df_test.loc[(df_test['date'] == i), 'Predicted_Total_'+disease_name.lower()+'_ratio'] = y_hat_test_ratio


        if use_median_target:
            y_hat_valid_median = X_valid_month['medianT'].fillna(na_val)
            y_hat_test_median = temp_test['medianT'].fillna(na_val)

            valid_preds_month_median = pd.DataFrame({'pred_'+disease_name.lower()+'_median': y_hat_valid_median, 'valid_index': X_valid_month['valid_index']}) # Store valid preds with index
            valid_preds_all = pd.concat([valid_preds_all, valid_preds_month_median]) # Append to all valid preds


            df_train.loc[(df_train['Year'] == year_valid) & (df_train['Month'] == month), 'pred_'+disease_name.lower()+'_median'] = y_hat_valid_median
            df_test.loc[(df_test['date'] == i), 'Predicted_Total_'+disease_name.lower()+'_median'] = y_hat_test_median

    # Merge all validation predictions back to the training dataframe
    valid_preds_all = valid_preds_all.merge(X_valid_month[['valid_index']], on='valid_index', how='left').set_index('valid_index')
    df_train = df_train.merge(valid_preds_all, left_index=True, right_index=True, how='left')


    return df_train, df_test


# --- 11. Model and Training Parameters ---
lgb_params_general = dict(n_estimators=40, objective='mae', verbose=-1, lambda_l1=5, max_depth=5, max_bin=100, random_state=41, learning_rate=0.05)
lgb_params_ml = dict(n_estimators=200, objective='mae', verbose=-1, lambda_l1=5, max_depth=5, max_bin=100, random_state=41, learning_rate=0.05)
lgb_params_diff_ml = dict(n_estimators=200, objective='mae', verbose=-1, lambda_l1=5, max_depth=5, max_bin=100, random_state=41, learning_rate=0.05)
lgb_params_ratio_ml = dict(n_estimators=100, objective='mae', verbose=-1, lambda_l1=10, max_depth=5, max_bin=100, random_state=41, learning_rate=0.05)
lgb_params_dr = dict(n_estimators=500, objective='mae',verbose=-1, lambda_l1=5,max_depth=5,max_bin=100, random_state=41, learning_rate=0.05)
lgb_params_diff_dr = dict(n_estimators=1000, objective='mae',verbose=-1, lambda_l1=5,max_depth=5,max_bin=100, random_state=41, learning_rate=0.05)
lgb_params_ratio_dr = dict(n_estimators=100, objective='mae',verbose=-1, lambda_l1=10,max_depth=5,max_bin=100, random_state=41, learning_rate=0.05)
lgb_params_bdr = dict(n_estimators=50, objective='mae',verbose=-1, lambda_l1=5,max_depth=5,max_bin=100, random_state=41, learning_rate=0.05)


model_parameter_sets = {
    'general': {'lgbm_org': lgb_params_general},
    'malaria': {'lgbm_org': lgb_params_ml, 'lgbm_diff': lgb_params_diff_ml},
    'intestinal_worms': {'lgbm_org': lgb_params_ml, 'lgbm_diff': lgb_params_diff_ml, 'lgbm_ratio': lgb_params_ratio_ml},
    'diarrhea': {'lgbm_org': lgb_params_dr, 'lgbm_diff': lgb_params_diff_dr, 'lgbm_ratio': lgb_params_ratio_dr},
    'blacklisted_diarrhea': {'lgbm_org': lgb_params_bdr, 'lgbm_diff': lgb_params_bdr, 'lgbm_ratio': lgb_params_bdr}
}


training_parameter_sets = {
    'dysentery_typhoid_cholera': {'year_valid': 2022, 'train_year': 2023, 'use_org_target': True}, # Example - adjust for actual needs if different
    'malaria_original': {'year_valid': 2022, 'train_year': 2023, 'use_diff_target': True, 'use_org_target': True, 'use_median_target': True},
    'malaria_sum': {'year_valid': 2022, 'train_year': 2023, 'use_diff_target': True, 'use_org_target': True, 'use_median_target': True},
    'intestinal_worms': {'year_valid': 2022, 'train_year': 2023, 'use_org_target': True, 'use_diff_target': True, 'use_ratio_target': True, 'use_median_target': True, 'cv_shuffle': False, 'cv_rs': None},
    'diarrhea_original': {'year_valid': 2022, 'train_year': 2023, 'use_diff_target': True, 'use_median_target': True, 'use_org_target': True, 'use_ratio_target': True},
    'diarrhea_sum': {'year_valid': 2022, 'train_year': 2023, 'use_diff_target': True, 'use_median_target': True, 'use_org_target': True, 'use_ratio_target': True},
    'blacklisted_diarrhea': {'year_valid': 2022, 'train_year': 2023, 'use_diff_target': True, 'use_org_target': True, 'use_ratio_target': True, 'use_ensemble': True, 'use_scaler': True, 'na_val': 9, 'par_dec': 1, 'cv_shuffle': False, 'cv_rs': None},
}


# --- 12. Train Models ---
print("\nTraining model for ['Dysentery', 'Typhoid', 'Cholera']")
temp_train_others = train[(train['Year'] >= 2021) & (train['Disease'].isin(['Dysentery', 'Typhoid', 'Cholera']))].copy() # Correct filtering
temp_train_others, test = generic_model_trainer(temp_train_others, test, 'Others', model_parameter_sets['general'], training_parameter_sets['dysentery_typhoid_cholera'])


print("\nTraining model for Malaria - Original Data")
temp_train_malaria = train[train['Disease'] == 'Malaria'].copy()
temp_train_malaria, test = generic_model_trainer(temp_train_malaria, test, 'Malaria_ml', model_parameter_sets['malaria'], training_parameter_sets['malaria_original'])

print("\nTraining model for Malaria - Aggregated Data")
temp_train_sum_malaria = train_sum[train_sum['Disease'] == 'Malaria'].copy()
temp_train_sum_malaria, test_sum = generic_model_trainer(temp_train_sum_malaria, test_sum, 'Malaria_ml_sum', model_parameter_sets['malaria'], training_parameter_sets['malaria_sum'])


print("\nTraining model for Intestinal Worms")
temp_train_iw = train[train['Disease'] == 'Intestinal Worms'].copy()
temp_train_iw, test = generic_model_trainer(temp_train_iw, test, 'IW', model_parameter_sets['intestinal_worms'], training_parameter_sets['intestinal_worms'])


print("\nTraining model for Diarrhea")
temp_train_dr = train[train['Disease'] == 'Diarrhea'].copy()
temp_train_dr, test= generic_model_trainer(temp_train_dr, test, 'Dr', model_parameter_sets['diarrhea'], training_parameter_sets['diarrhea_original'])

print("\nTraining model for Diarrhea Blacklisted - Aggregated")
temp_train_sum_dr = train_sum[train_sum['Disease'] == 'Diarrhea'].copy()
temp_train_sum_dr, test_sum= generic_model_trainer(temp_train_sum_dr, test_sum, 'Bdr', model_parameter_sets['blacklisted_diarrhea'], training_parameter_sets['blacklisted_diarrhea'])


# --- 13. Ensemble Predictions ---
use_ensemble = True
if use_ensemble: # Ensemble for Other diseases
    ens_cols_others = ['org'] # Corrected to 'org' as per original notebook for others
    test['Predicted_Total_ens'] = test[[f'Predicted_Total_others_{c}' for c in ens_cols_others]].mean(axis=1) # Corrected disease name

if use_ensemble: # Ensemble for Malaria
    ens_cols_malaria = ['diff', 'org']
    test['Predicted_Total_ml_ens'] = test[[f'Predicted_Total_malaria_ml_{c}' for c in ens_cols_malaria]].mean(axis=1)
    test_sum['Predicted_Total_ml_ens'] = test_sum[[f'Predicted_Total_malaria_ml_sum_{c}' for c in ens_cols_malaria]].mean(axis=1)

    test['Predicted_Total_ml_ens_sum'] = test_sum['Predicted_Total_ml_ens'] # Corrected names
    test['Predicted_Total_ml_diff_sum'] = test_sum['Predicted_Total_malaria_ml_sum_diff']
    test['Predicted_Total_ml_org_sum'] = test_sum['Predicted_Total_malaria_ml_sum_org']


if use_ensemble: # Ensemble for Intestinal Worms
    ens_cols_iw = ['iw_org', 'iw_diff', 'iw_ratio']
    test['Predicted_Total_iw_ens'] = test[[f'Predicted_Total_iw_{c}' for c in ens_cols_iw]].mean(axis=1)


if use_ensemble: # Ensemble for Diarrhea
    ens_cols_dr=['diff', 'org']
    test['Predicted_Total_dr_ens']=test[[f'Predicted_Total_dr_{c}' for c in ens_cols_dr]].mean(axis=1)
    test_sum['Predicted_Total_dr_ens']=test_sum[[f'Predicted_Total_dr_bdr_{c}' for c in ens_cols_dr]].mean(axis=1) # Corrected disease name

    test['Predicted_Total_dr_ens_sum']=test_sum['Predicted_Total_dr_ens'] # Corrected names
    test['Predicted_Total_dr_median_sum']=test_sum['Predicted_Total_dr_bdr_median']
    test['Predicted_Total_dr_diff_sum']=test_sum['Predicted_Total_dr_bdr_diff']
    test['Predicted_Total_dr_org_sum']=test_sum['Predicted_Total_dr_bdr_org']


def print_score(df):
    #df['mae']=abs(df['pred']-df[target]) # This line will likely cause error as 'pred' is not defined. Consider removing or fixing.
    df['mae_org']=abs(df['pred_others_org']-df[target]) # Corrected names
    df['mae_diff']=abs(df['pred_others_diff']-df[target]) # Corrected names
    df['mae_ratio']=abs(df['pred_others_ratio']-df[target]) # Corrected names
    df['mae_median']=abs(df['pred_others_median']-df[target]) # Corrected names
    display(df[(df['Year']==2022)].groupby(['Disease'])[['mae_org','mae_diff','mae_ratio', 'mae_median']].mean()) # Removed 'mae' as it's likely erroneous


# --- 14. Final Prediction and Submission ---
test['Predicted_Total']=0

blacklist = blacklist_locations # Define blacklist

malaria_cond=test['Disease'].isin(['Malaria'])
inst_cond=test['Disease'].isin(['Intestinal Worms'])
dr_cond=test['Disease'].isin(['Diarrhea'])
blacklist_cond=test['Location'].isin(blacklist) & dr_cond
dr_cond=dr_cond&~blacklist_cond
other_cond=test['Disease'].isin(dis_others)


test.loc[other_cond,'Predicted_Total']=test.loc[other_cond,'Predicted_Total_ens']
test.loc[malaria_cond, 'Predicted_Total']=test.loc[malaria_cond, ['Predicted_Total_ml_ens_sum']].mean(axis=1)
test.loc[inst_cond, 'Predicted_Total']=test.loc[inst_cond, ['Predicted_Total_iw_ens']].mean(axis=1)
test.loc[dr_cond, 'Predicted_Total']=test.loc[dr_cond, ['Predicted_Total_dr_ens','Predicted_Total_dr_ens_sum']].mean(axis=1)
test.loc[blacklist_cond, 'Predicted_Total']=test.loc[blacklist_cond, ['Predicted_Total_bdr_ens_sum']].mean(axis=1)


sub=test[['ID','Predicted_Total']]
sub['Predicted_Total']=np.clip(sub['Predicted_Total'].fillna(0), 0,np.inf)
sub.to_csv('test_sub.csv', index=False)
sub['Predicted_Total'].describe()