In [1]:
import torch
import os
import datasets
import numpy as np
from collections import defaultdict
# from foresight.datasets import patient_concept_stream
# from foresight.datasets.filters import filter_by_count, filter_by_type
# from foresight.datasets.utils import get_embeddings_for_tokens, stream_to_separate_examples, add_to_stream, \
#                                   remove_parents_from_stream, bucket_concepts, cleanup_stream, \
#                                   split_stream, add_age, get_all_splits, add_ttd, add_position_ids
# from foresight.utils import pickle
# from foresight.utils.cdb_utils import get_parents_map 
# from foresight.utils.stream_utils import docs2stream, calculate_counts
# from foresight.tokenizers.simple_map_tokenizer import SimpleMapTokenizer
# from foresight.metrics.next_concept_prediction import precision, metrics_data2df, ComputePrecisionHF
# from foresight.utils import pickle


In [2]:
from random import Random
from datasets import Dataset
from collections import Counter
from pathlib import Path

In [3]:
from foresight.tokenizers.simple_map_tokenizer_v2 import SimpleMapTokenizer


## Dummy Data

In [4]:
NUM_TIMELINES = 1000
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
NUM_TIMESTEPS = 10
RANDOM_SEED = Random(23)
MAX_SKIP = 10
MAX_NUM_SAMPLES = 3

SEQUENCE_LENGTH = 12
SEPARATOR_TOKEN = "<SEP>"
PADDING_TOKEN = '<PAD>'

In [5]:
def get_samples():
    for _ in range(1000):
        start_idx = RANDOM_SEED.randint(0, (len(LETTERS)/2))
        skip_odd = RANDOM_SEED.choice([True, False])
        num_samples = RANDOM_SEED.randint(1, MAX_NUM_SAMPLES)

        timeline = []
        for seq_idx in range(SEQUENCE_LENGTH):
            char_idx = start_idx + num_samples*seq_idx
            if skip_odd and char_idx % 2 == 1:
                timeline.append([])
            elif char_idx+num_samples >= len(LETTERS):
                timeline.append([])
            else:
                timeline.append(list(LETTERS[char_idx:char_idx+num_samples]))
            
        yield(
            {
                "timeline": timeline,
                "start_idx": start_idx,
                "skip_odd": skip_odd,
                "num_samples": num_samples,
            }
        )

In [6]:
dataset = Dataset.from_generator(get_samples)

In [7]:
dataset[0]

{'timeline': [['M'],
  ['N'],
  ['O'],
  ['P'],
  ['Q'],
  ['R'],
  ['S'],
  ['T'],
  ['U'],
  ['V'],
  ['W'],
  ['X']],
 'start_idx': 12,
 'skip_odd': False,
 'num_samples': 1}

In [8]:
def batched_timeline_to_tokens(batched_samples: dict[str, list], separator: str)->dict[str, list]:
    batched_samples["tokens"] = [
        [
            timestep_value
            for timestep in timeline
            for timestep_value in [separator] + timestep
        ]
        for timeline in batched_samples["timeline"]
    ]
    return batched_samples

In [9]:
dataset = dataset.map(lambda batch: batched_timeline_to_tokens(batch, SEPARATOR_TOKEN), batched=True)
dataset[0]["timeline"][:10], dataset[0]["tokens"][:10]

([['M'], ['N'], ['O'], ['P'], ['Q'], ['R'], ['S'], ['T'], ['U'], ['V']],
 ['<SEP>', 'M', '<SEP>', 'N', '<SEP>', 'O', '<SEP>', 'P', '<SEP>', 'Q'])

In [10]:
def batched_insert_static_feature_token(batched_samples: dict[str, list], key:str, insert_idx:int)->dict[str, list]:
    for idx, _ in enumerate(batched_samples["tokens"]):
        batched_samples["tokens"][idx].insert(insert_idx, f"{key}_{batched_samples[key][idx]}")
    return batched_samples


dataset = dataset.map(lambda batch: batched_insert_static_feature_token(batch, "start_idx", insert_idx=0), batched=True)
dataset = dataset.map(lambda batch: batched_insert_static_feature_token(batch, "skip_odd", insert_idx=1), batched=True)
dataset = dataset.map(lambda batch: batched_insert_static_feature_token(batch, "num_samples", insert_idx=2), batched=True)
dataset[0]["tokens"][:10]

['start_idx_12',
 'skip_odd_False',
 'num_samples_1',
 '<SEP>',
 'M',
 '<SEP>',
 'N',
 '<SEP>',
 'O',
 '<SEP>']

## Add position IDs

