In [113]:
import torch
import os
import datasets
import numpy as np
from collections import defaultdict
# from foresight.datasets import patient_concept_stream
# from foresight.datasets.filters import filter_by_count, filter_by_type
# from foresight.datasets.utils import get_embeddings_for_tokens, stream_to_separate_examples, add_to_stream, \
#                                   remove_parents_from_stream, bucket_concepts, cleanup_stream, \
#                                   split_stream, add_age, get_all_splits, add_ttd, add_position_ids
# from foresight.utils import pickle
# from foresight.utils.cdb_utils import get_parents_map 
# from foresight.utils.stream_utils import docs2stream, calculate_counts
# from foresight.tokenizers.simple_map_tokenizer import SimpleMapTokenizer
# from foresight.metrics.next_concept_prediction import precision, metrics_data2df, ComputePrecisionHF
# from foresight.utils import pickle


In [114]:
from random import Random
from pydantic import BaseModel
from datasets import Dataset

## Dummy Data

In [115]:
NUM_TIMELINES = 1000
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
NUM_TIMESTEPS = 10
RANDOM_SEED = Random(23)
MAX_SKIP = 10
MAX_NUM_SAMPLES = 3

SEQUENCE_LENGTH = 12*5
SEPARATOR = "<SEP>"

In [116]:
class DummySample(BaseModel):
    timeline: list[str]
    gap: int
    num_samples: int
    uppercase: bool

In [117]:
double_letters = [l_1+l_2 for l_1 in LETTERS for l_2 in LETTERS]

In [118]:
def get_samples():
    for _ in range(1000):
        start_idx = RANDOM_SEED.randint(0, len(double_letters) - 1)
        skip_odd = RANDOM_SEED.choice([True, False])
        num_samples = RANDOM_SEED.randint(1, MAX_NUM_SAMPLES)

        timeline = []
        for seq_idx in range(SEQUENCE_LENGTH):
            char_idx = start_idx + num_samples*seq_idx
            if skip_odd and char_idx % 2 == 1:
                timeline.append([])
            elif char_idx+num_samples >= len(double_letters):
                timeline.append([])
            else:
                timeline.append(double_letters[char_idx:char_idx+num_samples])
            
        yield(
            {
                "timeline": timeline,
                "start_idx": start_idx,
                "skip_odd": skip_odd,
                "num_samples": num_samples,
            }
        )

In [119]:
dataset = Dataset.from_generator(get_samples)

In [120]:
dataset[0]

{'timeline': [['LK'],
  [],
  ['LM'],
  [],
  ['LO'],
  [],
  ['LQ'],
  [],
  ['LS'],
  [],
  ['LU'],
  [],
  ['LW'],
  [],
  ['LY'],
  [],
  ['MA'],
  [],
  ['MC'],
  [],
  ['ME'],
  [],
  ['MG'],
  [],
  ['MI'],
  [],
  ['MK'],
  [],
  ['MM'],
  [],
  ['MO'],
  [],
  ['MQ'],
  [],
  ['MS'],
  [],
  ['MU'],
  [],
  ['MW'],
  [],
  ['MY'],
  [],
  ['NA'],
  [],
  ['NC'],
  [],
  ['NE'],
  [],
  ['NG'],
  [],
  ['NI'],
  [],
  ['NK'],
  [],
  ['NM'],
  [],
  ['NO'],
  [],
  ['NQ'],
  []],
 'start_idx': 296,
 'skip_odd': True,
 'num_samples': 1}

In [121]:
def batched_timeline_to_tokens(batched_samples: dict[str, list], separator: str)->dict[str, list]:
    batched_samples["tokens"] = [
        [
            timestep_value
            for timestep in timeline
            for timestep_value in [separator] + timestep
        ]
        for timeline in batched_samples["timeline"]
    ]
    return batched_samples

In [122]:
dataset = dataset.map(lambda batch: batched_timeline_to_tokens(batch, SEPARATOR), batched=True)


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [124]:
def batched_insert_static_feature_token(batched_samples: dict[str, list], key:str, insert_idx:int)->dict[str, list]:
    for idx, _ in enumerate(batched_samples["tokens"]):
        batched_samples["tokens"][idx].insert(insert_idx, str(batched_samples[key][idx]))
    return batched_samples


