In [1]:
%load_ext autoreload
%autoreload 2

In [92]:
import piton
import piton.datasets
import pickle
import numpy as np

from transformers import AutoModel, AutoTokenizer, AutoModelForMaskedLM
from typing import Tuple, List
import multiprocessing

In [3]:
def load_from_file(path_to_file: str):
    """Load object from Pickle file."""
    with open(path_to_file, "rb") as fd:
        result = pickle.load(fd)
    return result

In [6]:
path_to_model = "/local-scratch/nigam/projects/clmbr_text_assets/models/Bio_ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(path_to_model)
model = AutoModel.from_pretrained(path_to_model)

Some weights of the model checkpoint at /local-scratch/nigam/projects/clmbr_text_assets/models/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
max_length = 512
padding = True
truncation = True

# Run notes through an already-trained tokenizer
notes_tokenized = tokenizer(
    [note.value for note in notes],
    padding=padding,
    truncation=truncation,
    max_length=max_length,
    return_tensors="pt",
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(path_to_model)
max_length = 512
padding = True
truncation = True

# Run notes through an already-trained tokenizer
notes_tokenized = tokenizer(
    [note.value for note in notes],
    padding=padding,
    truncation=truncation,
    max_length=max_length,
    return_tensors="pt",
)

model = AutoModel.from_pretrained(args.get("path_to_model"))

outputs = model(**notes_tokenized)


In [10]:
labeled_patients = load_from_file("/local-scratch/nigam/projects/rthapa84/data/HighHbA1c_labeled_patients_v3.pickle")

In [11]:
len(labeled_patients)

418465

In [12]:
num_threads = 10

database_path = "/local-scratch/nigam/projects/ethanid/som-rit-phi-starr-prod.starr_omop_cdm5_deid_2022_09_05_extract2"

database = piton.datasets.PatientDatabase(database_path)
pids = sorted(labeled_patients.get_all_patient_ids())

pids = pids[:40]

pids_parts = np.array_split(pids, num_threads)

tasks = [(database_path, pid_part, labeled_patients) for pid_part in pids_parts]

In [13]:
database_path, pids, labeled_patients = tasks[0]

In [14]:
database = piton.datasets.PatientDatabase(database_path)
ontology = database.get_ontology()

In [15]:
MAX_CHAR = 100
# for event in database[0].events:
#     if type(event.value) == memoryview:
#         text = bytes(event.value).decode("utf-8")
        
#         if len(text) < NUM_CHAR:
#             continue
#         print(text)
#         print()

In [16]:
" ".join(["Mu", "Name", "is"])

'Mu Name is'

In [50]:
items = [i for i in range(1239)]

In [None]:
for 

In [89]:
def _get_patient_text_data(patient, labels):
    text_for_each_label = []

    label_idx = 0
    current_text = []
    for event in patient.events:
        while event.start > labels[label_idx].time:
            label_idx += 1
            text_for_each_label.append(" ".join(current_text))

            if label_idx >= len(labels):
                return text_for_each_label

        if type(event.value) is not memoryview:
            continue

        text_data = bytes(event.value).decode("utf-8")

        if len(text_data) < MAX_CHAR:
            continue

        current_text.append(text_data)

    if label_idx < len(labels):
        for label in labels[label_idx:]:
            text_for_each_label.append(" ".join(current_text))


    return text_for_each_label
    


def _run_text_featurizer(database_path, pids, labeled_patients, path_to_model):
    
    # database_path, pids, labeled_patients, path_to_model = args
    database = piton.datasets.PatientDatabase(database_path)
    tokenizer = AutoTokenizer.from_pretrained(path_to_model)
    model = AutoModel.from_pretrained(path_to_model)
    max_length = 512
    padding = True
    truncation = True
    CHUNK_SIZE = 10
    
    data = []
    patient_ids = []
    result_labels = []
    labeling_time = []
    
    for patient_id in pids:
        patient = database[patient_id]
        labels = labeled_patients.pat_idx_to_label(patient_id)

        if len(labels) == 0:
            continue
        
        patient_text_data = _get_patient_text_data(patient, labels)
        
        for i, label in enumerate(labels):
            data.append(patient_text_data[i])
            result_labels.append(label.value)
            patient_ids.append(patient.patient_id)
            labeling_time.append(label.time)
    
    embeddings = []
    for chunk in range(0, len(data), CHUNK_SIZE):
        notes_tokenized = tokenizer(
                                data[chunk:chunk+CHUNK_SIZE],
                                padding=padding,
                                truncation=truncation,
                                max_length=max_length,
                                return_tensors="pt",
                            )
        outputs = model(**notes_tokenized)
        batch_embedding_tensor = outputs.last_hidden_state[:, 0, :].squeeze()
        batch_embedding_numpy = batch_embedding_tensor.cpu().detach().numpy()
        embeddings.append(batch_embedding_numpy)
    
    embeddings = np.concatenate(embeddings)
            
    return embeddings, result_labels, patient_ids, labeling_time

        

In [None]:
pids = sorted(labeled_patients.get_all_patient_ids())[:100]
num_threads = 5

pids_parts = np.array_split(pids, num_threads)

tasks = [(database_path, pid_part, labeled_patients, path_to_model) for pid_part in pids_parts]

ctx = multiprocessing.get_context('forkserver')
with ctx.Pool(num_threads) as pool:
    text_featurizers_tuple_list = list(pool.imap(_run_text_featurizer, tasks))

In [87]:
embeddings, _, _, _ = _run_text_featurizer(database_path, pids, labeled_patients, path_to_model)

Some weights of the model checkpoint at /local-scratch/nigam/projects/clmbr_text_assets/models/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [88]:
embeddings.shape

(4, 768)

In [39]:
tokenizer = AutoTokenizer.from_pretrained(path_to_model)
max_length = 512
padding = True
truncation = True

# Run notes through an already-trained tokenizer
notes_tokenized = tokenizer(
    data,
    padding=padding,
    truncation=truncation,
    max_length=max_length,
    return_tensors="pt",
)

model = AutoModel.from_pretrained(path_to_model)

outputs = model(**notes_tokenized)

Some weights of the model checkpoint at /local-scratch/nigam/projects/clmbr_text_assets/models/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [46]:
def embed_with_cls(embeddings):
    return embeddings[:, 0, :].squeeze()

In [48]:
embed_with_cls(outputs.last_hidden_state)

torch.Size([4, 768])

In [67]:
outputs.last_hidden_state[:, 0, :].squeeze().cpu().detach().numpy()

array([[-0.27650014, -0.013582  , -0.66790944, ...,  0.05187273,
         0.1168906 ,  0.02797246],
       [-0.11857603,  0.25548074, -0.6114352 , ..., -0.08299585,
         0.17861778, -0.36386025],
       [ 0.36844736,  0.3425017 , -0.5480165 , ..., -0.3408979 ,
         0.60281944, -0.22029027],
       [ 0.22482194, -0.24363495, -0.03791487, ...,  0.08538163,
         0.39754403,  0.08916175]], dtype=float32)

{'input_ids': tensor([[ 101, 1218, 2027,  ..., 1336, 4506,  102],
        [ 101, 4113, 3516,  ..., 5682,  129,  102],
        [ 101, 3372, 1106,  ...,    0,    0,    0],
        [ 101,  102,    0,  ...,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 0,  ..., 0, 0, 0]])}

In [53]:
patient_id = 10
patient = database[patient_id]
labels = labeled_patients.pat_idx_to_label(patient_id)

text_1 = _get_patient_text_data(patient, labels)[0]

In [25]:
for patient_id in pids:
    patient = database[patient_id]
    labels = labeled_patients.pat_idx_to_label(patient_id)
    
    if len(labels) == 0:
        continue
        
    data = []
    result_labels = []
    patient_ids = []
    labeling_time = []
    
    
    
    label_idx = 0
    current_text = []
    for event in patient.events:
        while event.start > labels[label_idx].time:
            label_idx += 1
            text_for_each_label.append(" ".join(current_text))

            if label_idx >= len(labels):
                return text_for_each_label

        if type(event.value) is not memoryview:
            continue
            
        text_data = bytes(event.value).decode("utf-8")
        
        if len(text_data) < MAX_CHAR:
            continue
        
        current_text.append(text_data)

    if label_idx < len(labels):
        for label in labels[label_idx:]:
            text_for_each_label.append(" ".join(current_text))
        
    
    return text_for_each_label
        
        
    
    
    

[Label(time=datetime.datetime(2019, 10, 10, 7, 49), value=False)]

[Label(time=datetime.datetime(2021, 7, 3, 10, 23), value=False)]

[Label(time=datetime.datetime(2017, 10, 13, 11, 11), value=False)]

[Label(time=datetime.datetime(2017, 4, 18, 11, 35), value=True)]



In [None]:
def _run_featurizer(args: Tuple[str, List[int], labeled_patients, List[Featurizer]]) -> Tuple[Any, Any, Any, Any]:

    # print("launched")
    data = []
    indices: List[int] = []
    indptr = []

    result_labels = []
    patient_ids = []
    labeling_time = []

    database_path, pids, labeled_patients, featurizers = args

    database = PatientDatabase(database_path)
    ontology = database.get_ontology()

    for patient_id in pids:
        # print("launched", patient_id)
        patient = database[patient_id]
        labels = labeled_patients.pat_idx_to_label(patient_id)

        if len(labels) == 0:
            continue

        columns_by_featurizer = []

        for featurizer in featurizers:
            columns = featurizer.featurize(patient, labels, ontology)
            assert len(columns) == len(labels), (
                f"The featurizer {featurizer} didn't provide enough rows for "
                f"{labeling_function} on patient {patient.patient_id} ({len(columns)} != {len(labels)})"
            )
            columns_by_featurizer.append(columns)

        for i, label in enumerate(labels):
            indptr.append(len(indices))
            result_labels.append(label.value)
            patient_ids.append(patient.patient_id)
            labeling_time.append(label.time)

            column_offset = 0
            for j, feature_columns in enumerate(columns_by_featurizer):
                for column, value in feature_columns[i]:
                    assert (
                        0 <= column < featurizers[j].num_columns()
                    ), (
                        f"The featurizer {featurizers[j]} provided an out of bounds column for "
                        f"{labeling_function} on patient {patient.patient_id} ({column} should be between 0 and "
                        f"{featurizers[j].num_columns()})"
                    )
                    indices.append(column_offset + column)
                    data.append(value)

                column_offset += featurizers[j].num_columns()
    indptr.append(len(indices))

    data = np.array(data, dtype=np.float32)
    indices = np.array(indices, dtype=np.int32)
    indptr = np.array(indptr, dtype=np.int32)
    result_labels = np.array(result_labels)
    patient_ids = np.array(patient_ids, dtype=np.int32)
    labeling_time = np.array(patient_ids, dtype=np.datetime64)

    total_columns = sum(
        featurizer.num_columns() for featurizer in featurizers
    )

    data_matrix = scipy.sparse.csr_matrix(
        (data, indices, indptr), shape=(len(result_labels), total_columns)
    )

    # print("Done", data_matrix.shape)

    # data_matrix.check_format() # remove when we think its works

    # print(data_matrix.shape, result_labels.shape, patient_ids.shape, labeling_time.shape)

    return data_matrix, result_labels, patient_ids, labeling_time


class FeaturizerList:
    """
    Featurizer list consists of a list of featurizers that will be used (in sequence) to featurize data.
    It enables preprocessing of featurizers, featurization, and column name extraction.
    """

    def __init__(self, featurizers: List[Featurizer]):
        """Create a :class:`FeaturizerList` from a sequence of featurizers.

        Args:
            featurizers (List[Featurizer]): The featurizers to use for featurizeing patients.
        """
        self.featurizers = featurizers

    def preprocess_featurizers(
        self,
        # patients: Sequence[Patient],
        labeled_patients: LabeledPatients,
        database_path: str,
        num_threads: int = 1,
    ) -> None:
        """preprocess a list of featurizers on the provided patients using the given labeler.

        Args:
            patients (List[Patient]): Sequence of patients.
            labeling_function (:class:`labelers.core.LabelingFunction`): The labeler to preprocess with.
        """

        any_needs_preprocessing = any(
            featurizer.needs_preprocessing() for featurizer in self.featurizers
        )

        if not any_needs_preprocessing:
            return

        pids = sorted(labeled_patients.get_all_patient_ids())

        pids_parts = np.array_split(pids, num_threads)

        tasks = [(database_path, pid_part, labeled_patients, self.featurizers) for pid_part in pids_parts]

        ctx = multiprocessing.get_context('forkserver')
        with ctx.Pool(num_threads) as pool:
            trained_featurizers_tuple_list = list(pool.imap(_run_preprocess_featurizers, tasks))

        age_featurizers = []
        count_featurizers = []

        for trained_featurizers_tuple in trained_featurizers_tuple_list:
            age_featurizers.append(trained_featurizers_tuple[0])
            count_featurizers.append(trained_featurizers_tuple[1])

        # Aggregating age featurizers
        for age_featurizer in age_featurizers:
            if age_featurizer.to_dict()["age_statistics"]["current_mean"] != 0:
                self.featurizers[0].from_dict(age_featurizer.to_dict())
                break

        # Aggregating count featurizers
        patient_codes_dict_list = [count_featurizer.to_dict()["patient_codes"]["values"] for count_featurizer in count_featurizers]
        patient_codes = list(itertools.chain.from_iterable(patient_codes_dict_list))
        self.featurizers[1].from_dict({"patient_codes": {"values": patient_codes}})

        for featurizer in self.featurizers:
            featurizer.finalize_preprocessing()

    def featurize(
        self,
        labeled_patients: LabeledPatients,
        database_path: str,
        num_threads: int = 1,
    ) -> Tuple[Any, Any, Any, Any]:
        """
        Apply a list of featurizers to obtain a feature matrix and label vector for the given patients.
        Args:
            patients (List[Patient]): Sequence of patients
            labeling_function (:class:`labelers.core.LabelingFunction`): The labeler to preprocess with.
        Returns:
            This returns a tuple (data_matrix, labels, patient_ids, labeling_time).
            data_matrix is a sparse matrix of all the features of all the featurizers.
            labels is a list of boolean values representing the labels for each row in the matrix.
            patient_ids is a list of the patient ids for each row.
            labeling_time is a list of labeling/prediction time for each row.
        """

        # TODO check what is happening here
        # print(len(labeled_patients.get_all_patient_ids()))
        pids = sorted(labeled_patients.get_all_patient_ids())
        # print(pids)

        # pids = [i for i in range(len(patients))]
        pids_parts = np.array_split(pids, num_threads)

        tasks = [(database_path, pid_part, labeled_patients, self.featurizers) for pid_part in pids_parts]

        # multiprocessing.set_start_method('spawn', force=True)
        # print("This is before lunch", len(tasks))
        ctx = multiprocessing.get_context('forkserver')
        with ctx.Pool(num_threads) as pool:
            results = list(pool.imap(_run_featurizer, tasks))
        # print("Finished multiprocessing")

        data_matrix_list = []
        result_labels_list = []
        patient_ids_list = []
        labeling_time_list = []
        for result in results:
            # if result[0].shape[0] != 0:
            data_matrix_list.append(result[0])
            # if result[1].shape[1] != 0:
            result_labels_list.append(result[1])
            patient_ids_list.append(result[2])
            labeling_time_list.append(result[3])
        
        data_matrix = scipy.sparse.vstack(data_matrix_list)
        result_labels = np.concatenate(result_labels_list, axis=None)
        patient_ids = np.concatenate(patient_ids_list, axis=None)
        labeling_time = np.concatenate(labeling_time_list, axis=None)

        return (
            data_matrix,
            result_labels,
            patient_ids,
            labeling_time,
        )

    def get_column_name(self, column_index: int) -> str:
        offset = 0

        for featurizer in self.featurizers:
            if offset <= column_index < (offset + featurizer.num_columns()):
                return f"Featurizer {featurizer}, {featurizer.get_column_name(column_index - offset)}"

            offset += featurizer.num_columns()

        assert False, "This should never happen"

In [8]:
class CountFeaturizer(Featurizer):
    """
    Produces one column per each diagnosis code, procedure code, and prescription code.
    The value in each column is the count of how many times that code appears in the patient record
    before the corresponding label.
    """

    def __init__(
        self,
        # ontology: extension_datasets.Ontology,
        rollup: bool = False,
        exclusion_codes: List[int] = [],
        time_bins: Optional[
            List[Optional[int]]
        ] = None,  # [90, 180] refers to [0-90, 90-180]; [90, 180, math.inf] refers to [0-90, 90-180, 180-inf]
    ):
        self.patient_codes: Dictionary = Dictionary()
        self.exclusion_codes = set(exclusion_codes)
        self.time_bins = time_bins
        # self.ontology = ontology
        self.rollup = rollup

    def get_codes(self, code: int, ontology: extension_datasets.Ontology) -> Iterator[int]:
        if code not in self.exclusion_codes:
            if self.rollup:
                for subcode in ontology.get_all_parents(code):
                    yield subcode
            else:
                yield code

    def preprocess(self, patient: Patient, labels: List[Label]):
        """Adds every event code in this patient's timeline to `patient_codes`"""
        for event in patient.events:
            if event.value is None:
                self.patient_codes.add(event.code)

    def num_columns(self) -> int:
        if self.time_bins is None:
            return len(self.patient_codes)
        else:
            return len(self.time_bins) * len(self.patient_codes)

    def featurize(
        self, patient: Patient, labels: List[Label], ontology: extension_datasets.Ontology,
    ) -> List[List[ColumnValue]]:
        all_columns: List[List[ColumnValue]] = []

        if self.time_bins is None:
            current_codes: Dict[int, int] = defaultdict(int)

            label_idx = 0
            for event in patient.events:
                while event.start > labels[label_idx].time:
                    label_idx += 1
                    all_columns.append(
                        [
                            ColumnValue(column, count)
                            for column, count in current_codes.items()
                        ]
                    )

                    if label_idx >= len(labels):
                        return all_columns

                if event.value is not None:
                    continue

                for code in self.get_codes(event.code, ontology):
                    if code in self.patient_codes:
                        current_codes[self.patient_codes.transform(code)] += 1

            if label_idx < len(labels):
                for label in labels[label_idx:]:
                    all_columns.append(
                        [
                            ColumnValue(column, count)
                            for column, count in current_codes.items()
                        ]
                    )
                

                # if label_idx == len(labels) - 1:
                #     all_columns.append(
                #         [
                #             ColumnValue(column, count)
                #             for column, count in current_codes.items()
                #         ]
                #     )
                #     break

KeyError: 0

In [None]:
class TextFeaturizer(Featurizer):