In [1]:
import pandas as pd
from collections import defaultdict

In [2]:
def preprocess_and_group_columns(df, columns_to_remove=None, prefixes_to_remove=None):
    """
    Cleans the DataFrame by removing specified columns and prefix-matching columns,
    then groups remaining numeric columns by their unique non-negative value sets.

    Parameters:
        df (pd.DataFrame): The full DataFrame
        columns_to_remove (list): Exact column names to remove
        prefixes_to_remove (tuple or list): Column prefixes to remove (e.g., ('DSM_', 'INT_'))

    Returns:
        grouped_columns (dict): Keys = tuple of unique non-negative values, values = list of column names
    """
    if columns_to_remove is None:
        columns_to_remove = []
    if prefixes_to_remove is None:
        prefixes_to_remove = ()

    # Step 1: Build full list of columns to remove
    to_drop = set(columns_to_remove)
    to_drop.update([col for col in df.columns if col.startswith(tuple(prefixes_to_remove))])

    # Step 2: Drop columns
    df_cleaned = df.drop(columns=[col for col in to_drop if col in df.columns])
    print(f"[Step 1] Dropped {len(to_drop)} columns. Remaining: {df_cleaned.shape[1]}")

    # Step 3: Group by value sets allowing only non-negative values + (-6, -8)
    grouped_columns = defaultdict(list)

    for col in df_cleaned.columns:
        if df_cleaned[col].dtype in ['int64', 'float64']:
            allowed_values = df_cleaned[col][(df_cleaned[col] >= 0) | (df_cleaned[col].isin([-6, -8]))].dropna().unique()
            value_set = tuple(sorted(allowed_values))
            grouped_columns[value_set].append(col)

    print(f"[Step 2] Grouped {len(df_cleaned.columns)} columns into {len(grouped_columns)} value patterns.")
    return grouped_columns

In [3]:
df = pd.read_csv('./data/mental-health-comorbidity-raw.csv')

admin_cols = ['RESPID', 'NCS1YR', 'AGE', 'STR', 'CASEID', 'COMPLETE', 'SECU', 'CASEWGT']
checkpoint_cols = ['M5A', 'IR3', 'IR11_4', 'IR36', 'IR47', 'PD0A', 'PD2', 'PD5', 'PD14', 'PD20', 'PD23', 'AG2', 'AG7', 'AG10', 
                   'FD4_1', 'FD6', 'FD7_1', 'FD9_1', 'PR1', 'PR11_1', 'PR15', 'FN1', 'FN4', 'FN24', 'CN1_2', 'CN4', 'CN4_1',
                   'CN7_1', 'CN7_2', 'CN8', 'CN14', 'DA36_2B', 'DA36_3A_1', 'DE20_3', 'DE20_6', 'CH23', 'CH38_1', 'CH74_1',
                   'CH104', 'AD0', 'AD2', 'AD7', 'AD29', 'AD31', 'AD36', 'AD43_2', 'OD2', 'OD27', 'CD3', 'CD17_1', 'CD24', 'SA1E_1',
                   'SA2', 'SA3', 'SA7A1', 'SA10', 'SA11E_1', 'SA12', 'SA18_5']

rem_cols = admin_cols + checkpoint_cols 

prefixes = ('IR48VALUES', 'PD27VALUES', 'PD28VALUES', )

grouped = preprocess_and_group_columns(
    df,
    columns_to_remove=rem_cols,
    prefixes_to_remove= prefixes
)

# View the results
#for value_set, cols in grouped.items():
#    print(f"\nColumns with values {value_set}:")
#    for col in cols:
#        print(f"  - {col}")

[Step 1] Dropped 109 columns. Remaining: 891
[Step 2] Grouped 891 columns into 251 value patterns.


Really, this issue will manifest itself the most in regression/decision trees. We can use one-hot to help with this. 30% threshold also will help for weak categories

In [4]:
# grouped is your dictionary: {value_set: [list of column names]}
sorted_summary = sorted(grouped.items(), key=lambda x: len(x[1]), reverse=True)

# Display results
print("🔢 Most common non-negative value patterns by column count:\n")
for value_set, cols in sorted_summary:
    print(f"{value_set} → {len(cols)} columns")


🔢 Most common non-negative value patterns by column count:

