In [1]:
import piton.extension.dataloader
import piton.models.transformer

data_path = '/local-scratch/nigam/projects/ethanid/piton_1_extract'

dictionary_path =  "/local-scratch/nigam/projects/ethanid/piton/native/results/dict_entries"

import piton.datasets

data = piton.datasets.PatientDatabase(data_path)

import json

config = {
    "seed": 97,
    "vocab_size": 50000,
    "dictionary": json.load(open(dictionary_path)),
    "max_size": 13,
    "splits": [["train", 0, 70], ["dev", 70, 85], ["test", 85, 100]],
    
    'hidden_size': 768,
    'intermediate_size': 3072,
    'n_heads': 12,
    'n_layers': 0,
    
    'n_classes': 2,
    'learning_rate': 1.2e-5,
    'max_grad_norm': 1.0,
    'l2': 0,
}

with open("trash/config.json", 'w') as f:
    json.dump(config, f)

loader = piton.extension.dataloader.BatchCreator(data_path, "trash/config.json")

print("Ready to go!")
print(loader.get_number_of_batches("train"))

Ready to go!
1552


In [2]:
loader.get_batch("train", 1)

{'patient_ids': array([22860,  2650, 15714, 22997,  2287, 11128, 16926, 12311, 25280,
        24419,   416,  8083,  8897,  2725, 21079, 23524,   118,  2586,
        19853,  4723, 26547,  7062, 15185,  2363,  7663, 13579, 17388,
        19857,  9094, 19979, 19186, 22020], dtype=uint32),
 'tokens': array([[   1,    0,    5, ...,    0,    0,    0],
        [   1,    5,   53, ...,    0,    0,    0],
        [   1,    5,  245, ...,    0,    0,    0],
        ...,
        [   1,   13,    4, ...,    0,    0,    0],
        [   1,  245,    4, ...,    0,    0,    0],
        [   1,    5,  245, ...,  200, 8725,    0]], dtype=uint32),
 'ages': array([[0.0000000e+00, 2.7378234e-03, 2.7378234e-03, ..., 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00],
        [0.0000000e+00, 2.7378234e-03, 1.5097034e+01, ..., 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00],
        [0.0000000e+00, 2.7378234e-03, 2.7378234e-03, ..., 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00],
        ...,
     

In [3]:
import piton.models.transformer
import haiku as hk
import jax
import jax.numpy as jnp

import jmp
policy = jmp.get_policy('params=float32,compute=float16,output=float16')
hk.mixed_precision.set_policy(piton.models.transformer.RobertaClassifier, policy)

def roberta_classification_fn(batch):
    model = piton.models.transformer.RobertaClassifier(config)(batch)
    return model

rng = jax.random.PRNGKey(42)
roberta_classifier = hk.transform(roberta_classification_fn, apply_rng=True)
params = roberta_classifier.init(
    rng, 
    batch={a:jnp.array(b) for a, b in loader.get_batch("train", 1).items()},
)

In [4]:
# print(params)

In [5]:
def loss(params, loss_scale, rng, batch):
    logits = roberta_classifier.apply(params, rng, batch)
    labels = hk.one_hot(batch['batch_labels'], config['n_classes'])
    softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
    softmax_xent /= labels.shape[0]
    softmax_xent = loss_scale.scale(softmax_xent)
    
    # debug.print("Got loss {a} {b}", a=logits, b=labels)
    return softmax_xent

# @jax.jit
def accuracy(params, rng, batch):
    predictions = roberta_classifier.apply(params, rng, batch, training=False)
    return jnp.mean(jnp.argmax(predictions, axis=-1) == batch['batch_labels'])

def apply_optimizer(params, grads, opt_state):
    updates, opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state

from jax import debug
from typing import TypeVar

T = TypeVar("T")

def _cast_floating_to(tree: T, dtype: jnp.dtype) -> T:
  def conditional_cast(x):
    if (isinstance(x, (np.ndarray, jnp.ndarray)) and
        jnp.issubdtype(x.dtype, jnp.floating)):
      x = x.astype(dtype)
    return x
  return jax.tree_util.tree_map(conditional_cast, tree)


@jax.jit
def update(params, loss_scale, rng, opt_state, batch):
    batch_loss, grads = jax.value_and_grad(loss)(params, loss_scale, rng, batch)
    batch_loss = loss_scale.unscale(batch_loss)
    grads = loss_scale.unscale(_cast_floating_to(grads, jnp.float32))
    
    # debug.print("Working through {a}", a=grads)
    
    grads_finite = jmp.all_finite(grads)
    
    loss_scale = loss_scale.adjust(grads_finite)
    
    new_params, opt_state = jmp.select_tree(grads_finite, apply_optimizer(params, grads, opt_state), (params, opt_state))
    
    return new_params, opt_state, batch_loss, loss_scale



In [6]:
import optax 

def make_lr_schedule(warmup_percentage, total_steps):
    def lr_schedule(step):
        percent_complete = step / total_steps
        before_peak = jax.lax.convert_element_type(
            (percent_complete <= warmup_percentage),
            np.float32
        )
        scale = (
            (before_peak * (percent_complete / warmup_percentage) + (1 - before_peak))
            * (1 - percent_complete)
        )
        return scale
    return lr_schedule


total_steps = 1000 # config['n_epochs'] * (n_examples // config['batch_size'])
print(total_steps)
lr_schedule = make_lr_schedule(warmup_percentage=0.1, total_steps=total_steps)
opt = optax.chain(
    optax.clip_by_global_norm(config['max_grad_norm']),
    optax.adam(learning_rate=config['learning_rate']),
    optax.scale_by_schedule(lr_schedule)
)
opt_state = opt.init(params)


1000


In [None]:
import numpy as np

loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15))
# loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2))
print(loss_scale)

import time
total_batch = jnp.array(0)

for step in range(1000):
    if step % 100 == 0:
        print(total_batch/ 100)
        total_batch *= 0
        print(f"[Step {step}]")
        
    if step % 1000 == 1:
        start = time.time()
        
    if step % 1000 == 999:
        end = time.time()
        print(end - start)
    
    # print(f"[Step {step}]")
    if step % 1000 == 0 and step != 0:
        print(loss_scale)
        measure_current_performance(params, n_examples=100)
        
    # print(loss_scale)
    batch = {a:jnp.array(b) for a, b in loader.get_batch("train", step).items()}
    batch['batch_labels'] = jnp.zeros((batch['tokens'].shape[0],))
    print("Ready to go", {a:b.shape for a,b in batch.items()})
    # Perform adam update
    # print(batch_labels)
    params, opt_state, batch_loss, loss_scale = update(
        params, loss_scale, rng, opt_state, batch
    )
    print("Got it!", batch_loss, loss_scale)
    total_batch += batch_loss


DynamicLossScale(loss_scale=DeviceArray(32768., dtype=float16), counter=array(0, dtype=int32), period=2000, factor=2, min_loss_scale=array(1, dtype=int32))
0.0
[Step 0]
Ready to go {'patient_ids': (4,), 'tokens': (4, 2048), 'ages': (4, 2048), 'batch_labels': (4,)}


  leaves = jax.tree_leaves(tree)


Got it! 0.31054688 DynamicLossScale(loss_scale=DeviceArray(16384., dtype=float16), counter=DeviceArray(0, dtype=int32), period=2000, factor=2, min_loss_scale=array(1, dtype=int32))
Ready to go {'patient_ids': (32,), 'tokens': (32, 256), 'ages': (32, 256), 'batch_labels': (32,)}


  leaves = jax.tree_leaves(tree)


Got it! 0.31054688 DynamicLossScale(loss_scale=DeviceArray(8192., dtype=float16), counter=DeviceArray(0, dtype=int32), period=2000, factor=2, min_loss_scale=array(1, dtype=int32))
Ready to go {'patient_ids': (1,), 'tokens': (1, 8192), 'ages': (1, 8192), 'batch_labels': (1,)}


  leaves = jax.tree_leaves(tree)


Got it! 0.31054688 DynamicLossScale(loss_scale=DeviceArray(4096., dtype=float16), counter=DeviceArray(0, dtype=int32), period=2000, factor=2, min_loss_scale=array(1, dtype=int32))
Ready to go {'patient_ids': (4,), 'tokens': (4, 2048), 'ages': (4, 2048), 'batch_labels': (4,)}
Got it! 0.31054688 DynamicLossScale(loss_scale=DeviceArray(2048., dtype=float16), counter=DeviceArray(0, dtype=int32), period=2000, factor=2, min_loss_scale=array(1, dtype=int32))
Ready to go {'patient_ids': (4,), 'tokens': (4, 2048), 'ages': (4, 2048), 'batch_labels': (4,)}
Got it! 0.31054688 DynamicLossScale(loss_scale=DeviceArray(1024., dtype=float16), counter=DeviceArray(0, dtype=int32), period=2000, factor=2, min_loss_scale=array(1, dtype=int32))
Ready to go {'patient_ids': (1,), 'tokens': (1, 8192), 'ages': (1, 8192), 'batch_labels': (1,)}
Got it! 0.31054688 DynamicLossScale(loss_scale=DeviceArray(512., dtype=float16), counter=DeviceArray(0, dtype=int32), period=2000, factor=2, min_loss_scale=array(1, dtype=i