# Split train, val, test sets for multimodal pipeline
### Split: train, val, test
### Disease: CAD, Infarction
### Store in UKBB/cardiac_segmentations/projects/SelfSuperBio/18545/final

- cardiac_features_{split}_imputed_noOH_tabular_imaging_reordered.csv
- cardiac_{split}_paths_imaging.pt
- labels_{disease}_{split}.pt
- tabular_lengths_reordered.pt

- cardiac_features_{split}_imputed_noOH_tabular_imaging_reordered_balanced.csv
- cardiac_{split}_paths_imaging_balanced.pt
- labels_{disease}_{split}_balanced.pt


In [1]:
import os
from os.path import join
import pandas as pd
from Utils import check_or_save
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.impute import KNNImputer
from sklearn.model_selection import train_test_split
pd.options.display.max_columns = 700

In [2]:
IMPUTE = True
impute_strategy = 'iterative'
BALANCED = True
one_hot_encoded = False

# BASE = '/vol/biomedic3/sd1523/data/mm/UKBB'
BASE = '/bigdata/siyi/data/UKBB'

RAW_DATA = '/vol/biodata/data/biobank/18545/data'
RAW_FEATURES = join(BASE, 'features')
FEATURES = join(BASE, 'cardiac_segmentations', 'projects','SelfSuperBio', '18545')
STORE_PATH = join(FEATURES, 'final')

#### Delete ids that don't have img folder

In [None]:
data_df = pd.read_csv(join(RAW_FEATURES, 'cardiac_feature_18545_vector_labeled_noOH.csv'))
categorical_columns = data_df.loc[:,'Sex-0.0':].columns
for x in categorical_columns:
    print(data_df[x].value_counts().sort_index())

In [76]:
# generate invalid_ids list
_ids = list(data_df['eid'].astype(int))
invalid_ids_path = join(FEATURES, 'invalid_ids.pt')
if os.path.exists(invalid_ids_path):
    invalid_ids = torch.load(invalid_ids_path)
else:
    invalid_ids = []
    for id in tqdm(_ids):
        path = join(RAW_DATA, str(id))
        if not os.path.isdir(path):
            invalid_ids.append(id)
    print(f'Found {len(invalid_ids)} bad indices in the whole dataset')
    check_or_save(invalid_ids, join(FEATURES, 'invalid_ids.pt'))

In [77]:
# remove invalid ids
to_del = torch.load(join(FEATURES,'invalid_ids.pt'))
print(f'Dataset length before {len(_ids)}')
for _id in to_del:
    _ids.remove(_id)
print(f'Val length after {len(_ids)}')

Dataset length before 442554
Val length after 36194


In [None]:
# store removed invalid ids df
to_del = set([int(x) for x in to_del])
new_data_df = data_df[~data_df['eid'].isin(to_del)]
new_data_df.reset_index(inplace=True, drop=True)
new_data_df

In [79]:
check_or_save(new_data_df, join(FEATURES, 'cardiac_feature_18545_vector_labeled_noOH_dropNI.csv'), index=False, header=True)

### Impute and split tabular data

In [6]:
data_df = pd.read_csv(join(FEATURES, 'cardiac_feature_18545_vector_labeled_noOH_dropNI.csv'))
_ids = list(data_df['eid'].astype(int))
print(data_df.shape)
all_data_df = data_df.copy()

(36194, 81)


In [None]:
data_df_coverage = data_df.notna().sum()/len(data_df)*100
data_df_coverage.sort_values(ascending=False)

In [8]:
# If no train, val, test sets
problem_ids = torch.load(join(FEATURES, 'problem_ids_cardiac.pt'))
print(f'Num of problem ids in cardiac image: {len(problem_ids)}')
print(f'Dataset length before {len(_ids)}')
for _id in problem_ids:
    _ids.remove(_id)
print(f'Val length after {len(_ids)}')

train_set_ids, test_ids = train_test_split(_ids, test_size=0.1, random_state=2022)
train_ids, val_ids = train_test_split(train_set_ids, test_size=0.2, random_state=2022)

check_or_save(train_ids, join(FEATURES,'ids_train_tabular_imaging.pt'))
check_or_save(val_ids, join(FEATURES,'ids_val_tabular_imaging.pt'))
check_or_save(test_ids, join(FEATURES,'ids_test_tabular_imaging.pt'))

Num of problem ids in cardiac image: 27
Dataset length before 36194
Val length after 36167


In [9]:
train_ids = torch.load(join(FEATURES, 'ids_train_tabular_imaging.pt'))
val_ids = torch.load(join(FEATURES, 'ids_val_tabular_imaging.pt'))
test_ids = torch.load(join(FEATURES, 'ids_test_tabular_imaging.pt'))
print(f'train: {len(train_ids)}, val: {len(val_ids)}, test: {len(test_ids)}, total: {len(train_ids)+len(val_ids)+len(test_ids)}')

train: 26040, val: 6510, test: 3617, total: 36167


In [10]:
eid_df = data_df.loc[:,'eid'].astype('int')
# eid_df = data_df.loc[:,'eid_old'].astype('int')