(np.int64(-8), np.int64(-6), np.int64(1), np.int64(5)) → 174 columns
(np.int64(-6), np.int64(1), np.int64(5)) → 94 columns
(np.int64(1), np.int64(5)) → 86 columns
(np.int64(-6), np.int64(1), np.int64(2), np.int64(3), np.int64(4)) → 53 columns
(np.int64(0), np.int64(1)) → 41 columns
(np.int64(-8), np.int64(-6), np.int64(1), np.int64(2), np.int64(3), np.int64(4)) → 40 columns
(np.int64(1), np.int64(2), np.int64(3), np.int64(4)) → 26 columns
(np.int64(-8), np.int64(-6), np.int64(1), np.int64(5), np.int64(7)) → 16 columns
(np.int64(-6), np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5)) → 15 columns
(np.int64(-8), np.int64(-6), np.int64(1), np.int64(2)) → 13 columns
(np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8), np.int64(9)) → 12 columns
(np.int64(-8), np.int64(1), np.int64(5)) → 10 columns
(np.int64(5),) → 7 columns
(np.int64(1), np.int64(2)) → 7 columns
(np.

In [5]:
def remove_sparse_value_sets(df, grouped, min_columns=6):
    """
    Removes columns from the DataFrame that belong to value sets with fewer than `min_columns` columns.

    Parameters:
        df (pd.DataFrame): Original DataFrame
        grouped (dict): Dictionary of value sets to column names
        min_columns (int): Minimum number of columns to keep for a value set

    Returns:
        pd.DataFrame: Filtered DataFrame with only columns from frequent value sets
    """
    # Flatten column names from large-enough value sets
    columns_to_keep = [
        col
        for value_set, cols in grouped.items()
        if len(cols) >= min_columns
        for col in cols
    ]
    
    # Subset the DataFrame
    return df[columns_to_keep]

In [6]:
# Filter grouped to only value sets with 6 or more columns
filtered_grouped = {value_set: cols for value_set, cols in grouped.items() if len(cols) >= 6}

# Sort and display the remaining sets by column count
sorted_remaining = sorted(filtered_grouped.items(), key=lambda x: len(x[1]), reverse=True)

print("✅ Remaining value sets (with 6 or more columns):\n")
for value_set, cols in sorted_remaining:
    print(f"{value_set} → {len(cols)} columns")

✅ Remaining value sets (with 6 or more columns):

(np.int64(-8), np.int64(-6), np.int64(1), np.int64(5)) → 174 columns
(np.int64(-6), np.int64(1), np.int64(5)) → 94 columns
(np.int64(1), np.int64(5)) → 86 columns
(np.int64(-6), np.int64(1), np.int64(2), np.int64(3), np.int64(4)) → 53 columns
(np.int64(0), np.int64(1)) → 41 columns
(np.int64(-8), np.int64(-6), np.int64(1), np.int64(2), np.int64(3), np.int64(4)) → 40 columns
(np.int64(1), np.int64(2), np.int64(3), np.int64(4)) → 26 columns
(np.int64(-8), np.int64(-6), np.int64(1), np.int64(5), np.int64(7)) → 16 columns
(np.int64(-6), np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5)) → 15 columns
(np.int64(-8), np.int64(-6), np.int64(1), np.int64(2)) → 13 columns
(np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8), np.int64(9)) → 12 columns
(np.int64(-8), np.int64(1), np.int64(5)) → 10 columns
(np.int64(5),) → 7 columns
(np.int64(1), np.int64(2)) → 7 columns


In [7]:
# Display value sets and their associated columns
for value_set, cols in filtered_grouped.items():
    print(f"\n🔹 Columns with values {value_set} ({len(cols)} columns):")
    for col in cols:
        print(f"  - {col}")


🔹 Columns with values (np.int64(-6), np.int64(1), np.int64(5)) (94 columns):
  - M1
  - M18
  - M18B2
  - M18B3
  - M47
  - IR2
  - IR4
  - IR20
  - IR20B4
  - IR21
  - PD9
  - PD9B3
  - PD13A
  - PD13C
  - PD13D
  - PD17
  - PD21B2
  - PD21B3
  - PD25B
  - AG3INTR1
  - AG3INTR2
  - AAG3B2
  - AAG3B3
  - AAG3B4
  - AG4A
  - AG4B
  - AG4F
  - AG4G
  - AG4H
  - AG5
  - AG6
  - AAG6A2
  - AAG6A3
  - AG8
  - AG8A
  - AG9A
  - AG9B
  - AG9C
  - AG14
  - AG17
  - AG37
  - FD4A
  - PR16A
  - PR20
  - PR20B2
  - PR21
  - CN12B
  - LE3
  - LE4
  - LE9
  - LE11
  - LE12
  - CH61
  - CH61A
  - CH90
  - CH90A
  - AD3
  - AD3B2
  - AD4
  - AD6B
  - AD6C
  - AD6D
  - AD32
  - AD32B2
  - AD35A
  - AD35B
  - AD35C
  - AD35D
  - OD3B2
  - OD3B3
  - CD7B1
  - CD7B2
  - CD16
  - CD18C2
  - CD18C3
  - CD20
  - CD38
  - SA1F
  - SA1G
  - SA1H
  - SA1I
  - SA1J
  - SA1K
  - SA8
  - SA8B2
  - SA8B3
  - SA11F
  - SA11G
  - SA11H
  - SA11I
  - SA19
  - SA19B2
  - SA19B3
  - SA20

🔹 Columns with values (np.int

D_, DSM_ - diagnoses to remove if desired