In [1]:
import torch
import os
import datasets
import numpy as np
from collections import defaultdict
# from foresight.datasets import patient_concept_stream
# from foresight.datasets.filters import filter_by_count, filter_by_type
# from foresight.datasets.utils import get_embeddings_for_tokens, stream_to_separate_examples, add_to_stream, \
#                                   remove_parents_from_stream, bucket_concepts, cleanup_stream, \
#                                   split_stream, add_age, get_all_splits, add_ttd, add_position_ids
# from foresight.utils import pickle
# from foresight.utils.cdb_utils import get_parents_map 
# from foresight.utils.stream_utils import docs2stream, calculate_counts
# from foresight.tokenizers.simple_map_tokenizer import SimpleMapTokenizer
# from foresight.metrics.next_concept_prediction import precision, metrics_data2df, ComputePrecisionHF
# from foresight.utils import pickle


In [2]:
from random import Random
from foresight.tokenizers.my_map_tokenizer import MapTokenizer
from datasets import Dataset

## Dummy Data

In [3]:
NUM_TIMELINES = 1000
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
NUM_TIMESTEPS = 10
RANDOM_SEED = Random(23)
MAX_SKIP = 10
MAX_NUM_SAMPLES = 3

SEQUENCE_LENGTH = 12
SEPARATOR_TOKEN = "<SEP>"
PADDING_TOKEN = '<PAD>'

In [4]:
def get_samples():
    for _ in range(1000):
        start_idx = RANDOM_SEED.randint(0, (len(LETTERS)/2))
        skip_odd = RANDOM_SEED.choice([True, False])
        num_samples = RANDOM_SEED.randint(1, MAX_NUM_SAMPLES)

        timeline = []
        for seq_idx in range(SEQUENCE_LENGTH):
            char_idx = start_idx + num_samples*seq_idx
            if skip_odd and char_idx % 2 == 1:
                timeline.append([])
            elif char_idx+num_samples >= len(LETTERS):
                timeline.append([])
            else:
                timeline.append(list(LETTERS[char_idx:char_idx+num_samples]))
            
        yield(
            {
                "timeline": timeline,
                "start_idx": start_idx,
                "skip_odd": skip_odd,
                "num_samples": num_samples,
            }
        )

In [5]:
dataset = Dataset.from_generator(get_samples)

In [6]:
dataset[0]

{'timeline': [['M'],
  ['N'],
  ['O'],
  ['P'],
  ['Q'],
  ['R'],
  ['S'],
  ['T'],
  ['U'],
  ['V'],
  ['W'],
  ['X']],
 'start_idx': 12,
 'skip_odd': False,
 'num_samples': 1}

In [7]:
def batched_timeline_to_tokens(batched_samples: dict[str, list], separator: str)->dict[str, list]:
    batched_samples["tokens"] = [
        [
            timestep_value
            for timestep in timeline
            for timestep_value in [separator] + timestep
        ]
        for timeline in batched_samples["timeline"]
    ]
    return batched_samples

In [8]:
dataset = dataset.map(lambda batch: batched_timeline_to_tokens(batch, SEPARATOR_TOKEN), batched=True)
dataset[0]["timeline"][:10], dataset[0]["tokens"][:10]

([['M'], ['N'], ['O'], ['P'], ['Q'], ['R'], ['S'], ['T'], ['U'], ['V']],
 ['<SEP>', 'M', '<SEP>', 'N', '<SEP>', 'O', '<SEP>', 'P', '<SEP>', 'Q'])

In [9]:
def batched_insert_static_feature_token(batched_samples: dict[str, list], key:str, insert_idx:int)->dict[str, list]:
    for idx, _ in enumerate(batched_samples["tokens"]):
        batched_samples["tokens"][idx].insert(insert_idx, f"{key}_{batched_samples[key][idx]}")
    return batched_samples


dataset = dataset.map(lambda batch: batched_insert_static_feature_token(batch, "start_idx", insert_idx=0), batched=True)
dataset = dataset.map(lambda batch: batched_insert_static_feature_token(batch, "skip_odd", insert_idx=1), batched=True)
dataset = dataset.map(lambda batch: batched_insert_static_feature_token(batch, "num_samples", insert_idx=2), batched=True)
dataset[0]["tokens"][:10]

['start_idx_12',
 'skip_odd_False',
 'num_samples_1',
 '<SEP>',
 'M',
 '<SEP>',
 'N',
 '<SEP>',
 'O',
 '<SEP>']

## Add position IDs

In [10]:
def batched_add_position_ids(batched_samples: dict[str, list], separators:set[str])->dict[str, list]:
    batched_samples["position_ids"] = []
    for tokens in batched_samples['tokens']:
        position_ids = []
        cnt = 0
        for token in tokens:
            if token in separators:
                cnt += 1
            position_ids.append(cnt)
        batched_samples["position_ids"].append(position_ids)
    return batched_samples