continuous_df = data_df.loc[:,(
    'Pulse wave Arterial Stiffness index-2.0',
    'Systolic blood pressure-2.mean',
    'Diastolic blood pressure-2.mean',
    'Pulse rate-2.mean',
    'Body fat percentage-2.0',
    'Whole body fat mass-2.0',
    # 'Whole body fat-free mass-2.0',
    # 'Whole body water mass-2.0',
    'Body mass index (BMI)-2.0',
    # 'Cooked vegetable intake-2.0',
    # 'Salad / raw vegetable intake-2.0',
    # 'Cardiac operations performed',
    # 'Total mass-2.0',
    'Basal metabolic rate-2.0',
    # 'Impedance of whole body-2.0',
    'Waist circumference-2.0',
    'Hip circumference-2.0',
    # 'Standing height-2.0',
    # 'Height-2.0',
    # 'Sitting height-2.0',
    'Weight-2.0',
    'Ventricular rate-2.0',
    'P duration-2.0',
    'QRS duration-2.0',
    # 'PQ interval-2.0',
    # 'RR interval-2.0',
    # 'PP interval-2.0',
    'Cardiac output-2.0',
    'Cardiac index-2.0',
    'Average heart rate-2.0',
    'Body surface area-2.0',
    'Duration of walks-2.0',
    'Duration of moderate activity-2.0',
    'Duration of vigorous activity-2.0',
    # 'Time spent watching television (TV)-2.0',
    # 'Time spent using computer-2.0',
    # 'Time spent driving-2.0',
    'Heart rate during PWA-2.0',
    'Systolic brachial blood pressure during PWA-2.0',
    'Diastolic brachial blood pressure during PWA-2.0',
    'Peripheral pulse pressure during PWA-2.0',
    'Central systolic blood pressure during PWA-2.0',
    'Central pulse pressure during PWA-2.0',
    'Number of beats in waveform average for PWA-2.0',
    'Central augmentation pressure during PWA-2.0',
    'Augmentation index for PWA-2.0',
    'Cardiac output during PWA-2.0',
    'End systolic pressure during PWA-2.0',
    'End systolic pressure index during PWA-2.0',
    'Total peripheral resistance during PWA-2.0',
    'Stroke volume during PWA-2.0',
    # 'Mean arterial pressure during PWA-2.0',
    'Cardiac index during PWA-2.0',
    'Sleep duration-2.0',
    'Exposure to tobacco smoke at home-2.0',
    'Exposure to tobacco smoke outside home-2.0',
    # 'Pack years of smoking-2.0',
    # 'Pack years adult smoking as proportion of life span exposed to smoking-2.0',
    'LVESV (mL)',
    'LVEDV (mL)',
    'LVSV (mL)',
    'LVEF (%)',
    'LVCO (L/min)',
    'LVM (g)',
    'RVEDV (mL)',
    'RVESV (mL)', 
    'RVSV (mL)', 
    'RVEF (%)',
)]
if one_hot_encoded:
  categorical_df = data_df.loc[:,(
    'Worrier / anxious feelings-2.0',
    'Shortness of breath walking on level ground-2.0',
    'Sex-0.0',
    'Diabetes diagnosis',
    #'Heart attack diagnosed by doctor',   # commented by LaaF
    'Angina diagnosed by doctor',
    'Stroke diagnosed by doctor',
    'High blood pressure diagnosed by doctor',
    'Cholesterol lowering medication regularly taken',
    'Blood pressure medication regularly taken',
    'Insulin medication regularly taken',
    'Hormone replacement therapy medication regularly taken',
    'Oral contraceptive pill or minipill medication regularly taken',
    'Pace-maker-2.0',
    'Ever had diabetes (Type I or Type II)-0.0',
    'Long-standing illness, disability or infirmity-2.0',
    'Tense / \'highly strung\'-2.0',
    'Ever smoked-2.0',

    'Alcohol intake frequency.-2.0-0',
    'Alcohol intake frequency.-2.0-1',
    'Alcohol intake frequency.-2.0-2',
    'Alcohol intake frequency.-2.0-3',
    'Alcohol intake frequency.-2.0-4',
    'Alcohol intake frequency.-2.0-5',
    'Processed meat intake-2.0-0',
    'Processed meat intake-2.0-1',
    'Processed meat intake-2.0-2',
    'Processed meat intake-2.0-3',
    'Processed meat intake-2.0-4',
    'Processed meat intake-2.0-5',
    'Beef intake-2.0-0',
    'Beef intake-2.0-1',
    'Beef intake-2.0-2',
    'Beef intake-2.0-3',
    'Beef intake-2.0-4',
    'Beef intake-2.0-5',
    'Pork intake-2.0-0',
    'Pork intake-2.0-1',
    'Pork intake-2.0-2',
    'Pork intake-2.0-3',
    'Pork intake-2.0-4',
    'Pork intake-2.0-5',
    'Lamb/mutton intake-2.0-0',
    'Lamb/mutton intake-2.0-1',
    'Lamb/mutton intake-2.0-2',
    'Lamb/mutton intake-2.0-3',
    'Lamb/mutton intake-2.0-4',
    'Lamb/mutton intake-2.0-5',
    'Overall health rating-2.0-0',
    'Overall health rating-2.0-1',
    'Overall health rating-2.0-2',
    'Overall health rating-2.0-3',
    'Alcohol usually taken with meals-2.0-0',
    'Alcohol usually taken with meals-2.0-1',
    'Alcohol usually taken with meals-2.0-2',
    'Alcohol drinker status-2.0-0',
    'Alcohol drinker status-2.0-1',
    'Alcohol drinker status-2.0-2',
    'Frequency of drinking alcohol-0.0-0',
    'Frequency of drinking alcohol-0.0-1',
    'Frequency of drinking alcohol-0.0-2',
    'Frequency of drinking alcohol-0.0-3',
    'Frequency of drinking alcohol-0.0-4',
    'Frequency of consuming six or more units of alcohol-0.0-0',
    'Frequency of consuming six or more units of alcohol-0.0-1',
    'Frequency of consuming six or more units of alcohol-0.0-2',
    'Frequency of consuming six or more units of alcohol-0.0-3',
    'Frequency of consuming six or more units of alcohol-0.0-4',
    'Amount of alcohol drunk on a typical drinking day-0.0-0',
    'Amount of alcohol drunk on a typical drinking day-0.0-1',
    'Amount of alcohol drunk on a typical drinking day-0.0-2',
    'Amount of alcohol drunk on a typical drinking day-0.0-3',
    'Amount of alcohol drunk on a typical drinking day-0.0-4',
    'Amount of alcohol drunk on a typical drinking day-0.0-5',
    'Falls in the last year-2.0-0',
    'Falls in the last year-2.0-1',
    'Falls in the last year-2.0-2',
    'Weight change compared with 1 year ago-2.0-0',
    'Weight change compared with 1 year ago-2.0-1',
    'Weight change compared with 1 year ago-2.0-2',
    'Number of days/week walked 10+ minutes-2.0-0',
    'Number of days/week walked 10+ minutes-2.0-1',
    'Number of days/week walked 10+ minutes-2.0-2',
    'Number of days/week walked 10+ minutes-2.0-3',
    'Number of days/week walked 10+ minutes-2.0-4',
    'Number of days/week walked 10+ minutes-2.0-5',
    'Number of days/week walked 10+ minutes-2.0-6',
    'Number of days/week walked 10+ minutes-2.0-7',
    'Number of days/week of moderate physical activity 10+ minutes-2.0-0',
    'Number of days/week of moderate physical activity 10+ minutes-2.0-1',
    'Number of days/week of moderate physical activity 10+ minutes-2.0-2',
    'Number of days/week of moderate physical activity 10+ minutes-2.0-3',
    'Number of days/week of moderate physical activity 10+ minutes-2.0-4',
    'Number of days/week of moderate physical activity 10+ minutes-2.0-5',
    'Number of days/week of moderate physical activity 10+ minutes-2.0-6',
    'Number of days/week of moderate physical activity 10+ minutes-2.0-7',
    'Number of days/week of vigorous physical activity 10+ minutes-2.0-0',
    'Number of days/week of vigorous physical activity 10+ minutes-2.0-1',
    'Number of days/week of vigorous physical activity 10+ minutes-2.0-2',
    'Number of days/week of vigorous physical activity 10+ minutes-2.0-3',
    'Number of days/week of vigorous physical activity 10+ minutes-2.0-4',
    'Number of days/week of vigorous physical activity 10+ minutes-2.0-5',
    'Number of days/week of vigorous physical activity 10+ minutes-2.0-6',
    'Number of days/week of vigorous physical activity 10+ minutes-2.0-7',
    'Usual walking pace-2.0-0',
    'Usual walking pace-2.0-1',
    'Usual walking pace-2.0-2',
    'Frequency of stair climbing in last 4 weeks-2.0-0',
    'Frequency of stair climbing in last 4 weeks-2.0-1',
    'Frequency of stair climbing in last 4 weeks-2.0-2',
    'Frequency of stair climbing in last 4 weeks-2.0-3',
    'Frequency of stair climbing in last 4 weeks-2.0-4',
    'Frequency of stair climbing in last 4 weeks-2.0-5',
    'Frequency of walking for pleasure in last 4 weeks-2.0-0',
    'Frequency of walking for pleasure in last 4 weeks-2.0-1',
    'Frequency of walking for pleasure in last 4 weeks-2.0-2',
    'Frequency of walking for pleasure in last 4 weeks-2.0-3',
    'Frequency of walking for pleasure in last 4 weeks-2.0-4',
    'Frequency of walking for pleasure in last 4 weeks-2.0-5',
    'Frequency of walking for pleasure in last 4 weeks-2.0-6',
    'Duration walking for pleasure-2.0-0',
    'Duration walking for pleasure-2.0-1',
    'Duration walking for pleasure-2.0-2',
    'Duration walking for pleasure-2.0-3',
    'Duration walking for pleasure-2.0-4',
    'Duration walking for pleasure-2.0-5',
    'Duration walking for pleasure-2.0-6',
    'Duration walking for pleasure-2.0-7',
    'Frequency of strenuous sports in last 4 weeks-2.0-0',
    'Frequency of strenuous sports in last 4 weeks-2.0-1',
    'Frequency of strenuous sports in last 4 weeks-2.0-2',
    'Frequency of strenuous sports in last 4 weeks-2.0-3',
    'Frequency of strenuous sports in last 4 weeks-2.0-4',
    'Frequency of strenuous sports in last 4 weeks-2.0-5',
    'Frequency of strenuous sports in last 4 weeks-2.0-6',
    'Duration of strenuous sports-2.0-0',
    'Duration of strenuous sports-2.0-1',
    'Duration of strenuous sports-2.0-2',
    'Duration of strenuous sports-2.0-3',
    'Duration of strenuous sports-2.0-4',
    'Duration of strenuous sports-2.0-5',
    'Duration of strenuous sports-2.0-6',
    'Duration of strenuous sports-2.0-7',
    'Duration of light DIY-2.0-0',
    'Duration of light DIY-2.0-1',
    'Duration of light DIY-2.0-2',
    'Duration of light DIY-2.0-3',
    'Duration of light DIY-2.0-4',
    'Duration of light DIY-2.0-5',
    'Duration of light DIY-2.0-6',
    'Duration of light DIY-2.0-7',
    'Frequency of heavy DIY in last 4 weeks-2.0-0',
    'Frequency of heavy DIY in last 4 weeks-2.0-1',
    'Frequency of heavy DIY in last 4 weeks-2.0-2',
    'Frequency of heavy DIY in last 4 weeks-2.0-3',
    'Frequency of heavy DIY in last 4 weeks-2.0-4',
    'Frequency of heavy DIY in last 4 weeks-2.0-5',
    'Frequency of heavy DIY in last 4 weeks-2.0-6',
    'Duration of heavy DIY-2.0-0',
    'Duration of heavy DIY-2.0-1',
    'Duration of heavy DIY-2.0-2',
    'Duration of heavy DIY-2.0-3',
    'Duration of heavy DIY-2.0-4',
    'Duration of heavy DIY-2.0-5',
    'Duration of heavy DIY-2.0-6',
    'Duration of heavy DIY-2.0-7',
    'Frequency of other exercises in last 4 weeks-2.0-0',
    'Frequency of other exercises in last 4 weeks-2.0-1',
    'Frequency of other exercises in last 4 weeks-2.0-2',
    'Frequency of other exercises in last 4 weeks-2.0-3',
    'Frequency of other exercises in last 4 weeks-2.0-4',
    'Frequency of other exercises in last 4 weeks-2.0-5',
    'Frequency of other exercises in last 4 weeks-2.0-6',
    'Duration of other exercises-2.0-0',
    'Duration of other exercises-2.0-1',
    'Duration of other exercises-2.0-2',
    'Duration of other exercises-2.0-3',
    'Duration of other exercises-2.0-4',
    'Duration of other exercises-2.0-5',
    'Duration of other exercises-2.0-6',
    'Duration of other exercises-2.0-7',
    'Sleeplessness / insomnia-2.0-0',
    'Sleeplessness / insomnia-2.0-1',
    'Sleeplessness / insomnia-2.0-2',
    'Current tobacco smoking-2.0-0',
    'Current tobacco smoking-2.0-1',
    'Current tobacco smoking-2.0-2',
    'Past tobacco smoking-2.0-0',
    'Past tobacco smoking-2.0-1',
    'Past tobacco smoking-2.0-2',
    'Past tobacco smoking-2.0-3',
    'Smoking/smokers in household-2.0-0',
    'Smoking/smokers in household-2.0-1',
    'Smoking/smokers in household-2.0-2',
    'Smoking status-2.0-0',
    'Smoking status-2.0-1',
    'Smoking status-2.0-2'
)]
else:
  categorical_df = data_df.loc[:,(
    # 'Worrier / anxious feelings-2.0',
    'Shortness of breath walking on level ground-2.0',
    'Sex-0.0',
    'Diabetes diagnosis',
    #'Heart attack diagnosed by doctor',  # commented by LaaF
    'Angina diagnosed by doctor',
    'Stroke diagnosed by doctor',
    'High blood pressure diagnosed by doctor',
    'Cholesterol lowering medication regularly taken',
    'Blood pressure medication regularly taken',
    'Insulin medication regularly taken',
    'Hormone replacement therapy medication regularly taken',
    'Oral contraceptive pill or minipill medication regularly taken',
    # 'Pace-maker-2.0',
    # 'Ever had diabetes (Type I or Type II)-0.0',
    'Long-standing illness, disability or infirmity-2.0',
    # 'Tense / \'highly strung\'-2.0',
    'Ever smoked-2.0',

    "Sleeplessness / insomnia-2.0",
    # "Frequency of heavy DIY in last 4 weeks-2.0",
    "Alcohol intake frequency.-2.0",
    # "Processed meat intake-2.0",
    # "Beef intake-2.0",
    # "Pork intake-2.0",
    # "Lamb/mutton intake-2.0",
    "Overall health rating-2.0",
    # "Alcohol usually taken with meals-2.0",
    "Alcohol drinker status-2.0",
    # "Frequency of drinking alcohol-0.0",
    # "Frequency of consuming six or more units of alcohol-0.0",
    # "Amount of alcohol drunk on a typical drinking day-0.0",
    "Falls in the last year-2.0",
    # "Weight change compared with 1 year ago-2.0",
    "Number of days/week walked 10+ minutes-2.0",
    "Number of days/week of moderate physical activity 10+ minutes-2.0",
    "Number of days/week of vigorous physical activity 10+ minutes-2.0",
    "Usual walking pace-2.0",
    # "Frequency of stair climbing in last 4 weeks-2.0",
    # "Frequency of walking for pleasure in last 4 weeks-2.0",
    # "Duration walking for pleasure-2.0",
    # "Frequency of strenuous sports in last 4 weeks-2.0",
    "Duration of strenuous sports-2.0",
    # "Duration of light DIY-2.0",
    # "Duration of heavy DIY-2.0",
    # "Frequency of other exercises in last 4 weeks-2.0",
    # "Duration of other exercises-2.0",
    "Current tobacco smoking-2.0",
    "Past tobacco smoking-2.0",
    # "Smoking/smokers in household-2.0",
    "Smoking status-2.0"
)]

