In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.7"
os.environ["JAX_NUMPY_RANK_PROMOTION"] = "raise"


In [2]:
t = 'survival_clmbr'
step = '39000'

In [3]:
import os
import pickle

import piton.extension.dataloader
import msgpack
import numpy as np

data_path = "/local-scratch/nigam/projects/ethanid/piton_1_extract"

dictionary_path = (
    "/local-scratch/nigam/projects/ethanid/piton/native/results/dictionary"
)

surv_dictionary_path = (
    "/local-scratch/nigam/projects/ethanid/piton/native/results/survival_clmbr_dictionary"
)

import piton.datasets

data = piton.datasets.PatientDatabase(data_path)
male_code = data.get_code_dictionary().index("Gender/M")

import json

dictionary = msgpack.load(open(dictionary_path, 'rb'), use_list=False)

if t == 'survival_clmbr':
    surv_dict = msgpack.load(open(surv_dictionary_path, 'rb'), use_list=False)
    print(surv_dict.keys())
    task = {"type": "survival_clmbr", "survival_dict": surv_dict, "dim": 256}
elif t == 'clmbr':
    task = {"type": "clmbr", 'vocab_size': 10_000}
else:
    labels = []

    if False:
        limit = 100
    else:
        limit = len(data)

    for patient_id in range(0, limit):
        patient = data[patient_id]
        is_male = any(event.code == male_code for event in patient.events)
        labels.append((patient.patient_id, 1, is_male))
    task = {"type": "binary", "labels": labels}

config = {
    "transformer": {
        "vocab_size": 50000,
        "dictionary": dictionary,
        "hidden_size": 768,
        "intermediate_size": 3072,
        "n_heads": 12,
        "n_layers": 6,
    
        "rotary": "per_head",
        
        "max_size": 14,
        "min_size": 5,
        
        "attention_width": 256,
    },
    

    "task": task,
    "seed": 97,
    "splits": [["train", 0, 70], ["dev", 70, 85], ["test", 85, 100]],
    "learning_rate": 1e-3,
    "max_grad_norm": 1.0,
    "l2": 0,

    "n_epochs": 100,
}

print('WORKING WITH', config['learning_rate'])


with open("trash/config.json", "bw") as f:
    msgpack.dump(config, f)

loader = piton.extension.dataloader.BatchCreator(data_path, "trash/config.json")
print(loader.get_number_of_batches("train"))
print(loader.get_number_of_batches("dev"))
print(loader.get_number_of_batches("test"))
print("Starting to load ...")

dict_keys(['codes', 'lambdas', 'time_bins'])
WORKING WITH 0.001
Assigning ... 16685 16339 16384
777
173
180
Starting to load ...


In [4]:
1 << 14

16384

In [6]:
def compute_total_loss(split, params, rng):
    total_loss = 0
    num_batches = min(1000, loader.get_number_of_batches(split))
    for i in range(0, num_batches):
        batch = loader.get_batch(split, i)
        if batch['task']['num_valid'] == 0:
            print('WAT', i)
            print(1/0)

    return total_loss / num_batches

compute_total_loss('train', None, None)

0.0

In [5]:
loader.get_batch('train', 351)

