In [1]:
import os
import re
from typing import Any
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from icdmappings import Mapper
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from torch.utils.data import DataLoader, Dataset, dataloader
from transformers import AutoModel, AutoTokenizer
from tqdm.auto import tqdm
from functools import reduce
import gc
from copy import deepcopy


pd.set_option("future.no_silent_downcasting", True)
pd.set_option("mode.chained_assignment", None)

In [2]:
pklData = pd.read_pickle("./data/patientData_1.pkl")
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    print(pklData.loc[0])

Patient IDX                                                                                               1
PatientAdministrationGenderCode                                                                        male
PatientBirthDateTime                                                                                     62
Systolic Blood Pressure                                                                                    
Diastolic Blood Pressure                                                                                   
Body Weight                                                                                                
Body Height                                                                                                
BMI                                                                                                        
Body Temperature                                                                                           
Heart Rate                  

In [3]:
class HIEDATA2(Dataset[Any]):
    def __init__(
        self,
        dataLocation="/code/app/data/",
        cacheRowLimit=200,
        cacheDir="./cache",
        debug=False,
    ):
        super().__init__()
        self.debug = debug
        self.dataPath = Path(dataLocation)
        self.cacheDir = Path(cacheDir)
        os.makedirs(self.cacheDir, exist_ok=True)
        self.cacheIndex = 0
        if not self.dataPath.is_dir():
            raise FileNotFoundError("Data directory doesn't exist")
        if debug:
            print("Found data directory")
        self.cacheRowLimit = cacheRowLimit
        self.dataFiles = list(self.dataPath.glob("*.pkl"))
        if not len(self.dataFiles):
            raise FileNotFoundError("No PKLs found in data directory")
        if debug:
            print(f"Found {len(self.dataFiles)} PKL files")
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        if debug:
            print(f"Using {self.device}")
        self.labelEncoder = LabelEncoder()
        self.currentFileIdx = None

        self.IDColumn = "Patient IDX"
        self.categoricalColumns = {
            "PatientAdministrationGenderCode": None,
            "Smoking Status": None,
            # "Urine Routine",
        }
        self.categEmbedColumns = {
            "SnomedEmbed": [
                "SNOMED Codes 0",
                "SNOMED Codes 1",
                "SNOMED Codes 2",
                "SNOMED Codes 3",
                "SNOMED Codes Other",
            ],
            "ProcedureEmbed": [
                "Procedure Codes 0",
                "Procedure Codes 1",
                "Procedure Codes 2",
                "Procedure Codes 3",
                "Procedure Codes Other",
            ],
        }
        self.ignoreColumns = [
            "ICD-10 Codes 0",
            "ICD-10 Codes 1",
            "ICD-10 Codes Other",
            "ICD-10 Codes 2",
            "ICD-10 Codes 3",
        ]
        self.cliNotesColumns = [
            "Projected_Med_Embeddings",
            "Projected_Note_Embeddings",
        ]
        self.yVar = [
            "Chronic kidney disease all stages (1 through 5)",
            "Acute Myocardial Infarction",
            "Hypertension Pulmonary hypertension",
            "Ischemic Heart Disease",
        ]
        self.yVarList = {
            "Diabetes": [
                "Type 1 Diabetes",
                "Type II Diabetes",
            ]
        }
        self.continuousColumns = [
            "PatientBirthDateTime",
            "Systolic Blood Pressure",
            "Diastolic Blood Pressure",
            "Body Weight",
            "Body Height",
            "BMI",
            "Body Temperature",
            "Heart Rate",
            "Oxygen Saturation",
            "Respiratory Rate",
            "Hemoglobin A1C",
            "Blood Urea Nitrogen",
            "Bilirubin lab test",
            "Troponin Lab Test",
            "Ferritin",
            "Glucose Tolerance Testing",
            "Cerebral Spinal Fluid (CSF) Analysis",
            "Arterial Blood Gas",
            "Comprehensive Metabolic Panel",
            "Chloride  Urine",
            "Calcium in Blood  Serum or Plasma",
            "Magnesium in Blood  Serum or Plasma",
            "Magnesium in Urine",
            "Chloride  Blood  Serum  or Plasma",
            "Creatinine  Urine",
            "Creatinine  Blood  Serum  or Plasma",
            "Phosphate Blood  Serum  or Plasma",
            "Coagulation Assay",
            "Complete Blood Count",
            "Creatine Kinase Blood  Serum  or Plasma",
            "D Dimer Test",
            "Electrolytes Panel Blood  Serum  or Plasma",
            "Inflammatory Markers (CRP) Blood  Serum  or Plasma",
            "Lipid Serum  or Plasma",
            "Sputum Culture",
            "Urine Collection 24 Hours",
            # "Urine Routine", This is string, should be included in categorical
        ]
        self.contData = []
        self.categData = []
        self.clinData = []
        self.labels = []
        self.dataLen = 0
        maxRows = 0
        self.labelEncoder = LabelEncoder()
        self.clinicalbert_model_name = (
            "emilyalsentzer/Bio_ClinicalBERT"  # "medicalai/ClinicalBERT"
        )
        self.clinicalbert_model = AutoModel.from_pretrained(
            self.clinicalbert_model_name
        ).to(self.device)
        self.clinicalbert_tokenizer = AutoTokenizer.from_pretrained(
            self.clinicalbert_model_name
        )
        if debug:
            print("Pre Processing Data and creating Cache...")
        initCategData = []
        for f in tqdm(self.dataFiles):
            df = pd.read_pickle(f)
            initCategData.append(
                df[list(self.categoricalColumns.keys())].astype("string")
            )
        initCategData = pd.concat(initCategData, ignore_index=True)
        for categLabel in self.categoricalColumns:
            self.categoricalColumns[categLabel] = LabelEncoder().fit(
                initCategData[categLabel]
            )
        for f in tqdm(self.dataFiles):
            df = pd.read_pickle(f)
            patientIDXs = df[self.IDColumn].unique()
            for patientID in tqdm(patientIDXs[:5]):
                contStack = []
                categStack = []
                clinStack = []
                # print(patientID)
                patientRows = df.loc[df[self.IDColumn] == patientID]
                patientRows = patientRows.ffill().bfill().fillna(0)
                contData = patientRows[self.continuousColumns]
                contData = contData.reset_index().drop(columns="index")
                clinicalEmbeddings = patientRows[self.cliNotesColumns]
                # categData = self.labelEncoder.fit_transform(
                # patientRows[self.categoricalColumns].fillna(0)
                # )
                for categEmbed in self.categEmbedColumns:
                    subset = (
                        patientRows[self.categEmbedColumns[categEmbed]]
                        .reset_index()
                        .drop(columns="index")
                    )
                    newData = []
                    subset = subset.astype("string")
                    # print(patientID, categEmbed)
                    for i in range(len(subset)):
                        # print(i)
                        newData.append(
                            self.genEmbeddings("".join(subset.loc[i].to_list()))
                        )
                    clinicalEmbeddings[categEmbed] = newData
                    # print(clinicalEmbeddings)
                    # print("-------------------")
                    del newData
                # for i in clinicalEmbeddings.columns:
                #     print(len(clinicalEmbeddings[i].to_numpy()[0]))
                clinicalEmbeddings = clinicalEmbeddings.to_numpy()
                categData = patientRows[self.categoricalColumns.keys()].to_numpy()
                for i, d in enumerate(self.categoricalColumns):
                    categData[:, i] = self.categoricalColumns[d].transform(
                        categData[:, i]
                    )
                # for d in range(len(contData["PatientBirthDateTime"])):
                #     contData.at[d, "PatientBirthDateTime"] = self.dateToInt(
                #         contData["PatientBirthDateTime"][d]
                #     )
                for col in contData:
                    contData[col] = pd.to_numeric(contData[col])
                # with pd.option_context('display.max_rows', None, 'display.max_columns', None):
                #     print(contData)
                # Z Score Norming continuous values.
                contData = (
                    ((contData - contData.mean()) / (contData.std() + 1e-100))
                    .fillna(0)
                    .to_numpy()
                )
                assert (
                    len(contData) == len(categData) == len(clinicalEmbeddings)
                ), print(
                    f"{len(contData)}, {len(categData)}, {len(clinicalEmbeddings)}, {patientID}"
                )
                labels = patientRows[self.yVar]
                for col in self.yVarList:
                    orValues = reduce(
                        lambda a, b: a | b, patientRows[self.yVarList[col]].T.to_numpy()
                    )
                    labels.loc[:, col] = orValues
                labels = (
                    labels.reset_index()
                    .drop(columns="index")
                    .replace("", np.nan)
                    .fillna(0)
                    .astype(int)
                    .to_numpy()
                )
                for dRow in range(len(labels)):
                    contStack.append(contData[dRow])
                    categStack.append(categData[dRow])
                    clinStack.append(np.stack(clinicalEmbeddings[dRow]))
                    if np.any(labels[dRow]):
                        # print(dRow)
                        self.contData.append(np.stack(deepcopy(contStack)))
                        self.categData.append(np.stack(deepcopy(categStack)))
                        self.clinData.append(np.stack(deepcopy(clinStack)))
                        self.labels.append(np.stack(deepcopy(labels[dRow])))
                        self.dataLen += 1
                        # if len(self.contData) > self.cacheRowLimit:
                        #     self.cacheToDisk()
                    else:
                        pass
                if self.contData:
                    if maxRows < len(self.contData[-1]):
                        maxRows = len(self.contData[-1])
                del contStack, categStack, clinStack

                # break
                # break

        print(f"Max roll up Rows: {maxRows}")
        for i, d in enumerate(self.contData):
            self.contData[i] = np.vstack(
                [
                    d,
                    np.zeros(
                        (maxRows - d.shape[0], *d.shape[1:]),
                        dtype=d.dtype,
                    ),
                ]
            )
        for i, d in enumerate(self.categData):
            self.categData[i] = np.vstack(
                [
                    d,
                    np.zeros(
                        (maxRows - d.shape[0], *d.shape[1:]),
                        dtype=d.dtype,
                    ),
                ]
            )
        for i, d in enumerate(self.clinData):
            self.clinData[i] = np.vstack(
                [
                    d,
                    np.zeros(
                        (maxRows - d.shape[0], *d.shape[1:]),
                        dtype=d.dtype,
                    ),
                ]
            )
        # print(np.stack(self.contData).shape)
        self.contData = torch.tensor(self.contData)
        self.contData = self.contData.reshape((len(self.contData), -1))
        self.categData = torch.tensor(np.array(self.categData).astype(int))
        self.categData = self.categData.reshape((len(self.categData), -1))
        self.clinData = torch.tensor(self.clinData)
        self.clinData = self.clinData.reshape((len(self.clinData), -1))
        # print(self.labels)
        self.contDataInputShape = self.contData.shape[-1]
        self.categDataInputShape = self.categData.shape[-1]
        self.clinDataInputShape = self.clinData.shape[-1]
        self.labels = torch.tensor(np.array(self.labels).astype(int))
        self.labelOutputShape = self.labels.shape[-1]
        while len(self.contData) != 0:
            self.cacheToDisk()
        if debug:
            print("Created Cache")
        self.loadRequiredFile(0)
        if debug:
            print("Done initializing Dataset")

    def getTotalRowCount(self):
        totalRows = 0
        for dataFile in self.dataFiles:
            totalRows += len(pd.read_pickle(dataFile))
        return totalRows

    def getShapes(self):
        return (
            self.contDataInputShape,
            self.categDataInputShape,
            self.clinDataInputShape,
            self.labelOutputShape
        )

    def __getitem__(self, index):
        fileIDX = index // self.cacheRowLimit
        localIDX = index % self.cacheRowLimit
        if self.currentFileIdx != fileIDX:
            self.loadRequiredFile(fileIDX)
        return (
            self.contData[localIDX],
            self.categData[localIDX],
            self.clinData[localIDX],
            self.labels[localIDX],
        )

    def __len__(self):
        return self.dataLen

    def loadRequiredFile(self, fileIDX):
        if self.currentFileIdx == fileIDX:
            return
        else:
            if self.debug:
                print(f"Loading Cache File {fileIDX+1}")
            data = torch.load(self.cacheDir / f"dataCache{fileIDX}.pkl")
            self.contData = data["cont"]
            self.categData = data["categ"]
            self.clinData = data["cli"]
            self.labels = data["labels"]
            self.currentFileIdx = fileIDX

    def cacheToDisk(self):
        torch.save(
            {
                "cont": self.contData[: self.cacheRowLimit],
                "categ": self.categData[: self.cacheRowLimit],
                "cli": self.clinData[: self.cacheRowLimit],
                "labels": self.labels[: self.cacheRowLimit],
            },
            self.cacheDir / f"dataCache{self.cacheIndex}.pkl",
        )
        self.contData = self.contData[self.cacheRowLimit :]
        self.categData = self.categData[self.cacheRowLimit :]
        self.clinData = self.clinData[self.cacheRowLimit :]
        self.labels = self.labels[self.cacheRowLimit :]
        gc.collect()
        if self.debug:
            print(f"Creating Cache File {self.cacheIndex + 1}")
        self.cacheIndex += 1

    def chunkText(self, text, chunkSize=512):
        chunks = []
        idx = 0
        while idx < len(text):
            end = min(idx + chunkSize, len(text))
            chunk = text[idx:end]
            if end < len(text) and not re.match(r"\b\w+\b$", chunk):
                chunk += " " + text[end]
                end += 1
            chunks.append(chunk)
            idx = end
        return chunks

    def genEmbeddings(self, text):
        chunks = self.chunkText(text)
        embeddings = []
        for chunk in chunks:
            encoded_input = self.clinicalbert_tokenizer(
                chunk, return_tensors="pt", padding="max_length", truncation=True
            )
            encoded_input = encoded_input.to(self.device)
            with torch.no_grad():
                output = self.clinicalbert_model(**encoded_input)
                last_hidden_state = output.last_hidden_state
                chunk_embedding = torch.mean(last_hidden_state, dim=1)
                embeddings.append(chunk_embedding.cpu().detach().numpy())
        if embeddings:
            return np.mean(np.concatenate(embeddings), axis=0)
        else:
            return np.zeros((768,))