dataset = dataset.map(lambda batch: batched_insert_static_feature_token(batch, "start_idx", insert_idx=0), batched=True)
dataset = dataset.map(lambda batch: batched_insert_static_feature_token(batch, "skip_odd", insert_idx=1), batched=True)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [125]:
dataset[0]["tokens"][:10]

['296', 'True', '<SEP>', 'LK', '<SEP>', '<SEP>', 'LM', '<SEP>', '<SEP>', 'LO']

## Add position IDs

In [None]:
def batched_add_position_ids(batched_samples: dict[str, list], separators:set[str])->dict[str, list]:
    batched_samples["position_ids"] = []
    for tokens in batched_samples['stream']:
        position_ids = []
        cnt = 0
        for token in tokens:
            position_ids.append(cnt)
            if token in separators:
                cnt += 1
        batched_samples["position_ids"].append(position_ids)
    return batched_samples

dataset = dataset.map(lambda batch: batched_add_position_ids(batch, {SEPARATOR}), batched=True)

In [None]:
dataset = dataset.map(
        lambda examples: add_position_ids(examples, separators={'<SEP>', '<SEP-1>', '<SEP-7>' '<SEP-14>', '<SEP-30>', '<SEP-365>'}),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

# Get token_type2tokens

In [None]:
token_type2tokens = defaultdict(set)
total_cnt = 0
for _dataset in get_all_splits(dataset):
    for stream in _dataset['stream']:
        for example in stream:
            token_type2tokens[example['token_type']].add(example['token'])
            total_cnt += 1
token_type2tokens = dict(token_type2tokens)
pickle.dump(token_type2tokens, TOKEN_TYPES_PATH)
fprint("Total number of annotations: ", total_cnt)

In [None]:
pickle.dump(token_type2tokens, TOKEN_TYPES_PATH)
fprint("Total number of annotations: ", total_cnt)

# Cleanup stream and leave only what we need

In [None]:
dataset = dataset.map(
        lambda examples: cleanup_stream(examples, keep_time=True, keep_type=True, keep_position_ids=True,
                                        keep_context_representation=False),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

### Save

In [None]:
dataset.save_to_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)

In [None]:
dataset = datasets.load_from_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)

In [None]:
JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH

In [None]:
# Total number of patients fater intial filtering
train_len = len(dataset['train'])
test_len = len(dataset['test'])
fprint("Total number of pts in train: ", train_len)
fprint("Total number of pts in test: ", test_len)
fprint("Total number of pts: ", train_len + test_len)

In [None]:
# Total number of annotations per type after filtering
cnt_per_type_after = {}
for _dataset in get_all_splits(dataset):
    for stream in _dataset['stream']:
        for cui in stream:
            if cat.cdb.cui2type_ids.get(cui, None):
                t = list(cat.cdb.cui2type_ids[cui])[0]
                cnt_per_type_after[t] = cnt_per_type_after.get(t, 0) + 1

In [None]:
fprint("Total number of annotations per type: \n")
for t in cnt_per_type_after:
    fprint("{:30}: {}".format(cat.cdb.addl_info['type_id2name'][t].title(), cnt_per_type_after[t]))

# Make tokenizer

In [None]:
extra_tokenizer = None
#extra_tokenizer = SimpleMapTokenizer.load("./data/time/models/slam_tokenizer_annotations_stream_phase2_1d_200_ALL_TYPES.pickle")

In [None]:
token_type2tokens = pickle.load(TOKEN_TYPES_PATH)
extra_concepts = None
if extra_tokenizer is not None:
    extra_concepts = list(extra_tokenizer.tkn2id.keys())

    for k,v in extra_tokenizer.token_type2tokens.items():
        if k in token_type2tokens:
            token_type2tokens[k].update(extra_tokenizer.token_type2tokens[k])
        else:
            token_type2tokens[k] = extra_tokenizer.token_type2tokens[k]

In [None]:
_types = list(cdb.addl_info['type_id2name'].keys()) + list(token_type2tokens.keys())
embeddings, tkn2id, id2tkn, = get_embeddings_for_tokens(dataset, cdb, context_type='xlong', types=_types,
                                                        concepts=extra_concepts)

In [None]:
tkn2name = {tkn:cdb.get_name(tkn) for tkn in tkn2id.keys()}
tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,
                               token_type2tokens=token_type2tokens, embeddings=embeddings,
                               global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)