dataset = dataset.map(lambda batch: batched_add_position_ids(batch, {SEPARATOR_TOKEN}), batched=True)
[(dataset[0]["tokens"][idx], dataset[0]["position_ids"][idx]) for idx in range(10)]

[('start_idx_12', 0),
 ('skip_odd_False', 0),
 ('num_samples_1', 0),
 ('<SEP>', 1),
 ('M', 1),
 ('<SEP>', 2),
 ('N', 2),
 ('<SEP>', 3),
 ('O', 3),
 ('<SEP>', 4)]

In [11]:
dataset = dataset.train_test_split(test_size=0.2)
len(dataset["train"]), len(dataset["test"])

(800, 200)

# Make tokenizer

In [12]:
from collections import Counter

token_count = Counter(token for tokens in dataset["train"]["tokens"] for token in tokens)
token_count

Counter({'<SEP>': 9600,
         'O': 642,
         'M': 634,
         'Q': 590,
         'S': 580,
         'U': 542,
         'K': 541,
         'N': 504,
         'W': 481,
         'P': 467,
         'T': 452,
         'R': 449,
         'I': 442,
         'L': 432,
         'V': 419,
         'skip_odd_False': 402,
         'skip_odd_True': 398,
         'X': 369,
         'J': 356,
         'G': 336,
         'H': 291,
         'num_samples_3': 270,
         'num_samples_2': 265,
         'num_samples_1': 265,
         'E': 240,
         'F': 210,
         'Y': 181,
         'C': 154,
         'D': 133,
         'B': 80,
         'start_idx_10': 73,
         'start_idx_0': 67,
         'A': 67,
         'start_idx_12': 63,
         'start_idx_9': 62,
         'start_idx_8': 58,
         'start_idx_5': 57,
         'start_idx_2': 56,
         'start_idx_4': 56,
         'start_idx_3': 55,
         'start_idx_11': 54,
         'start_idx_7': 54,
         'start_idx_13': 51,
       

In [13]:
tkn_to_id = {tkn: idx for idx, tkn in enumerate(token_count.keys())}
tkn_to_id

{'start_idx_2': 0,
 'skip_odd_True': 1,
 'num_samples_2': 2,
 '<SEP>': 3,
 'C': 4,
 'D': 5,
 'E': 6,
 'F': 7,
 'G': 8,
 'H': 9,
 'I': 10,
 'J': 11,
 'K': 12,
 'L': 13,
 'M': 14,
 'N': 15,
 'O': 16,
 'P': 17,
 'Q': 18,
 'R': 19,
 'S': 20,
 'T': 21,
 'U': 22,
 'V': 23,
 'W': 24,
 'X': 25,
 'start_idx_10': 26,
 'skip_odd_False': 27,
 'num_samples_3': 28,
 'Y': 29,
 'start_idx_11': 30,
 'start_idx_4': 31,
 'start_idx_8': 32,
 'start_idx_7': 33,
 'start_idx_1': 34,
 'start_idx_12': 35,
 'num_samples_1': 36,
 'start_idx_5': 37,
 'B': 38,
 'start_idx_9': 39,
 'start_idx_13': 40,
 'start_idx_6': 41,
 'start_idx_0': 42,
 'A': 43,
 'start_idx_3': 44}

In [14]:
tokenizer = MapTokenizer(tkn_to_id)


In [24]:
ids = tokenizer(dataset["train"]["tokens"][0], is_split_into_words=True)
ids

{'input_ids': [0, 1, 2, 3, 4, 5, 3, 6, 7, 3, 8, 9, 3, 10, 11, 3, 12, 13, 3, 14, 15, 3, 16, 17, 3, 18, 19, 3, 20, 21, 3, 22, 23, 3, 24, 25, 3], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [23]:
tokenizer.convert_ids_to_tokens(ids["input_ids"]) == dataset["train"]["tokens"][0]

True

In [None]:
dataset["train"]["tokens"][0]

In [None]:
tokenizer(dataset["train"]["tokens"][0])

In [None]:
from transformers import AutoTokenizer

AutoTokenizer.from_pretrained("bert-base-cased").encode_plus(["hello", "world"])

In [None]:
tokenizer(dataset["train"]["tokens"][0])

In [None]:
tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,
                               token_type2tokens=token_type2tokens, embeddings=embeddings,
                               global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)

In [None]:
_types = list(cdb.addl_info['type_id2name'].keys()) + list(token_type2tokens.keys())
embeddings, tkn2id, id2tkn, = get_embeddings_for_tokens(dataset, cdb, context_type='xlong', types=_types,
                                                        concepts=extra_concepts)

In [None]:
tkn2name = {tkn:cdb.get_name(tkn) for tkn in tkn2id.keys()}
tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,
                               token_type2tokens=token_type2tokens, embeddings=embeddings,
                               global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)

In [None]:
assert len(tokenizer.tkn2id) == len(tokenizer.id2tkn)
assert len(tokenizer.embeddings) == len(tokenizer.id2tkn)
assert len(tokenizer.tkn2name) == len(tokenizer.id2tkn)
fprint(tokenizer.pad_id, tokenizer.id2tkn[tokenizer.pad_id])

In [None]:
len(tokenizer.tkn2name)

In [None]:
# save
tokenizer.save(TOKENIZER_PATH)

In [None]:
# Total number of different concepts after all filtering
fprint("Total number of concepts after filtering: ", len(tokenizer.tkn2id))
fprint("")

In [None]:
# Total number annotations after all filtering
fprint("Total number of annotations after filtering: ", sum([x for x in cnt_per_type_after.values()]))
fprint("")

# Print number of different concepts per type after filtering

In [None]:
cnt_per_type = {}
for cui in tkn2id:
    if cat.cdb.cui2type_ids.get(cui, ['Other']):
        t = list(cat.cdb.cui2type_ids.get(cui, ['Other']))[0]
        cnt_per_type[t] = cnt_per_type.get(t, 0) + 1
fprint("Total number of <<different>> concepts per type after filtering")
for t in cnt_per_type:
    fprint("{:30}: {}".format(cat.cdb.addl_info['type_id2name'].get(t, t).title(), cnt_per_type[t]))
fprint("")

# Create global tokenizer

In [None]:
_types = list(cdb.addl_info['type_id2name'].keys()) + list(token_type2tokens.keys())
concepts = list(cat.config.linking['filters']['cuis'])
embeddings, tkn2id, id2tkn, = get_embeddings_for_tokens(dataset, cdb, context_type='xlong', types=_types, concepts=concepts)

In [None]:
tkn2name = {tkn:cdb.get_name(tkn) for tkn in tkn2id.keys()}
tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,
                               token_type2tokens=token_type2tokens, embeddings=embeddings,
                               global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)