In [4]:
data = HIEDATA2("./data", debug=True, )

Found data directory
Found 1 PKL files
Using cpu




Pre Processing Data and creating Cache...


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Max roll up Rows: 46


  self.contData = torch.tensor(self.contData)


Creating Cache File 1
Created Cache
Loading Cache File 1
Done initializing Dataset


  data = torch.load(self.cacheDir / f"dataCache{fileIDX}.pkl")


In [4]:
# categCols = [            "PatientAdministrationGenderCode",
#             "SNOMED Codes 0",
#             "SNOMED Codes 1",
#             "SNOMED Codes 2",
#             "SNOMED Codes 3",
#             "SNOMED Codes Other",
#             "Smoking Status",
#             "Procedure Codes 0",
#             "Procedure Codes 1",
#             "Procedure Codes 2",
#             "Procedure Codes 3",
#             "Procedure Codes Other",
#             "Urine Routine",]

# with pd.option_context('display.max_rows', None, 'display.max_columns', None):
#     print(pklData.loc[pklData["Patient IDX"] == 1][categCols].dtypes)
#     print(pklData.loc[pklData["Patient IDX"] == 1][categCols])
#     print(pklData.loc[pklData["Patient IDX"] == 1][categCols]["Smoking Status"].astype("string").fillna("asd"))

