In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import polars as pl

In [2]:
sub_path_001_005 = '/kaggle/input/leap-pytorch-5m-f64-batch-001-005-ensemble/ensemble_submission.csv'
sub_path_006_010 = '/kaggle/input/leap-pytorch-5m-f64-batch-006-010-ensemble/ensemble_submission.csv'


# Paths to submission files
submission_paths = [
    sub_path_001_005,
    sub_path_006_010
]

In [3]:
# Read the sample_id column from the first submission file
sample_id_col = pl.read_csv('/kaggle/input/leap-atmospheric-physics-ai-climsim/sample_submission.csv', n_threads=1)['sample_id']
sample_id_col = sample_id_col.to_pandas()

In [4]:
# Load the header from the first submission file
header = pd.read_csv(submission_paths[0], nrows=0).columns.tolist()

In [5]:
def ensemble_submissions(submission_paths):
    """
    Ensemble predictions from multiple submission files.

    Parameters
    ----------
    submission_paths : list of str
        List of file paths to submission CSV files.

    Returns
    -------
    np.ndarray
        The ensembled predictions.
    """
    # Read and ensemble predictions from each submission file
    ensemble_predictions = None
    for path in submission_paths:
        submission_df = pl.read_csv(path)
        predictions = submission_df[:, 1:].to_numpy()  # Exclude the first column (sample_id)
        if ensemble_predictions is None:
            ensemble_predictions = predictions
        else:
            ensemble_predictions += predictions

        # Delete submission_df to free up memory
        del submission_df

    # Average ensemble predictions
    ensemble_predictions /= len(submission_paths)

    return ensemble_predictions


In [6]:
# Ensemble predictions
ensemble_preds = ensemble_submissions(submission_paths)

In [7]:
# Create DataFrame from ensembled predictions
ensemble_df = pd.DataFrame(ensemble_preds, columns=header[1:])  # Exclude sample_id from header

ensemble_df

Unnamed: 0,ptend_t_0,ptend_t_1,ptend_t_2,ptend_t_3,ptend_t_4,ptend_t_5,ptend_t_6,ptend_t_7,ptend_t_8,ptend_t_9,...,ptend_v_58,ptend_v_59,cam_out_NETSW,cam_out_FLWDS,cam_out_PRECSC,cam_out_PRECC,cam_out_SOLS,cam_out_SOLL,cam_out_SOLSD,cam_out_SOLLD
0,-0.240605,-0.514768,-0.414768,-0.617056,-0.795912,-0.909707,-0.941093,-0.946086,-0.950893,-0.950653,...,-0.094108,0.068964,0.015512,5.194903,0.019998,0.004688,-0.000311,0.012071,0.009443,-0.018621
1,-0.189888,-0.430200,-0.439067,-0.668852,-0.877671,-0.972547,-0.969087,-0.965670,-0.978381,-0.990967,...,-0.129141,0.423531,0.034073,5.158769,0.050641,0.150926,0.004044,0.016107,0.010673,0.013420
2,-0.248405,-0.941084,-0.587242,-0.615612,-0.785433,-0.915164,-0.956804,-0.955637,-0.939955,-0.927477,...,-0.154351,-0.031226,-0.011583,5.709803,0.019263,0.005436,-0.003573,-0.006266,-0.011131,-0.024672
3,-0.290976,-0.769755,-0.602398,-0.621971,-0.826239,-0.970950,-0.997067,-0.974548,-0.948202,-0.937874,...,-0.243370,-0.274822,-0.004546,5.751472,0.058090,0.245144,0.020742,0.014384,0.004509,-0.027713
4,-0.199230,-0.250085,-0.403759,-0.723925,-0.926101,-1.002715,-0.986452,-0.978449,-0.977556,-0.990342,...,0.559024,0.103322,0.027400,5.084572,0.044448,0.152916,-0.005384,0.002509,0.020860,0.041488
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
624995,1.243302,0.704741,1.015835,1.142871,1.297883,1.315854,1.245037,1.191253,1.199678,1.247306,...,0.089324,0.131364,2.017672,5.412830,0.046568,0.034535,1.897592,1.902543,1.625645,1.121087
624996,1.303961,0.663198,1.000224,1.212504,1.287996,1.253108,1.119913,1.052780,1.193922,1.339654,...,-0.047775,-0.004737,1.602201,5.703972,0.033443,0.092254,1.148337,1.113886,2.324075,2.131416
624997,1.292895,0.398499,0.896352,1.141986,1.091592,0.921941,0.721742,0.692916,0.819667,0.908161,...,-0.105182,-0.108229,1.143164,5.691248,0.017042,0.028041,0.884144,0.885550,1.511482,1.271262
624998,1.383824,0.440285,0.904473,1.071636,1.179318,1.092865,0.932861,0.875537,0.911289,0.972626,...,0.186191,-0.127692,1.590675,5.387972,0.074061,0.013962,1.564957,1.607105,1.178976,0.707465


In [8]:
# Insert the sample_id column at the beginning of the DataFrame
ensemble_df.insert(0, 'sample_id', sample_id_col)

ensemble_df