In [11]:
# normalize
continuous_df=(continuous_df-continuous_df.mean())/continuous_df.std()
print(f'Number of continuous features: {len(continuous_df.columns)}')
print(f'Number of samples: {len(continuous_df)}')
print(f'Number of categorical features: {len(categorical_df.columns)}')
print(f'data_df shape: {data_df.shape}')

Number of continuous features: 49
Number of samples: 36194
Number of categorical features: 26
data_df shape: (36194, 81)


In [12]:
if IMPUTE:
  if impute_strategy.lower()=='simple':
    continuous_df.fillna(0, inplace=True)

    for i in categorical_df.columns[categorical_df.isnull().any(axis=0)]:
        categorical_df[i].fillna(categorical_df[i].mode()[0], inplace=True)

    data_df = pd.concat([eid_df, continuous_df, categorical_df], axis=1)
  else:
    data_df = pd.concat([eid_df, continuous_df, categorical_df], axis=1)

    mi = 10
    nn = 9
    it_path = os.path.join(FEATURES, f'data_df_imputed_{mi}.pt')
    knn_path = os.path.join(FEATURES, f'data_df_NNimputed_{nn}.pt')
    if os.path.exists(it_path):
      continuous_df_it_imputed = torch.load(it_path)
    else:
      imp = IterativeImputer(max_iter=mi, random_state=0, sample_posterior=True, skip_complete=True, min_value=continuous_df.min(), max_value=continuous_df.max())
      continuous_df_it_imputed = imp.fit_transform(continuous_df)
      torch.save(continuous_df_it_imputed, os.path.join(FEATURES, f'data_df_imputed_{mi}.pt'))

    if os.path.exists(knn_path):
      data_df_nn_imputed = torch.load(knn_path)
    else:
      knnimputer = KNNImputer(n_neighbors=nn)
      data_df_nn_imputed = knnimputer.fit_transform(data_df.iloc[:,1:]) # exclude eid
      torch.save(data_df_nn_imputed, os.path.join(FEATURES, f'data_df_NNimputed_{nn}.pt'))

    categorical_column_indices = [data_df.columns.get_loc(c) for c in data_df.columns if c in categorical_df.columns]
    categorical_column_indices = [c-1 for c in categorical_column_indices] # shift by one because no more eid
    data_df.loc[:,categorical_df.columns] = data_df_nn_imputed[:,categorical_column_indices]
    data_df.loc[:,categorical_df.columns] = data_df.loc[:,categorical_df.columns].round(0) # round to nearest integer because categorical can only be integer

    data_df.loc[:,continuous_df.columns] = continuous_df_it_imputed

  data_df.loc[:,categorical_df.columns] = data_df_nn_imputed[:,categorical_column_indices]