{'num_patients': 1,
 'patient_ids': array([16685,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,

In [24]:
p = data[16685]

len(p.events)

192572

In [26]:
p.events[:100]

(Event(start=1983-06-19 00:00:00, code=269),
 Event(start=1983-06-19 23:59:00, code=525),
 Event(start=1983-06-19 23:59:00, code=432),
 Event(start=1983-06-19 23:59:00, code=11, value='The above results were imported from Skaug. For these results, units,  reference ranges, and flagging of abnormal or critical values are not  available:'),
 Event(start=1983-06-19 23:59:00, code=1282),
 Event(start=1999-03-12 10:35:00, code=11, value='0401 C00170R'),
 Event(start=1999-03-12 10:35:00, code=9871),
 Event(start=1999-03-18 20:10:00, code=3911),
 Event(start=1999-03-18 20:10:00, code=11, value='Biopsy/Tissue(Gen)GS SPECIMEN DESCRIPTION:   CYST PILONIDOL SPECIMEN COMMENT / SPECIAL REQUEST:   aprox 2 cm's in anaport and bloody swab in culturette DIRECT EXAM / GRAM STAIN:   small number POLYS seen   small number GRAM POSITIVE COCCI in pairs and groups   rare to small number GRAM NEGATIVE RODS   rare to small number GRAM POSITIVE RODS    SLIDE REVIEWED: ORIGINAL READING CONFIRMED. SOME ORGANISMS 

In [27]:
print(loader.get_number_of_batches("train"))
print(loader.get_number_of_batches("dev"))
print(len(data))

777
173
27412


In [5]:
print(loader.get_batch("train", 0))
len(data[18416].events)

{'num_patients': 2, 'patient_ids': array([10022,  4756,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,

2327

In [6]:
import piton.models.transformer
import jax
import haiku as hk
import jax.numpy as jnp

def roberta_classification_fn(batch):
    model = piton.models.transformer.EHRTransformer(config)(batch)
    return model


dummy_batch=jax.tree_map(lambda a: jnp.array(a), loader.get_batch("train", 2))
print("Batch info", jax.tree_map(lambda a: (a.shape, a.dtype), dummy_batch))

rng = jax.random.PRNGKey(42)
roberta_classifier = hk.transform(roberta_classification_fn)

random_params = roberta_classifier.init(
    rng,
    batch=dummy_batch,
)

print("Params info", jax.tree_map(lambda a: a.shape, random_params))

Batch info {'num_patients': ((), dtype('int32')), 'patient_ids': ((256,), dtype('uint32')), 'task': {'event_indices': ((1048576, 2), dtype('uint32')), 'num_valid': ((), dtype('int32')), 'sparse_time': [((16385,), dtype('uint32')), ((16384,), dtype('float32')), ((1048576,), dtype('uint32')), ((1048576,), dtype('float32'))]}, 'transformer': {'ages': ((8192,), dtype('float32')), 'label_indices': ((1024,), dtype('uint32')), 'length': ((), dtype('int32')), 'normalized_ages': ((8192,), dtype('float32')), 'tokens': ((8192,), dtype('uint32'))}}
Compiling the transformer ... (8192,) (1024,)




Params info {'EHRTransformer/~/SurvivalCLMBRTask': {'code_weights': (8192, 256)}, 'EHRTransformer/~/SurvivalCLMBRTask/~/linear': {'b': (4080,), 'w': (768, 4080)}, 'EHRTransformer/~/TransformerFeaturizer/~/Transformer/~/embed': {'embeddings': (50000, 768)}, 'EHRTransformer/~/TransformerFeaturizer/~/Transformer/~/layer_norm': {'offset': (768,), 'scale': (768,)}, 'EHRTransformer/~/TransformerFeaturizer/~/Transformer/~/loop_0/TransformerBlock/~/layer_norm': {'offset': (768,), 'scale': (768,)}, 'EHRTransformer/~/TransformerFeaturizer/~/Transformer/~/loop_0/TransformerBlock/~/linear': {'b': (5376,), 'w': (768, 5376)}, 'EHRTransformer/~/TransformerFeaturizer/~/Transformer/~/loop_0/TransformerBlock/~/linear_1': {'b': (768,), 'w': (3840, 768)}, 'EHRTransformer/~/TransformerFeaturizer/~/Transformer/~/loop_1/TransformerBlock/~/layer_norm': {'offset': (768,), 'scale': (768,)}, 'EHRTransformer/~/TransformerFeaturizer/~/Transformer/~/loop_1/TransformerBlock/~/linear': {'b': (5376,), 'w': (768, 5376)

In [7]:
loaded_params = pickle.load(open(f"../native/result_clmbr/params_{t}_{step}", "rb"))

print("Loaded params info", jax.tree_map(lambda a: a.shape, loaded_params))

Loaded params info {'EHRTransformer/~/SurvivalCLMBRTask': {'code_weights': (8192, 512)}, 'EHRTransformer/~/SurvivalCLMBRTask/~/linear': {'b': (8176,), 'w': (768, 8176)}, 'EHRTransformer/~/TransformerFeaturizer/~/Transformer/~/embed': {'embeddings': (50000, 768)}, 'EHRTransformer/~/TransformerFeaturizer/~/Transformer/~/layer_norm': {'offset': (768,), 'scale': (768,)}, 'EHRTransformer/~/TransformerFeaturizer/~/Transformer/~/loop_0/TransformerBlock/~/layer_norm': {'offset': (768,), 'scale': (768,)}, 'EHRTransformer/~/TransformerFeaturizer/~/Transformer/~/loop_0/TransformerBlock/~/linear': {'b': (5376,), 'w': (768, 5376)}, 'EHRTransformer/~/TransformerFeaturizer/~/Transformer/~/loop_0/TransformerBlock/~/linear_1': {'b': (768,), 'w': (3840, 768)}, 'EHRTransformer/~/TransformerFeaturizer/~/Transformer/~/loop_1/TransformerBlock/~/layer_norm': {'offset': (768,), 'scale': (768,)}, 'EHRTransformer/~/TransformerFeaturizer/~/Transformer/~/loop_1/TransformerBlock/~/linear': {'b': (5376,), 'w': (768

In [8]:
from typing import TypeVar

T = TypeVar("T")


def _cast_floating_to(tree: T, dtype: jnp.dtype) -> T:
    def conditional_cast(x):
        if isinstance(x, (np.ndarray, jnp.ndarray)) and jnp.issubdtype(
            x.dtype, jnp.floating
        ):
            x = x.astype(dtype)
        return x

    return jax.tree_util.tree_map(conditional_cast, tree)

@jax.jit
def compute_loss(params, rng, batch):
    loss, logits = roberta_classifier.apply(params, rng, batch)
    return loss, logits



In [9]:
def compute_total_loss(split, params, rng):
    total_loss = 0
    num_batches = min(1000, loader.get_number_of_batches(split))
    for i in range(num_batches):
        batch = loader.get_batch(split, i)
        total_loss += compute_loss(_cast_floating_to(params, jnp.float16), rng, batch)[0]
        if total_loss != total_loss:
            print("WAT", i)
            print(1/0)

    return total_loss / num_batches


In [10]:
# print(compute_total_loss("train", loaded_params, None))
# print(compute_total_loss("dev", loaded_params, None))

In [11]:
print(compute_total_loss("train", random_params, None))
print(compute_total_loss("dev", random_params, None))

Compiling the transformer ... (8192,) (1024,)
Compiling the transformer ... (8192,) (1024,)
Compiling the transformer ... (8192,) (256,)
Compiling the transformer ... (8192,) (256,)
Compiling the transformer ... (8192,) (256,)
Compiling the transformer ... (8192,) (256,)
WAT 419


ZeroDivisionError: division by zero

In [16]:
batch = loader.get_batch("train", 419)

loss,logits = compute_loss(_cast_floating_to(random_params, jnp.float16), None, batch)
print(batch['transformer']['tokens'].shape)
print(batch)

(8192,)
{'num_patients': 1, 'patient_ids': array([529,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,

In [None]:
print(batch['patient_ids'])

In [13]:
d = data.get_code_dictionary()

for event in data[529].events:
    if event.value is None:
        print(event, bytes(d[event.code]).decode('utf8'))

Event(start=1934-09-29 00:00:00, code=269) SNOMED/3950001
Event(start=1934-09-29 23:59:00, code=525) Race/5
Event(start=1934-09-29 23:59:00, code=432) Gender/F
Event(start=1934-09-29 23:59:00, code=371) Ethnicity/Not Hispanic
Event(start=2009-05-05 04:00:00, code=580) Visit/IP
Event(start=2009-05-05 04:10:00, code=5642) RxNorm/198443
Event(start=2009-05-05 04:10:00, code=478) RxNorm/313782
Event(start=2009-05-05 04:10:00, code=1379) RxNorm/727517
Event(start=2009-05-05 04:10:00, code=2199) CPT4/82962
Event(start=2009-05-05 04:10:00, code=331) CPT4/85610
Event(start=2009-05-05 04:10:00, code=285) LOINC/2601-3
Event(start=2009-05-05 04:10:00, code=207) LOINC/24321-2
Event(start=2009-05-05 04:10:00, code=980) LOINC/34109-9
Event(start=2009-05-05 04:10:00, code=847) SNOMED/252465000
Event(start=2009-05-05 04:10:00, code=8053) SNOMED/423171007
Event(start=2009-05-05 04:10:00, code=547) SNOMED/42525009
Event(start=2009-05-05 04:10:00, code=12833) RxNorm/310520
Event(start=2009-05-05 04:10:00

In [None]:
print(batch['task'].keys())
print(batch['task']['event_indices'])

codes = batch['task']['event_indices'][:, 1]
rep_bin_indices = batch['task']['event_indices'][:, 0]
bin_indices = rep_bin_indices % 16
rep_indices = rep_bin_indices // 16

In [None]:
total = jnp.stack((rep_indices, bin_indices, codes), axis=-1)
print(total)

In [None]:
print(total[rep_indices == 62])
print(total[rep_indices == 0])

In [None]:
print(batch['transformer']['label_indices'])
print(batch['transformer']['length'])

In [None]:
print(total.shape)
total[10000:20000, :]
print(batch['transformer']['label_indices'][62])
print(logits.shape)
print(jnp.log(logits[62, :, 7954]))
print(jnp.log(logits[0, :, 7954]))

In [None]:
print(batch['transformer']['label_indices'])
print(batch['transformer']['label_indices'].shape)

In [None]:
print(batch['task'].keys())
print(batch['task']['event_indices'])
print(jnp.exp2(batch['task']['sparse_time'][1][:12]))

In [None]:
print(loss)
a = jax.nn.softmax(logits)
a = a.reshape(batch['transformer']['label_indices'].shape[0], -1, 8192)
print(a.shape)
print(jnp.min(a[:, 5, 10]))

0.0001222 * 90

In [None]:
print([a for a in loaded_params.keys() if 'Survival' in a])

print(loaded_params['EHRTransformer/~/SurvivalCLMBRTask'])
print(loaded_params['EHRTransformer/~/SurvivalCLMBRTask/~/linear'])

print(jnp.var(loaded_params['EHRTransformer/~/SurvivalCLMBRTask']['code_weights']))
print(jnp.var(loaded_params['EHRTransformer/~/SurvivalCLMBRTask/~/linear']['w']))
print(1/jnp.sqrt(256))
print(1/jnp.sqrt(796))

In [None]:
def _cast_floating_to(tree: T, dtype: jnp.dtype) -> T:
    def conditional_cast(x):
        if isinstance(x, (np.ndarray, jnp.ndarray)) and jnp.issubdtype(
            x.dtype, jnp.floating
        ):
            x = x.astype(dtype)
        return x

    return jax.tree_util.tree_map(conditional_cast, tree)

def pad_to_32(val, dim):
    remain = jnp.zeros((val.shape[0], dim - val.shape[1]), dtype=val.dtype)
    return jnp.concatenate((remain, val), axis=-1)

def loss_fn(bs, batch):
    a_matrix = jnp.ones(
        (batch['transformer']['label_indices'].shape[0] * len(config['task']['survival_dict']['time_bins']), 1), dtype=bs.dtype)
    b_matrix = bs.reshape(-1, 1)
    
    a_matrix = pad_to_32(a_matrix, 32)
    b_matrix = pad_to_32(b_matrix, 32)

    survival_loss = piton.jax.exp_mean(
        a_matrix, b_matrix, batch['task']["sparse_time"]
    )
    
    event_loss = jnp.log(2) * piton.jax.embedding_dot(
        a_matrix, b_matrix, batch['task']["event_indices"]
    ).sum(dtype=jnp.float32)
    
    event_loss = -event_loss / (
        a_matrix.shape[0] * b_matrix.shape[0]
    )
    
    return survival_loss + event_loss

@jax.value_and_grad
def loss_value_and_grad(params, loss_scale, batch):
    loss = loss_fn(params, batch)

    assert loss.dtype == jnp.float32

    post_scale = loss_scale.scale(loss)
    return post_scale

def apply_optimizer(params, grads, opt_state):
    updates, opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state

@jax.jit
def update(params, loss_scale, opt_state, batch):
    batch_loss, grads = loss_value_and_grad(
        _cast_floating_to(params, jnp.float16), loss_scale, batch
    )

    batch_loss = loss_scale.unscale(batch_loss.astype(jnp.float32))
    grads = loss_scale.unscale(_cast_floating_to(grads, jnp.float32))

    grads_finite = jmp.all_finite(grads)

    loss_scale = loss_scale.adjust(grads_finite)

    new_params, opt_state = jmp.select_tree(
        grads_finite,
        apply_optimizer(params, grads, opt_state),
        (params, opt_state),
    )

    return new_params, opt_state, batch_loss, loss_scale


In [None]:
import optax
import jmp



params = jax.random.normal(jax.random.PRNGKey(123), 
                           shape=(len(config['task']['survival_dict']['codes']),)) - 13

opt = optax.sgd(learning_rate=1e-2)
opt_state = opt.init(params)

loss_scale = jmp.DynamicLossScale(jnp.array(2**15, dtype=jnp.float32))

for i in range(loader.get_number_of_batches("train")):
    batch = loader.get_batch("train", i)
    batch = jax.tree_map(lambda a: jnp.array(a), batch)
    params, opt_state, batch_loss, loss_scale = update(
        params, loss_scale, opt_state, batch)
    if i % 100 == 0:
        print(i, loss_scale, params, batch_loss)

In [None]:
import optax
import jmp


params = jnp.log2(jnp.array(config['task']['survival_dict']['lambdas']))
print(params)

opt = optax.adam(learning_rate=1e-2)
opt_state = opt.init(params)

loss_scale = jmp.DynamicLossScale(jnp.array(2**15, dtype=jnp.float32))

for i in range(loader.get_number_of_batches("train")):
    batch = loader.get_batch("train", i)
    batch = jax.tree_map(lambda a: jnp.array(a), batch)
    params, opt_state, batch_loss, loss_scale = update(
        params, loss_scale, opt_state, batch)
    if i % 100 == 0:
        print(i, loss_scale, params, batch_loss)

In [None]:
import copy
theoretically_optimal_params = copy.deepcopy(loaded_params)
print(theoretically_optimal_params.keys())


print(compute_total_loss("dev", theoretically_optimal_params, None))

In [None]:
theoretically_optimal_params['EHRTransformer/~/SurvivalCLMBRTask']['code_weights'] = jnp.zeros_like(theoretically_optimal_params['EHRTransformer/~/SurvivalCLMBRTask']['code_weights'])

In [None]:
print(compute_total_loss("dev", theoretically_optimal_params, None))

In [None]:
print(theoretically_optimal_params['EHRTransformer/~/SurvivalCLMBRTask']['code_weights'].shape)

better = pad_to_32(jnp.log2(jnp.array(config['task']['survival_dict']['lambdas'])).reshape(-1, 1), dim=256)

print(better)

theoretically_optimal_params['EHRTransformer/~/SurvivalCLMBRTask']['code_weights'] = better

print(compute_total_loss("dev", theoretically_optimal_params, None))

In [None]:
better = pad_to_32(params.reshape(-1, 1), dim=256)

print(better)

theoretically_optimal_params['EHRTransformer/~/SurvivalCLMBRTask']['code_weights'] = better

print(compute_total_loss("dev", theoretically_optimal_params, None))