Unnamed: 0,sample_id,ptend_t_0,ptend_t_1,ptend_t_2,ptend_t_3,ptend_t_4,ptend_t_5,ptend_t_6,ptend_t_7,ptend_t_8,...,ptend_v_58,ptend_v_59,cam_out_NETSW,cam_out_FLWDS,cam_out_PRECSC,cam_out_PRECC,cam_out_SOLS,cam_out_SOLL,cam_out_SOLSD,cam_out_SOLLD
0,test_169651,-0.240605,-0.514768,-0.414768,-0.617056,-0.795912,-0.909707,-0.941093,-0.946086,-0.950893,...,-0.094108,0.068964,0.015512,5.194903,0.019998,0.004688,-0.000311,0.012071,0.009443,-0.018621
1,test_524862,-0.189888,-0.430200,-0.439067,-0.668852,-0.877671,-0.972547,-0.969087,-0.965670,-0.978381,...,-0.129141,0.423531,0.034073,5.158769,0.050641,0.150926,0.004044,0.016107,0.010673,0.013420
2,test_634129,-0.248405,-0.941084,-0.587242,-0.615612,-0.785433,-0.915164,-0.956804,-0.955637,-0.939955,...,-0.154351,-0.031226,-0.011583,5.709803,0.019263,0.005436,-0.003573,-0.006266,-0.011131,-0.024672
3,test_403572,-0.290976,-0.769755,-0.602398,-0.621971,-0.826239,-0.970950,-0.997067,-0.974548,-0.948202,...,-0.243370,-0.274822,-0.004546,5.751472,0.058090,0.245144,0.020742,0.014384,0.004509,-0.027713
4,test_484578,-0.199230,-0.250085,-0.403759,-0.723925,-0.926101,-1.002715,-0.986452,-0.978449,-0.977556,...,0.559024,0.103322,0.027400,5.084572,0.044448,0.152916,-0.005384,0.002509,0.020860,0.041488
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
624995,test_578220,1.243302,0.704741,1.015835,1.142871,1.297883,1.315854,1.245037,1.191253,1.199678,...,0.089324,0.131364,2.017672,5.412830,0.046568,0.034535,1.897592,1.902543,1.625645,1.121087
624996,test_395695,1.303961,0.663198,1.000224,1.212504,1.287996,1.253108,1.119913,1.052780,1.193922,...,-0.047775,-0.004737,1.602201,5.703972,0.033443,0.092254,1.148337,1.113886,2.324075,2.131416
624997,test_88942,1.292895,0.398499,0.896352,1.141986,1.091592,0.921941,0.721742,0.692916,0.819667,...,-0.105182,-0.108229,1.143164,5.691248,0.017042,0.028041,0.884144,0.885550,1.511482,1.271262
624998,test_79382,1.383824,0.440285,0.904473,1.071636,1.179318,1.092865,0.932861,0.875537,0.911289,...,0.186191,-0.127692,1.590675,5.387972,0.074061,0.013962,1.564957,1.607105,1.178976,0.707465


In [9]:
# Write the ensembled DataFrame to a CSV file
pl.from_pandas(ensemble_df).write_csv('ensemble_submission.csv')

In [10]:
ensemble_df

Unnamed: 0,sample_id,ptend_t_0,ptend_t_1,ptend_t_2,ptend_t_3,ptend_t_4,ptend_t_5,ptend_t_6,ptend_t_7,ptend_t_8,...,ptend_v_58,ptend_v_59,cam_out_NETSW,cam_out_FLWDS,cam_out_PRECSC,cam_out_PRECC,cam_out_SOLS,cam_out_SOLL,cam_out_SOLSD,cam_out_SOLLD
0,test_169651,-0.240605,-0.514768,-0.414768,-0.617056,-0.795912,-0.909707,-0.941093,-0.946086,-0.950893,...,-0.094108,0.068964,0.015512,5.194903,0.019998,0.004688,-0.000311,0.012071,0.009443,-0.018621
1,test_524862,-0.189888,-0.430200,-0.439067,-0.668852,-0.877671,-0.972547,-0.969087,-0.965670,-0.978381,...,-0.129141,0.423531,0.034073,5.158769,0.050641,0.150926,0.004044,0.016107,0.010673,0.013420
2,test_634129,-0.248405,-0.941084,-0.587242,-0.615612,-0.785433,-0.915164,-0.956804,-0.955637,-0.939955,...,-0.154351,-0.031226,-0.011583,5.709803,0.019263,0.005436,-0.003573,-0.006266,-0.011131,-0.024672
3,test_403572,-0.290976,-0.769755,-0.602398,-0.621971,-0.826239,-0.970950,-0.997067,-0.974548,-0.948202,...,-0.243370,-0.274822,-0.004546,5.751472,0.058090,0.245144,0.020742,0.014384,0.004509,-0.027713
4,test_484578,-0.199230,-0.250085,-0.403759,-0.723925,-0.926101,-1.002715,-0.986452,-0.978449,-0.977556,...,0.559024,0.103322,0.027400,5.084572,0.044448,0.152916,-0.005384,0.002509,0.020860,0.041488
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
624995,test_578220,1.243302,0.704741,1.015835,1.142871,1.297883,1.315854,1.245037,1.191253,1.199678,...,0.089324,0.131364,2.017672,5.412830,0.046568,0.034535,1.897592,1.902543,1.625645,1.121087
624996,test_395695,1.303961,0.663198,1.000224,1.212504,1.287996,1.253108,1.119913,1.052780,1.193922,...,-0.047775,-0.004737,1.602201,5.703972,0.033443,0.092254,1.148337,1.113886,2.324075,2.131416
624997,test_88942,1.292895,0.398499,0.896352,1.141986,1.091592,0.921941,0.721742,0.692916,0.819667,...,-0.105182,-0.108229,1.143164,5.691248,0.017042,0.028041,0.884144,0.885550,1.511482,1.271262
624998,test_79382,1.383824,0.440285,0.904473,1.071636,1.179318,1.092865,0.932861,0.875537,0.911289,...,0.186191,-0.127692,1.590675,5.387972,0.074061,0.013962,1.564957,1.607105,1.178976,0.707465