In [None]:
tokenizer.save(BASE_TOKENIZER_PATH)

# Convert tokens to IDs

In [None]:
if FROM_BASE:
    print("USING BASE TOKENIZER")
    TOKENIZER_PATH = BASE_TOKENIZER_PATH

In [None]:
tokenizer =  SimpleMapTokenizer.load(TOKENIZER_PATH)

In [None]:
encoded_dataset = dataset.map(
        lambda examples: tokenizer.encode(examples),
        batched=True,
        remove_columns=['stream'],
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
encoded_dataset.save_to_disk(PREPARED_DATASET_SPLIT_PATH)

In [None]:
PREPARED_DATASET_SPLIT_PATH

In [None]:
TOKENIZER_PATH

# Test is all OK

In [None]:
encoded_dataset = datasets.load_from_disk(PREPARED_DATASET_SPLIT_PATH)

In [None]:
dataset = datasets.load_from_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)

In [None]:
tokenizer = SimpleMapTokenizer.load(TOKENIZER_PATH)

In [None]:
encoded_dataset

In [None]:
dataset

In [None]:
ind = 1096

In [None]:
from datetime import datetime

In [None]:
[cdb.get_name(x) for x in dataset['train'][ind]['stream']]

In [None]:
for ty, p, t, c in zip(encoded_dataset['train'][ind]['token_type'], encoded_dataset['train'][ind]['position_ids'], encoded_dataset['train'][ind]['time'], tokenizer.convert_ids2tokens(encoded_dataset['train'][ind]['input_ids'])):
    print(datetime.fromtimestamp(t), p, "{:20}".format(ty), c)

In [None]:
encoded_dataset['train'][ind]['patient_id']

In [None]:
ds_info.close()

# Preapre for Foresight

In [None]:
ind = 32330

In [None]:
import json

In [None]:
[cdb.get_name(x) for x in dataset['train'][ind]['stream']]

In [None]:
for i, c in enumerate(dataset['train'][ind]['stream']):
    print(i)
    if i > 20 and c not in dataset['train'][ind]['stream'][0:i]:
        print(i, c, cdb.get_name(c))

In [None]:
out = []
for i, cui in enumerate(dataset['train'][ind]['stream'][:161]):
    d = {
        'id': cui,
        'label': cdb.get_name(cui),
        'count': 1000000,
        'name': cdb.get_name(cui),
        'cui': cui,
        'saliency': 0,
        'uid': i
    }
    out.append(d)

In [None]:
json.dump(out, open("./data/tmp/timeline_example_1.json", 'w'))

In [None]:
len(out)

In [None]:
out