print(np.stack(pklData["Projected_Med_Embeddings"]).shape)
# for col in pklData:
#     print(col)

NameError: name 'pklData' is not defined

### Code block to print data shapes and types

In [5]:
for i in range(0, len(data), 200):
    print(f"index: {i}", data[i][0].shape, data[i][1].shape, data[i][2].shape, data[i][3].shape)
i=60
print()
print("continuous data breakdown:")
print(data[i][0].shape, type(data[i][0]) )
print()
print("categorical data breakdown:")
print(data[i][1].shape, type(data[i][1]) )
print()
print("clinical data breakdown:")
print(data[i][2].shape, type(data[i][2]))
print("labels breakdown:")
print(data[i][3])

index: 0 torch.Size([1656]) torch.Size([92]) torch.Size([141312]) torch.Size([5])

continuous data breakdown:
torch.Size([1656]) <class 'torch.Tensor'>

categorical data breakdown:
torch.Size([92]) <class 'torch.Tensor'>

clinical data breakdown:
torch.Size([141312]) <class 'torch.Tensor'>
labels breakdown:
tensor([0, 0, 0, 1, 1])


In [6]:
data.getShapes()

(1656, 92, 141312, 5)

In [7]:
categoricalColumns = [
            "PatientAdministrationGenderCode",
            "Smoking Status",
            "Urine Routine",
        ]