In [13]:
print(data_df.shape)
check_or_save(data_df, join(FEATURES, 'cardiac_feature_18545_vector_labeled_noOH_dropNI_imputed.csv'), index=False, header=True)

(36194, 76)


In [14]:
lengths = [1 for i in range(len(continuous_df.columns))]
max = list(data_df.max(axis=0))[len(continuous_df.columns)+1:]
max = [int(i)+1 for i in max]
lengths = lengths + max

check_or_save(lengths, join(FEATURES, 'tabular_lengths.pt'))
print(len(lengths), lengths)

75 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 6, 4, 3, 3, 8, 8, 8, 3, 8, 3, 4, 3]


In [15]:
field_lengths_tabular = lengths
categorical_ids = []
continuous_ids = []
for i in range(len(field_lengths_tabular)):
    if field_lengths_tabular[i] == 1:
        continuous_ids.append(i)
    else:
        categorical_ids.append(i)
print('Categorical Index: {}, '.format(len(categorical_ids)), categorical_ids)
print('Numerical Index: {}, '.format(len(continuous_ids)), continuous_ids)

reorder_ids = categorical_ids + continuous_ids
reorder_field_lengths_tabular = [field_lengths_tabular[i] for i in reorder_ids]
reorder_field_lengths_tabular_noExt = reorder_field_lengths_tabular[:-10]
check_or_save(reorder_field_lengths_tabular, join(STORE_PATH, f'tabular_lengths_reordered.pt'), index=False, header=False)
check_or_save(reorder_field_lengths_tabular_noExt, join(STORE_PATH, f'tabular_lengths_reordered_noExt.pt'), index=False, header=False)