In [None]:
assert len(tokenizer.tkn2id) == len(tokenizer.id2tkn)
assert len(tokenizer.embeddings) == len(tokenizer.id2tkn)
assert len(tokenizer.tkn2name) == len(tokenizer.id2tkn)
fprint(tokenizer.pad_id, tokenizer.id2tkn[tokenizer.pad_id])

In [None]:
len(tokenizer.tkn2name)

In [None]:
# save
tokenizer.save(TOKENIZER_PATH)

In [None]:
# Total number of different concepts after all filtering
fprint("Total number of concepts after filtering: ", len(tokenizer.tkn2id))
fprint("")

In [None]:
# Total number annotations after all filtering
fprint("Total number of annotations after filtering: ", sum([x for x in cnt_per_type_after.values()]))
fprint("")

# Print number of different concepts per type after filtering

In [None]:
cnt_per_type = {}
for cui in tkn2id:
    if cat.cdb.cui2type_ids.get(cui, ['Other']):
        t = list(cat.cdb.cui2type_ids.get(cui, ['Other']))[0]
        cnt_per_type[t] = cnt_per_type.get(t, 0) + 1
fprint("Total number of <<different>> concepts per type after filtering")
for t in cnt_per_type:
    fprint("{:30}: {}".format(cat.cdb.addl_info['type_id2name'].get(t, t).title(), cnt_per_type[t]))
fprint("")

# Create global tokenizer

In [None]:
_types = list(cdb.addl_info['type_id2name'].keys()) + list(token_type2tokens.keys())
concepts = list(cat.config.linking['filters']['cuis'])
embeddings, tkn2id, id2tkn, = get_embeddings_for_tokens(dataset, cdb, context_type='xlong', types=_types, concepts=concepts)

In [None]:
tkn2name = {tkn:cdb.get_name(tkn) for tkn in tkn2id.keys()}
tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,
                               token_type2tokens=token_type2tokens, embeddings=embeddings,
                               global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)

In [None]:
tokenizer.save(BASE_TOKENIZER_PATH)

# Convert tokens to IDs

In [None]:
if FROM_BASE:
    print("USING BASE TOKENIZER")
    TOKENIZER_PATH = BASE_TOKENIZER_PATH

In [None]:
tokenizer =  SimpleMapTokenizer.load(TOKENIZER_PATH)

In [None]:
encoded_dataset = dataset.map(
        lambda examples: tokenizer.encode(examples),
        batched=True,
        remove_columns=['stream'],
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
encoded_dataset.save_to_disk(PREPARED_DATASET_SPLIT_PATH)

In [None]:
PREPARED_DATASET_SPLIT_PATH

In [None]:
TOKENIZER_PATH

# Test is all OK

In [None]:
encoded_dataset = datasets.load_from_disk(PREPARED_DATASET_SPLIT_PATH)

In [None]:
dataset = datasets.load_from_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)

In [None]:
tokenizer = SimpleMapTokenizer.load(TOKENIZER_PATH)

In [None]:
encoded_dataset

In [None]:
dataset

In [None]:
ind = 1096

In [None]:
from datetime import datetime

In [None]:
[cdb.get_name(x) for x in dataset['train'][ind]['stream']]

In [None]:
for ty, p, t, c in zip(encoded_dataset['train'][ind]['token_type'], encoded_dataset['train'][ind]['position_ids'], encoded_dataset['train'][ind]['time'], tokenizer.convert_ids2tokens(encoded_dataset['train'][ind]['input_ids'])):
    print(datetime.fromtimestamp(t), p, "{:20}".format(ty), c)

In [None]:
encoded_dataset['train'][ind]['patient_id']

In [None]:
ds_info.close()

# Preapre for Foresight

In [None]:
ind = 32330

In [None]:
import json

In [None]:
[cdb.get_name(x) for x in dataset['train'][ind]['stream']]

In [None]:
for i, c in enumerate(dataset['train'][ind]['stream']):
    print(i)
    if i > 20 and c not in dataset['train'][ind]['stream'][0:i]:
        print(i, c, cdb.get_name(c))

In [None]:
out = []
for i, cui in enumerate(dataset['train'][ind]['stream'][:161]):
    d = {
        'id': cui,
        'label': cdb.get_name(cui),
        'count': 1000000,
        'name': cdb.get_name(cui),
        'cui': cui,
        'saliency': 0,
        'uid': i
    }
    out.append(d)

In [None]:
json.dump(out, open("./data/tmp/timeline_example_1.json", 'w'))

In [None]:
len(out)

In [None]:
out