In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
from sklearn.model_selection import train_test_split
import pandas as pd

from mra_midas_skin_cancer_ml.utils.process_metadata import (
    create_lesion_key,
    import_metadata,
    process_target,
    sort_metadata,
)

In [3]:
def process_metadata_for_img():
    """
    Process metadata to develop separate models based on image distance.
    """

    meta_df = import_metadata()
    meta_df = process_target(meta_df)
    meta_df = meta_df[meta_df["midas_path_binary"] != "missing"]

    dist_dict = {}
    dist_dict["all"] = meta_df

    cols = [
        "lesion_key",
        "midas_record_id",
        "midas_file_name",
        "midas_path_binary",
    ]

    for dist in ["1ft", "6in", "dscope"]:
        subset_df = meta_df[meta_df["midas_distance"] == dist]

        # Sort by patient_id (asc), lesion (asc)and control (desc)
        subset_df = sort_metadata(subset_df)

        # Create unique patient lesion key
        subset_df = create_lesion_key(subset_df)

        subset_df = subset_df[cols]

        # Drop duplicates and keep last record (non-control record)
        subset_df = subset_df.drop_duplicates(subset="lesion_key", keep="last")

        print(f"{dist} Unique: {subset_df['lesion_key'].is_unique}")

        dist_dict[dist] = subset_df

    return dist_dict


dist_dict = process_metadata_for_img()

1ft Unique: True
6in Unique: True
dscope Unique: True


In [None]:
def check_split_ratios(dist_dict):
    """Check the distribution of the target variable across the splits for each image set."""
    for dist, df in dist_dict.items():
        print(f"\n{dist}")
        print("Raw counts:")
        print(df["split"].value_counts())
        print("\nProportions:")
        print(df["split"].value_counts(normalize=True))

def train_test_split_by_lesion(dist_dict, test_size=0.2, random_state=42):
    """Split the dataframe into train/val/test sets for each image set."""

    result_dict = {}

    for dist, subset_df in dist_dict.items():
        X = subset_df.drop(columns=["midas_path_binary"])
        y = subset_df["midas_path_binary"]

        X_train, X_temp, y_train, y_temp = train_test_split(
            X, y,
            test_size=test_size,
            random_state=random_state,
            stratify=y
        )

        X_val, X_test, y_val, y_test = train_test_split(
            X_temp, y_temp,
            test_size=0.5,
            random_state=random_state,
            stratify=y_temp
        )

        splits = {
            "train": (X_train, y_train),
            "val":   (X_val, y_val),
            "test":  (X_test, y_test)
        }

        split_dfs = []

        for split_name, (X_split, y_split) in splits.items():
            split_df = X_split.copy()
            split_df["midas_path_binary"] = y_split
            split_df["split"] = split_name
            split_dfs.append(split_df)

        result_dict[dist] = pd.concat(split_dfs, ignore_index=True)

    return result_dict