Categorical Index: 26,  [49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74]
Numerical Index: 49,  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48]


### Get train, val, test tabular features, labels, and image paths

In [16]:
all_image_paths = torch.load(join(FEATURES, 'preprocessed_cardiac_npy_path.pt'))
print(len(all_image_paths))

36167


In [17]:
target = 'Infarction'
# target = 'CAD'
# directly save
print(target)
for split, ids in zip(['train', 'val', 'test'], [train_ids, val_ids, test_ids]):
    # tabular features
    split_df = data_df.set_index('eid').loc[ids]
    split_df = split_df.iloc[:,reorder_ids]
    print(f'{split} df shape: {split_df.shape}')
    check_or_save(split_df, join(STORE_PATH, f'cardiac_features_{split}_imputed_noOH_tabular_imaging_reordered.csv'), index=False, header=False)
    # tabular features no extracted
    split_df_no_extracted = split_df.iloc[:,:-10]
    print(f'{split} df shape: {split_df_no_extracted.shape}')
    check_or_save(split_df_no_extracted, join(STORE_PATH, f'cardiac_features_{split}_imputed_noOH_tabular_imaging_reordered_noExt.csv'), index=False, header=False)
    # image paths
    split_image_paths = [all_image_paths[k] for k in ids]
    check_or_save(split_image_paths, join(STORE_PATH, f'cardiac_{split}_paths_imaging.pt'))
    print(f'{split} image path shape: {len(split_image_paths)}')
    # labels
    split_all_df = all_data_df.set_index('eid').loc[ids]
    split_labels = split_all_df[target].values 
    check_or_save(split_labels, join(STORE_PATH, f'cardiac_labels_{target}_{split}.pt'))
    print(f'{split} label shape: {len(split_labels)}')