categEmbedColumns = {
            "SnomedEmbed": [
                "SNOMED Codes 0",
                "SNOMED Codes 1",
                "SNOMED Codes 2",
                "SNOMED Codes 3",
                "SNOMED Codes Other",
            ],
            "ProcedureEmbed": [
                "Procedure Codes 0",
                "Procedure Codes 1",
                "Procedure Codes 2",
                "Procedure Codes 3",
                "Procedure Codes Other",
            ],
        }

for i in categEmbedColumns:
    tempData = pklData[categEmbedColumns[i]]
    tempData = tempData.astype("string")
    break

print("".join(tempData.loc[100].to_list()))

422650009 314529007 423315002 73595000 40055000  160903007  36971009  59621000  5251000175109  162864005  160904001  73438004  741062008  444814009


In [8]:
import torch.nn as nn

class MUFASAFull(nn.Module):
    def __init__(self, contFeatureLen, categLen, clinicNotesLen, outputLen):
        super(MUFASAFull, self).__init__()
        #! # layers for continuous features branch
        self.contInput = nn.Linear(contFeatureLen, 128)
        self.inputNorm = nn.LayerNorm(128)
        self.attention = nn.MultiheadAttention(128, 4)
        self.lRelu = nn.LeakyReLU()
        self.dropout1 = nn.Dropout(0.4)
        # ? Addition Layer
        self.nextLayerNorm = nn.LayerNorm(128)
        self.conv1 = nn.Linear(128, 512)
        self.relu = nn.ReLU()
        self.dropout2 = nn.Dropout(0.3)
        self.next2LayerNorm = nn.LayerNorm(512)
        self.conv2 = nn.Linear(512, 128)
        # ? Addition Layer

        #! # layers for categorical features branch
        self.categInput = nn.Linear(categLen, 128)
        self.cat_2_layerNorm = nn.LayerNorm(128)
        self.cat_3_self_attention = nn.MultiheadAttention(128, 4)
        self.cat_4_conv1 = nn.Linear(256, 256)
        self.cat_5_relu = nn.ReLU()
        self.cat_dropout = nn.Dropout(0.4)
        # ? addition
        self.cat_branch_layerNorm = nn.LayerNorm(256)
        self.cat_branch_conv2 = nn.Linear(256, 384)
        self.cat_branch_relu = nn.ReLU()
        self.cat_branch_dropout = nn.Dropout(0.3)

        #! # layers for clinical features branch
        self.cliInput = nn.Linear(clinicNotesLen, 128)
        self.cli_2_layerNorm = nn.LayerNorm(128)
        self.cli_3_selfAtt = nn.MultiheadAttention(128, 4)
        # ? addition
        self.cli_4_layerNorm = nn.LayerNorm(128)
        self.cli_5_conv1 = nn.Linear(128, 512)
        self.cli_6_relu = nn.ReLU()
        self.cli_6_1_dropout = nn.Dropout(0.4)
        self.cli_7_layerNorm = nn.LayerNorm(512)
        self.cli_8_conv2 = nn.Linear(512, 128)
        # ? addition
        # ? Fuse concatenation between ret1 from categorical and current
        self.cli_9_layerNorm = nn.LayerNorm(384)
        self.cli_10_conv3l = nn.Linear(384, 1536)
        self.cli_11_conv3r = nn.Linear(384, 384)
        self.cli_12_relu = nn.ReLU()
        self.cli_12_1_dropout = nn.Dropout(0.3)
        self.cli_13_conv4 = nn.Linear(1536, 384)
        # ? Fuse addition between concat output, current, right branch conv,
        # ? continuous branch, ret2 from categorical branch

        #! Final Output from addition
        self.out = nn.Linear(384, outputLen)

    def forward(self, sample):
        contIn, cateIn, clinIn = sample
        # print(contIn.shape, cateIn.shape, clinIn.shape)
        contOutput = self.continuousFeaturesForward(contIn)
        ret1, ret2 = self.categoricalFeaturesForward(cateIn)
        res = self.clinicalFeaturesForward(clinIn, ret1, ret2, contOutput)
        res = self.out(res)
        return res

    def categoricalFeaturesForward(self, inp):
        xa = self.categInput(inp)
        x = xa.clone()
        x = self.cat_2_layerNorm(x)
        x = self.cat_3_self_attention(x, x, x, need_weights=False)
        xb = torch.concatenate([x[0], xa], dim=1)
        x = self.cat_4_conv1(xb)
        x = self.cat_dropout(self.cat_5_relu(x))

        xBran = self.cat_branch_layerNorm(xb)
        ret1 = torch.add(xBran, x)
        xBran = self.cat_branch_conv2(xBran)
        ret2 = self.cat_branch_dropout(self.cat_branch_relu(xBran))
        return ret1, ret2

    def clinicalFeaturesForward(self, inp, ret1, ret2, contFeat):
        xa = self.cliInput(inp)
        x = xa.clone()
        x = self.cli_2_layerNorm(x)
        x = self.cli_3_selfAtt(x, x, x, need_weights=False)
        # raise NotImplementedError("Change the skip strategy")
        x = torch.add(x[0], xa)
        xb = self.cli_4_layerNorm(x)
        x = self.cli_5_conv1(xb)
        x = self.cli_6_1_dropout(self.cli_6_relu(x))
        x = self.cli_7_layerNorm(x)
        x = self.cli_8_conv2(x)
        x = torch.add(x, xb)
        xc = torch.concatenate([x, ret1], dim=1)
        x = self.cli_9_layerNorm(xc)
        xdl = self.cli_10_conv3l(x)
        xdr = self.cli_11_conv3r(x)
        x = self.cli_12_1_dropout(self.cli_12_relu(xdl))
        x = self.cli_13_conv4(x)
        # print(x.shape, xc.shape, contFeat.shape, xdr.shape, ret2.shape)
        x = x + xc + nn.functional.pad(contFeat, (0, 256), value=0) + xdr + ret2
        return x

    def continuousFeaturesForward(self, inp):
        x = self.contInput(inp)
        sav0 = x.clone()
        x = self.inputNorm(x)
        x = self.attention(x, x, x, need_weights=False)
        x = self.dropout1(self.lRelu(x[0]))
        addOutput = torch.add(x, sav0)
        x = self.nextLayerNorm(addOutput)
        x = self.conv1(x)
        x = self.dropout2(self.relu(x))
        x = self.next2LayerNorm(x)
        x = self.conv2(x)
        output = torch.add(x, addOutput)
        return output



In [9]:
model=MUFASAFull(*data.getShapes())

In [10]:
dl=DataLoader(data, batch_size=4, shuffle=False)

In [11]:
for d4 in dl:
    cont, cate, clin, lab = d4
    a=model([cont.float(), cate.float(), clin.float()])
    print(a.shape, a)
    break

torch.Size([4, 5]) tensor([[ 0.5068, -1.0627, -1.0535,  1.0040, -1.2035],
        [ 0.6805, -0.9053, -1.2376,  0.9532, -1.1185],
        [ 0.6199, -1.5490, -1.1506,  0.5832, -1.4645],
        [ 0.2582, -1.1938, -0.8834,  1.0256, -1.6013]],
       grad_fn=<AddmmBackward0>)


In [14]:
torch.numel((nn.Sigmoid()(a)>0.5).int())

20

In [70]:
ls=nn.BCEWithLogitsLoss()
l=ls(a, lab.float())
# print(torch.mean(l.flatten()))
print(l)

tensor(1.0427, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