In [11]:
def batched_add_position_ids(batched_samples: dict[str, list], separators:set[str])->dict[str, list]:
    batched_samples["position_ids"] = []
    for tokens in batched_samples['tokens']:
        position_ids = []
        cnt = 0
        for token in tokens:
            if token in separators:
                cnt += 1
            position_ids.append(cnt)
        batched_samples["position_ids"].append(position_ids)
    return batched_samples

dataset = dataset.map(lambda batch: batched_add_position_ids(batch, {SEPARATOR_TOKEN}), batched=True)
[(dataset[0]["tokens"][idx], dataset[0]["position_ids"][idx]) for idx in range(10)]

[('start_idx_12', 0),
 ('skip_odd_False', 0),
 ('num_samples_1', 0),
 ('<SEP>', 1),
 ('M', 1),
 ('<SEP>', 2),
 ('N', 2),
 ('<SEP>', 3),
 ('O', 3),
 ('<SEP>', 4)]

In [12]:
dataset = dataset.train_test_split(test_size=0.2)
len(dataset["train"]), len(dataset["test"])

(800, 200)

# Make tokenizer

In [13]:
token_count = Counter(token for tokens in dataset["train"]["tokens"] for token in tokens)
token_count

Counter({'<SEP>': 9600,
         'O': 638,
         'M': 633,
         'Q': 577,
         'S': 567,
         'K': 536,
         'U': 536,
         'N': 495,
         'W': 468,
         'P': 454,
         'I': 449,
         'T': 443,
         'R': 434,
         'L': 416,
         'V': 409,
         'skip_odd_True': 401,
         'skip_odd_False': 399,
         'J': 352,
         'X': 350,
         'G': 343,
         'H': 292,
         'num_samples_1': 275,
         'num_samples_3': 269,
         'num_samples_2': 256,
         'E': 250,
         'F': 206,
         'Y': 175,
         'C': 166,
         'D': 144,
         'B': 82,
         'start_idx_0': 65,
         'A': 65,
         'start_idx_12': 65,
         'start_idx_2': 65,
         'start_idx_10': 64,
         'start_idx_9': 60,
         'start_idx_3': 59,
         'start_idx_7': 57,
         'start_idx_8': 54,
         'start_idx_13': 53,
         'start_idx_5': 53,
         'start_idx_4': 53,
         'start_idx_11': 51,
       

In [14]:
tokenizer = SimpleMapTokenizer.from_vocab(token_count.keys())
tokenizer(dataset["train"]["tokens"][0])

{'input_ids': [46,
  31,
  28,
  1,
  12,
  1,
  13,
  1,
  14,
  1,
  15,
  1,
  16,
  1,
  17,
  1,
  18,
  1,
  19,
  1,
  20,
  1,
  21,
  1,
  22,
  1,
  23],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1],
 'token_type_ids': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0]}

In [15]:
tokenizer.save(Path.cwd() / "outputs" / "tokenizer")

In [16]:
encoded_dataset = dataset.map(
        lambda batch: tokenizer.batch_encode(batch),
        batched=True)

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [17]:
encoded_dataset["train"][0]

{'timeline': [['J'],
  ['K'],
  ['L'],
  ['M'],
  ['N'],
  ['O'],
  ['P'],
  ['Q'],
  ['R'],
  ['S'],
  ['T'],
  ['U']],
 'start_idx': 9,
 'skip_odd': False,
 'num_samples': 1,
 'tokens': ['start_idx_9',
  'skip_odd_False',
  'num_samples_1',
  '<SEP>',
  'J',
  '<SEP>',
  'K',
  '<SEP>',
  'L',
  '<SEP>',
  'M',
  '<SEP>',
  'N',
  '<SEP>',
  'O',
  '<SEP>',
  'P',
  '<SEP>',
  'Q',
  '<SEP>',
  'R',
  '<SEP>',
  'S',
  '<SEP>',
  'T',
  '<SEP>',
  'U'],
 'position_ids': [0,
  0,
  0,
  1,
  1,
  2,
  2,
  3,
  3,
  4,
  4,
  5,
  5,
  6,
  6,
  7,
  7,
  8,
  8,
  9,
  9,
  10,
  10,
  11,
  11,
  12,
  12],
 'input_ids': [46,
  31,
  28,
  1,
  12,
  1,
  13,
  1,
  14,
  1,
  15,
  1,
  16,
  1,
  17,
  1,
  18,
  1,
  19,
  1,
  20,
  1,
  21,
  1,
  22,
  1,
  23],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1],
 'token_type_ids': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
 

In [18]:
encoded_dataset.save_to_disk(Path.cwd() / "outputs" / "encoded_dataset")

Saving the dataset (0/1 shards):   0%|          | 0/800 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/200 [00:00<?, ? examples/s]