Infarction
train df shape: (26040, 75)


train df shape: (26040, 65)
train image path shape: 26040
train label shape: 26040
val df shape: (6510, 75)
val df shape: (6510, 65)
val image path shape: 6510
val label shape: 6510
test df shape: (3617, 75)
test df shape: (3617, 65)
test image path shape: 3617
test label shape: 3617


In [19]:
# balance
# data_df contains final imputed features, all_data_df contains all features and labels
target = 'Infarction'
# target = 'CAD'
data_df['eid_old'] = data_df['eid']
all_data_df['eid_old'] = all_data_df['eid']
train_df = data_df.set_index('eid_old').loc[train_ids]
val_df = data_df.set_index('eid_old').loc[val_ids]
all_train_df = all_data_df.set_index('eid_old').loc[train_ids]
all_val_df = all_data_df.set_index('eid_old').loc[val_ids]
print(target)

for all_split_df, split_df, split in zip([all_train_df, all_val_df], [train_df, val_df], ['train', 'val']):
    pos_eids = list(all_split_df.loc[all_split_df[target] == 1]['eid'])
    random.seed(2022)
    neg_eids = random.sample(list(all_split_df.loc[all_split_df[target] == 0]['eid']), len(pos_eids))
    all_eids = pos_eids + neg_eids
    # tabular features
    balanced_split_df = split_df.set_index('eid').loc[all_eids]
    balanced_split_df = balanced_split_df.iloc[:,reorder_ids]
    print(f'{split} tabular feature shape: {balanced_split_df.shape}')
    check_or_save(balanced_split_df, join(STORE_PATH, f'cardiac_features_{split}_imputed_noOH_tabular_imaging_{target}_balanced_reordered.csv'), index=False, header=False)
    # tabular features no extracted
    balanced_split_df_no_extracted = balanced_split_df.iloc[:,:-10]
    print(f'{split} tabular feature shape: {balanced_split_df_no_extracted.shape}')
    check_or_save(balanced_split_df_no_extracted, join(STORE_PATH, f'cardiac_features_{split}_imputed_noOH_tabular_imaging_{target}_balanced_reordered_noExt.csv'), index=False, header=False)
    # image paths
    balanced_split_image_paths = [all_image_paths[k] for k in all_eids]
    check_or_save(balanced_split_image_paths, join(STORE_PATH, f'cardiac_{split}_paths_imaging_{target}_balanced.pt'))
    print(f'{split} image path shape: {len(balanced_split_image_paths)}')
    # labels
    split_all_df = all_data_df.set_index('eid').loc[all_eids]
    split_labels = split_all_df[target].values
    check_or_save(split_labels, join(STORE_PATH, f'cardiac_labels_{target}_{split}_balanced.pt'))
    print(f'{split} label shape: {len(split_labels)}')
    # ids
    check_or_save(all_eids, join(STORE_PATH, f'ids_{split}_tabular_imaging_{target}_balanced.pt'))

Infarction
train tabular feature shape: (1552, 75)
train tabular feature shape: (1552, 65)
train image path shape: 1552
train label shape: 1552
Saving to /bigdata/siyi/data/UKBB/cardiac_segmentations/projects/SelfSuperBio/18545/final/ids_train_tabular_imaging_Infarction_balanced.pt
val tabular feature shape: (472, 75)
val tabular feature shape: (472, 65)
val image path shape: 472
val label shape: 472
Saving to /bigdata/siyi/data/UKBB/cardiac_segmentations/projects/SelfSuperBio/18545/final/ids_val_tabular_imaging_Infarction_balanced.pt


### Get low balanced data split

