In [1]:
import numpy as np
import scipy.sparse as sp
import pandas as pd
import random
from prediction_utils.pytorch_utils.datasets import ArrayLoaderGenerator

In [64]:
def get_features_sparse(num_samples=100, num_features=1000, seed=10):
    return sp.random(m=num_samples, n=num_features, format="csr", random_state=seed)

def get_cohort(
    
    num_samples=100,
    row_id_col="row_id",
    fold_id_test="test",
    label_col="outcome",
    attributes=["gender"],
    attribute="gender",
):

    return pd.DataFrame(
        {
            row_id_col: np.arange(num_samples),
            "fold_id": [
                random.choice(["1", "2", "3", fold_id_test])
                for _ in range(num_samples)
            ],
            label_col: np.random.randint(0, 1 + 1, size=num_samples),
            attribute: [
                random.choice(["male", "female"]) for _ in range(num_samples)
            ],
        }
    )

def test_get_data_dict():
    num_samples = 100
    num_features = 1000
    features = get_features_sparse(
        num_samples=num_samples, num_features=num_features
    )
    cohort = get_cohort(num_samples=num_samples)
    loader_generator = ArrayLoaderGenerator(
        features=features,
        cohort=cohort,
        row_id_col="row_id",
        fold_id_test="test",
        label_col="outcome",
        attributes=None,
        attribute=None,
        fold_id="1",
        num_workers=0,
    )
    data_dict = loader_generator.data_dict
    return data_dict, cohort

#     data_df = pd.concat(
#         {
#             key: pd.concat(
#                 {key2: pd.Series(value2) for key2, value2 in value.items()}
#             )
#             .rename(key)
#             .reset_index(level=1, drop=True)
#             .rename_axis("fold_id")
#             .reset_index()
#             for key, value in data_dict.items()
#             if key != "features"
#         },
#     )
#     data_df = data_df.sort_values(['row_id'])
#     assert data_df.shape[0] == cohort.shape[0]
#     assert data_df[['row_id', 'outcome']].equals(cohort[['row_id', 'outcome']])

In [65]:
data_dict, cohort = test_get_data_dict()

In [75]:
data_df_dict = {
            key: pd.concat(
                {key2: pd.Series(value2) for key2, value2 in value.items()}
            )
            .to_frame().rename(columns={0:key})
            .rename_axis(['fold_id', 'dict_row_id'])
            .reset_index()
            for key, value in data_dict.items()
            if key in ['row_id', 'labels']
        }

for i, (key, value) in enumerate(data_df_dict.items()):
    print(value)
    if i == 0:
        result = value
    else:
        result = result.merge(value)

   fold_id  dict_row_id  row_id
0    train            0       0
1    train            1       1
2    train            2       2
3    train            3       3
4    train            4       4
..     ...          ...     ...
95    test           27      82
96    test           28      89
97    test           29      94
98    test           30      96
99    test           31      98

[100 rows x 3 columns]
   fold_id  dict_row_id  labels
0    train            0       0
1    train            1       1
2    train            2       0
3    train            3       0
4    train            4       0
..     ...          ...     ...
95    test           27       1
96    test           28       1
97    test           29       1
98    test           30       1
99    test           31       0

[100 rows x 3 columns]


In [76]:
cohort

Unnamed: 0,row_id,fold_id,outcome,gender
0,0,3,0,male
1,1,3,1,male
2,2,2,0,male
3,3,2,0,male
4,4,3,0,male
...,...,...,...,...
95,95,3,1,male
96,96,test,1,female
97,97,2,1,male
98,98,test,0,female


In [77]:
result.sort_values('row_id')[['row_id', 'labels']].rename(
    columns={'labels': 'outcome'}
).merge(cohort)

Unnamed: 0,row_id,outcome,fold_id,gender
0,0,0,3,male
1,1,1,3,male
2,2,0,2,male
3,3,0,2,male
4,4,0,3,male
...,...,...,...,...
95,95,1,3,male
96,96,1,test,female
97,97,1,2,male
98,98,0,test,female