In [29]:
split='train'
# target = 'Infarction'
target = 'CAD'
print(target)
for k, prev_k in zip([0.1,0.01],['','_0.1']):
  imputed_df = pd.read_csv(join(STORE_PATH, f'cardiac_features_{split}_imputed_noOH_tabular_imaging_{target}{prev_k}_balanced_reordered.csv'), header=None)
  labels = torch.load(join(STORE_PATH,f'cardiac_labels_{target}{prev_k}_{split}_balanced.pt'))
  image_paths = torch.load(join(STORE_PATH, f'cardiac_{split}_paths_imaging_{target}{prev_k}_balanced.pt'))
  ids = torch.load(join(STORE_PATH, f'ids_{split}_tabular_imaging_{target}{prev_k}_balanced.pt'))
  assert len(imputed_df) == len(labels) == len(image_paths) == len(ids)

  indices = list(range(imputed_df.shape[0]))
  _, low_data_indices = train_test_split(indices, test_size=0.1, random_state=2022, stratify=labels)
  # print(low_data_indices)
  # ids
  low_data_ids = [ids[i] for i in low_data_indices]
  print(f'Low data {k} ids shape: {len(low_data_ids)}')
  check_or_save(low_data_ids, join(STORE_PATH,f'ids_{split}_tabular_imaging_{target}_{k}_balanced.pt'))
  # labels
  low_data_labels = [labels[i] for i in low_data_indices]
  print(f'Low data {k} labels shape: {len(low_data_labels)}')
  check_or_save(low_data_labels, join(STORE_PATH,f'cardiac_labels_{target}_{k}_{split}_balanced.pt'))
  # image paths
  low_data_image_paths = [image_paths[i] for i in low_data_indices]
  print(f'Low data {k} image paths shape: {len(low_data_image_paths)}')
  check_or_save(low_data_image_paths, join(STORE_PATH,f'cardiac_{split}_paths_imaging_{target}_{k}_balanced.pt'))
  # tabular features
  low_data_imputed_df = imputed_df.iloc[low_data_indices]
  print(f'Low data {k} tabular features shape: {low_data_imputed_df.shape}')
  check_or_save(low_data_imputed_df, join(STORE_PATH,f'cardiac_features_{split}_imputed_noOH_tabular_imaging_{target}_{k}_balanced_reordered.csv'), header=False, index=False)


CAD
Low data 0.1 ids shape: 349
Saving to /bigdata/siyi/data/UKBB/cardiac_segmentations/projects/SelfSuperBio/18545/final/ids_train_tabular_imaging_CAD_0.1_balanced.pt
Low data 0.1 labels shape: 349
Saving to /bigdata/siyi/data/UKBB/cardiac_segmentations/projects/SelfSuperBio/18545/final/cardiac_labels_CAD_0.1_train_balanced.pt
Low data 0.1 image paths shape: 349
Saving to /bigdata/siyi/data/UKBB/cardiac_segmentations/projects/SelfSuperBio/18545/final/cardiac_train_paths_imaging_CAD_0.1_balanced.pt
Low data 0.1 tabular features shape: (349, 75)
Low data 0.01 ids shape: 35
Saving to /bigdata/siyi/data/UKBB/cardiac_segmentations/projects/SelfSuperBio/18545/final/ids_train_tabular_imaging_CAD_0.01_balanced.pt
Low data 0.01 labels shape: 35
Saving to /bigdata/siyi/data/UKBB/cardiac_segmentations/projects/SelfSuperBio/18545/final/cardiac_labels_CAD_0.01_train_balanced.pt
Low data 0.01 image paths shape: 35
Saving to /bigdata/siyi/data/UKBB/cardiac_segmentations/projects/SelfSuperBio/18545/f

### Check split data

In [3]:
split = 'train'
target = 'Infarction'
k = ''
data_df = pd.read_csv(join(FEATURES, f'cardiac_feature_18545_vector_labeled_noOH_dropNI_imputed.csv'))
all_data_df = pd.read_csv(join(FEATURES, f'cardiac_feature_18545_vector_labeled_noOH_dropNI.csv'))

split_df = pd.read_csv(join(STORE_PATH, f'cardiac_features_{split}_imputed_noOH_tabular_imaging_{target}{k}_balanced_reordered.csv'), header=None)
split_labels = torch.load(join(STORE_PATH, f'cardiac_labels_{target}{k}_{split}_balanced.pt'))
split_image_paths = torch.load(join(STORE_PATH, f'cardiac_{split}_paths_imaging_{target}{k}_balanced.pt'))
split_ids = torch.load(join(STORE_PATH, f'ids_{split}_tabular_imaging_{target}{k}_balanced.pt'))

In [4]:
print(len(split_df))

1552


In [None]:
# ids = [int(x.split('/')[-2]) for x in split_image_paths[:10]]
ids = split_ids[:10]
split_df.iloc[:10]

In [None]:
data_df.set_index('eid').loc[ids]

In [None]:
all_data_df.set_index('eid').loc[ids, target]

In [35]:
print(train_ids[:10])

[1221869, 1560136, 3442183, 1514007, 1892612, 4576079, 3905249, 5998757, 3441553, 1053366]


In [36]:
split_image_paths[:10]

['/bigdata/siyi/data/UKBB/cardiac_segmentations/subjects/2353623/sa_es_ed_mm.npy',
 '/bigdata/siyi/data/UKBB/cardiac_segmentations/subjects/2223143/sa_es_ed_mm.npy',
 '/bigdata/siyi/data/UKBB/cardiac_segmentations/subjects/2444884/sa_es_ed_mm.npy',
 '/bigdata/siyi/data/UKBB/cardiac_segmentations/subjects/3411878/sa_es_ed_mm.npy',
 '/bigdata/siyi/data/UKBB/cardiac_segmentations/subjects/5043249/sa_es_ed_mm.npy',
 '/bigdata/siyi/data/UKBB/cardiac_segmentations/subjects/4809059/sa_es_ed_mm.npy',
 '/bigdata/siyi/data/UKBB/cardiac_segmentations/subjects/2510815/sa_es_ed_mm.npy',
 '/bigdata/siyi/data/UKBB/cardiac_segmentations/subjects/2321244/sa_es_ed_mm.npy',
 '/bigdata/siyi/data/UKBB/cardiac_segmentations/subjects/2708178/sa_es_ed_mm.npy',
 '/bigdata/siyi/data/UKBB/cardiac_segmentations/subjects/5747266/sa_es_ed_mm.npy']

In [37]:
print(split_labels[:10])

[0, 1, 1, 0, 1, 0, 1, 1, 1, 1]


In [22]:
print(len(reorder_field_lengths_tabular_noExt))

65


In [None]:
random.seed(2030)
L = 10
f, axarr = plt.subplots(L, 3, figsize=(20,5*L))
for i in range(L):
    rand_idx = random.randrange(0,len(split_image_paths))
    img = np.load(split_image_paths[rand_idx])
    assert img.shape == (210,210,3)
    # print(f'max: {np.max(img, axis=(0,1))}, min: {np.min(img, axis=(0,1))}')
    # print(f'mean: {np.mean(img, axis=(0,1))}')
    axarr[i,0].imshow(img[:,:,0])
    axarr[i,1].imshow(img[:,:,1])
    axarr[i,2].imshow(img[:,:,2])
print(img.dtype)

### Check image

In [101]:
import nibabel as nib
RAW_DATA = '/vol/biodata/data/biobank/18545/data'
not_matching_ids = torch.load(join(FEATURES, 'not_matching_ids_cardiac.pt'))

In [102]:
print(not_matching_ids[:10])

[1000690, 1001312, 1001440, 1001562, 1002212, 1002361, 1002604, 1004082, 1004288, 1004403]


In [103]:
_id = 1001312
folder = join(RAW_DATA, str(_id))
_file = os.path.join(folder,'sa.nii.gz')
nii = nib.load(_file)
im = nii.get_fdata()
print(im.shape)

nii_es = nib.load(join(folder,'sa_ES.nii.gz'))
im_es = nii_es.get_fdata()
print(im_es.shape)
print(im_es.dtype)

(174, 208, 10, 50)
(174, 208, 10)
float64


In [104]:
# test the effect of nii.affine
test_slice = im[:,:,:,10]
test_slice_affine = nib.Nifti1Image(im[:, :, :, 10], nii.affine)
overlap_ratio = (test_slice==test_slice).sum()/test_slice.size
print(overlap_ratio)

1.0


In [105]:
best_overlap_es = 0
es_slice = im_es[:,:,im_es.shape[2]//2]
for i in range(50):
    im_slice = im[:,:,im.shape[2]//2,i]
    overlap_es = (es_slice==im_slice).sum()
    if overlap_es > best_overlap_es:
        best_overlap_es = overlap_es
        best_i_es = i
print(best_i_es)
# best_i_es = 19
im_slice = im[:,:,im.shape[2]//2,best_i_es]
overlap_ratio_es = (im_slice==es_slice).sum()/es_slice.size
print(overlap_ratio_es)
print(np.allclose(im_slice, es_slice))

22
0.9908819628647215
False


In [None]:
f, axarr = plt.subplots(1, 2, figsize=(13,5))
axarr[0].imshow(im_slice, cmap='gray')
axarr[1].imshow(es_slice, cmap='gray')
axarr[0].set_title('ES slice from sa.nii.gz')
axarr[1].set_title('ES slice from sa_ES.nii.gz')
plt.suptitle(f'Subject {_id}, overlap ratio: {overlap_ratio_es*100:.2f}')
plt.show()


### Check data_df

In [24]:
data_df = pd.read_csv(join(FEATURES, f'cardiac_feature_18545_vector_labeled_noOH_dropNI_imputed.csv'))
all_data_df = pd.read_csv(join(FEATURES, f'cardiac_feature_18545_vector_labeled_noOH_dropNI.csv'))
field_lengths_tabular = torch.load(join(FEATURES, 'tabular_lengths.pt'))

In [108]:
categorical_ids = []
continuous_ids = []
for i in range(len(field_lengths_tabular)):
    if field_lengths_tabular[i] == 1:
        continuous_ids.append(i)
    else:
        categorical_ids.append(i)
continuous_ids = [x+1 for x in continuous_ids]
categorical_ids = [x+1 for x in categorical_ids]

In [109]:
print(categorical_ids)

[50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75]


In [None]:
data_df

In [None]:
for _id in categorical_ids:
    print(data_df.iloc[:,_id].value_counts())