In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.45"
os.environ["JAX_NUMPY_RANK_PROMOTION"] = "raise"

In [2]:

import piton.jax
import jax.numpy as jnp
import jax.lax as lax
import haiku as hk
import haiku as hk
from io import BytesIO
from functools import lru_cache

import joblib
import requests

from transformers import RobertaTokenizer
from jax import jit
from jax.random import PRNGKey
import numpy as np
import jax


  from .autonotebook import tqdm as notebook_tqdm
2022-11-10 17:58:24.632259: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-10 17:58:25.634494: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-11-10 17:58:25.634613: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [3]:
# We'll make use of these again later as a means to check our implementation
huggingface_tokenizer = RobertaTokenizer.from_pretrained('/local-scratch/nigam/projects/ethanid/roberta_saved')

In [4]:
class Embedding(hk.Module):
    """
    Embeds tokens and positions into an array of shape [n_batch, n_seq, n_hidden]
    """
    def __init__(self, config):
        super().__init__("Embedding")
        self.config = config

    def __call__(self, token_ids, training=False):
        """
        token_ids: ints of shape (batch, n_seq)
        """
        # We have to flatten our tokens before passing them to the hk.Embed module,
        # as arrays with more than one dimension are interpreted as multi-dimensional indexes
        flat_token_ids = jnp.reshape(token_ids, [token_ids.shape[0] * token_ids.shape[1]])
        flat_token_embeddings = hk.Embed(
            vocab_size=huggingface_tokenizer.vocab_size,
            embed_dim=self.config['hidden_size'],
            w_init=hk.initializers.TruncatedNormal(stddev=1 / np.sqrt(self.config['hidden_size']))
        )(flat_token_ids)
        
        # After we've embedded our token IDs, we reshape to recover our batch dimension
        token_embeddings = jnp.reshape(
            flat_token_embeddings, 
            [token_ids.shape[0], token_ids.shape[1], self.config['hidden_size']]
        )
        
        # Combine our token embeddings with a set of learned positional embeddings
        return token_embeddings + PositionEmbeddings(self.config)()
        
class PositionEmbeddings(hk.Module):
    """
    A position embedding of shape [n_seq, n_hidden]
    """
    def __init__(self, config):
        super().__init__("PosEmb")
        self.config = config
        # The Roberta position embeddings are offset by 2
        self.offset = 2

    def __call__(self):
        position_weights = hk.get_parameter(
            "position_embeddings", 
            (self.config['max_length'], self.config['hidden_size']),
            init=hk.initializers.TruncatedNormal(stddev=1 / np.sqrt(self.config['hidden_size']))
        )
        
        return jnp.expand_dims(position_weights, 0)

In [5]:
class TransformerBlock(hk.Module):

    def __init__(self, config, layer_num):
        super().__init__("TransformerBlock" + str(layer_num))
        self.config = config
        self.n = layer_num
        
        self.norm = hk.LayerNorm(-1, True, True)
        self.input_proj = hk.Linear(
            output_size=3 * self.config['hidden_size'] + self.config['intermediate_size'],
        )
        self.output_proj = hk.Linear(
            self.config['hidden_size'],
            w_init=hk.initializers.TruncatedNormal(stddev=2 / (self.config['n_layers'] * jnp.sqrt(self.config['hidden_size']))))

    def __call__(self, x, training=False):
        x = self.norm(x)
        
        middle = self.input_proj(x)
        
        assert (x.shape[2] == self.config['hidden_size'])
        
        head_size = self.config['hidden_size'] // self.config['n_heads']
        
        q, k, v, ff = jnp.split(middle, [i * self.config['hidden_size'] for i in range(1, 4)], axis=-1)
        
        def move_to_batch(val):
            with_head = val.reshape((x.shape[0], x.shape[1], self.config['n_heads'], head_size))
            with_head_at_start = with_head.transpose((0, 2, 1, 3))
            return with_head_at_start.reshape((x.shape[0] * self.config['n_heads'], x.shape[1], head_size))
        
        q = move_to_batch(q)
        k = move_to_batch(k)
        v = move_to_batch(v)
        
        attn = piton.jax.local_attention(q, k, v, x.shape[1])
        
        def move_out_of_batch(val):
            with_head = val.reshape((x.shape[0], self.config['n_heads'], x.shape[1], head_size))
            with_head_at_start = with_head.transpose((0, 2, 1, 3))
            return with_head_at_start.reshape((x.shape[0], x.shape[1], x.shape[2]))
    
        shaped_attn = move_out_of_batch(attn)
        
        ff = jax.nn.gelu(ff)
        
        combined = jnp.concatenate((shaped_attn, ff), axis=-1)
        
        return x + self.output_proj(combined)

In [6]:
class RobertaFeaturizer(hk.Module):
    def __init__(self, config, *args, **kwargs):
        super().__init__(name="Transformer")
        self.config = config
        self.norm = hk.LayerNorm(-1, True, True)
        
        self.layer_transform = hk.transform(lambda a, training: TransformerBlock(config, 0)(a, training))
        self.lifted_params = [hk.lift(self.layer_transform.init, name=f'loop_{i}') for i in range(self.config['n_layers'])]

    
    def __call__(self, token_ids, training=False):
        x = Embedding(self.config)(token_ids, training=training)
        # for layer_num in range(self.config['n_layers']):
        #     print(layer_num)
        #     x = TransformerBlock(self.config, layer_num)(x, training=training)
        
        init_rng = hk.next_rng_key() if hk.running_init() else None
        
        all_params = [lifted(init_rng, x, training) for lifted in self.lifted_params]
        flattened = [jax.tree_util.tree_flatten(a) for a in all_params]
        all_flat, all_defs = zip(*flattened)

        assert all(all_defs[0] == a for a in all_defs)

        all_stacked = [jnp.stack(tuple(a[i] for a in all_flat)) for i in range(len(all_flat[0]))]
        all_stacked_tree = jax.tree_util.tree_unflatten(all_defs[0], all_stacked)

        process = lambda v, params: (self.layer_transform.apply(params, None, v, training), None)

        final_x = jax.lax.scan(process, x, all_stacked_tree)[0]

        return self.norm(final_x)


In [7]:
config = {
    'max_length': 512,
    'embed_dropout_rate': 0.1,
    'fully_connected_drop_rate': 0.1,
    'attention_drop_rate': 0.1,
    'hidden_size': 768,
    'intermediate_size': 3072,
    'n_heads': 12,
    'n_layers': 6,
    'mask_id': 1,
    'weight_stddev': 0.02,

    # For use later in finetuning
    'n_classes': 2,
    'classifier_drop_rate': 0.1,
    'learning_rate': 1.2e-5,
    'max_grad_norm': 1.0,
    'l2': 0,
    'n_epochs': 5,
    'batch_size': 4
}

In [8]:
class RobertaClassifier(hk.Module):

    def __init__(self, config, *args, **kwargs):
        super().__init__(name="Classifier")
        self.config = config
    
    def __call__(self, token_ids, training=False):
        sequence_features = RobertaFeaturizer(self.config)(token_ids=token_ids, training=training)
        
        # Our classifier representation is just the output state of our first token
        clf_state = sequence_features[:,-1,:]
        
        if training:
            clf_state = hk.dropout(
                rng=hk.next_rng_key(),
                rate=self.config['classifier_drop_rate'],
                x=clf_state
            )
        
        # We project down from our hidden dimension to n_classes and use this as our softmax logits
        clf_logits = hk.Linear(
            output_size=self.config['n_classes'],
            w_init=hk.initializers.TruncatedNormal(self.config['weight_stddev'])
        )(clf_state)
        
        return clf_logits


Let's plug in a real dataset to try it out.  As much as I dislike the trope of testing text classifiers on sentiment analysis, we'll be using the IMDB Sentiment dataset from tensorflow datasets because it's already packaged up neatly for us. 



In [9]:
import tensorflow_datasets as tfds

def load_dataset(split, training, batch_size, n_epochs=1, n_examples=None):
    """Loads the dataset as a generator of batches."""
    ds = tfds.load("imdb_reviews", split=f"{split}[:{n_examples}]").cache().repeat(n_epochs)
    if training:
        ds = ds.shuffle(10 * batch_size, seed=0)
    ds = ds.batch(batch_size)
    return tfds.as_numpy(ds)

n_examples = 25000
train = load_dataset("train", training=True, batch_size=4, n_epochs=config['n_epochs'], n_examples=n_examples)

2022-11-10 17:58:27.813379: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or directory
2022-11-10 17:58:27.813545: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublasLt.so.11'; dlerror: libcublasLt.so.11: cannot open shared object file: No such file or directory
2022-11-10 17:58:27.816157: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusolver.so.11'; dlerror: libcusolver.so.11: cannot open shared object file: No such file or directory
2022-11-10 17:58:27.816223: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusparse.so.11'; dlerror: libcusparse.so.11: cannot open shared object file: No such file or directory
2022-11-10 17:58:27.816270: W tensorflow/stream_executor/platform/default/dso_loader

We'll add in an `encode_batch` utility to make calling the huggingface tokenizer more concise, transformer our new `RobertaClassifier` module into a pure function, and initialize our model state.

In [10]:
import optax

def roberta_classification_fn(batch_token_ids, training):
    model = RobertaClassifier(config)(
        jnp.asarray(batch_token_ids), 
        training=training
    )
    return model


def encode_batch(batch_text):
    # Accept either utf-8 encoded bytes or unicode
    batch_text = [
        text.decode('utf-8') if isinstance(text, bytes) else text 
        for text in batch_text
    ]
    
    # Use huggingface's tokenizer to convert from raw text to integer token ids
    token_ids = huggingface_tokenizer.batch_encode_plus(
        batch_text, 
        pad_to_max_length=True, 
        max_length=config['max_length'],
    )['input_ids']
    return np.asarray(token_ids)



import jmp
policy = jmp.get_policy('params=float32,compute=float16,output=float16')
hk.mixed_precision.set_policy(RobertaClassifier, policy)

# Purify our RobertaClassifier through the use of hk.transform and initialize our classifier
rng = jax.random.PRNGKey(42)
roberta_classifier = hk.transform(roberta_classification_fn, apply_rng=True)
params = roberta_classifier.init(
    rng, 
    batch_token_ids=encode_batch(['Sample text', 'Sample text']), 
    training=True
)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


[[[-1.157    0.3704   0.0762  ... -0.826   -0.3176   0.447  ]
  [-0.432    0.609   -1.035   ...  1.516    1.522   -1.608  ]
  [ 0.989   -0.11383 -0.553   ...  1.135   -1.367   -0.5703 ]
  ...
  [-1.836    1.444   -1.603   ...  0.4905  -1.285   -0.404  ]
  [ 0.1232   0.703   -2.307   ...  0.9287  -2.137    0.3118 ]
  [-0.369    2.076    0.704   ... -0.646   -1.657    1.457  ]]

 [[-1.157    0.3704   0.0762  ... -0.826   -0.3176   0.447  ]
  [-0.432    0.609   -1.035   ...  1.516    1.522   -1.608  ]
  [ 0.989   -0.11383 -0.553   ...  1.135   -1.367   -0.5703 ]
  ...
  [-1.836    1.444   -1.603   ...  0.4905  -1.285   -0.404  ]
  [ 0.1232   0.703   -2.307   ...  0.9287  -2.137    0.3118 ]
  [-0.369    2.076    0.704   ... -0.646   -1.657    1.457  ]]] True
(2, 512, 768) (2, 512, 3072)
[[[-1.157    0.3704   0.0762  ... -0.826   -0.3176   0.447  ]
  [-0.432    0.609   -1.035   ...  1.516    1.522   -1.608  ]
  [ 0.989   -0.11383 -0.553   ...  1.135   -1.367   -0.5703 ]
  ...
  [-1.836    1

In [11]:
print(params.keys())

dict_keys(['Classifier/Transformer/Embedding/embed', 'Classifier/Transformer/Embedding/PosEmb', 'Classifier/Transformer/~/loop_0/TransformerBlock0/~/layer_norm', 'Classifier/Transformer/~/loop_0/TransformerBlock0/~/linear', 'Classifier/Transformer/~/loop_0/TransformerBlock0/~/linear_1', 'Classifier/Transformer/~/loop_1/TransformerBlock0/~/layer_norm', 'Classifier/Transformer/~/loop_1/TransformerBlock0/~/linear', 'Classifier/Transformer/~/loop_1/TransformerBlock0/~/linear_1', 'Classifier/Transformer/~/loop_2/TransformerBlock0/~/layer_norm', 'Classifier/Transformer/~/loop_2/TransformerBlock0/~/linear', 'Classifier/Transformer/~/loop_2/TransformerBlock0/~/linear_1', 'Classifier/Transformer/~/loop_3/TransformerBlock0/~/layer_norm', 'Classifier/Transformer/~/loop_3/TransformerBlock0/~/linear', 'Classifier/Transformer/~/loop_3/TransformerBlock0/~/linear_1', 'Classifier/Transformer/~/loop_4/TransformerBlock0/~/layer_norm', 'Classifier/Transformer/~/loop_4/TransformerBlock0/~/linear', 'Classif

In [12]:
def loss(params, loss_scale, rng, batch_token_ids, batch_labels):
    logits = roberta_classifier.apply(params, rng, batch_token_ids, training=True)
    labels = hk.one_hot(batch_labels, config['n_classes'])
    softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
    softmax_xent /= labels.shape[0]
    softmax_xent = loss_scale.scale(softmax_xent)
    return softmax_xent

@jax.jit
def accuracy(params, rng, batch_token_ids, batch_labels):
    predictions = roberta_classifier.apply(params, rng, batch_token_ids, training=False)
    return jnp.mean(jnp.argmax(predictions, axis=-1) == batch_labels)

def apply_optimizer(params, grads, opt_state):
    updates, opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state

from jax import debug
from typing import TypeVar

T = TypeVar("T")

def _cast_floating_to(tree: T, dtype: jnp.dtype) -> T:
  def conditional_cast(x):
    if (isinstance(x, (np.ndarray, jnp.ndarray)) and
        jnp.issubdtype(x.dtype, jnp.floating)):
      x = x.astype(dtype)
    return x
  return jax.tree_util.tree_map(conditional_cast, tree)


@jax.jit
def update(params, loss_scale, rng, opt_state, batch_token_ids, batch_labels):
    batch_loss, grads = jax.value_and_grad(loss)(params, loss_scale, rng, batch_token_ids, batch_labels)
    batch_loss = loss_scale.unscale(batch_loss)
    grads = loss_scale.unscale(_cast_floating_to(grads, jnp.float32))
    
    grads_finite = jmp.all_finite(grads)
    
    loss_scale = loss_scale.adjust(grads_finite)
    
    new_params, opt_state = jmp.select_tree(grads_finite, apply_optimizer(params, grads, opt_state), (params, opt_state))
    
    return new_params, opt_state, batch_loss, loss_scale



In [13]:
def make_lr_schedule(warmup_percentage, total_steps):
    def lr_schedule(step):
        percent_complete = step / total_steps
        before_peak = jax.lax.convert_element_type(
            (percent_complete <= warmup_percentage),
            np.float32
        )
        scale = (
            (before_peak * (percent_complete / warmup_percentage) + (1 - before_peak))
            * (1 - percent_complete)
        )
        return scale
    return lr_schedule


total_steps = config['n_epochs'] * (n_examples // config['batch_size'])
print(total_steps)
lr_schedule = make_lr_schedule(warmup_percentage=0.1, total_steps=total_steps)
opt = optax.chain(
    optax.clip_by_global_norm(config['max_grad_norm']),
    optax.adam(learning_rate=config['learning_rate']),
    optax.scale_by_schedule(lr_schedule)
)
opt_state = opt.init(params)


31250


In [14]:
def measure_current_performance(params, n_examples=None, splits=('train', 'test')):
    # Load our training evaluation and test evaluation splits 
    if 'train' in splits:
        train_eval = load_dataset("train", training=False, batch_size=25, n_examples=n_examples)
        # Compute mean train accuracy
        train_accuracy = np.mean([
            accuracy(
                params, 
                rng, 
                encode_batch(train_eval_batch['text']), 
                train_eval_batch['label']
            )
            for train_eval_batch in train_eval
        ])
        print(f"\t Train acc: {train_accuracy:.3f}")
    
    if 'test' in splits:
        test_eval = load_dataset("test", training=False, batch_size=25, n_examples=n_examples)
        # Compute mean test accuracy
        test_accuracy = np.mean([
            accuracy(
                params, 
                rng,
                encode_batch(test_eval_batch['text']), 
                test_eval_batch['label'],
            )
            for test_eval_batch in test_eval
        ])
        print(f"\t Test accuracy: {test_accuracy:.3f}")


In [15]:
loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15), min_loss_scale=np.array(0.2))
# loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2))
print(loss_scale)

import time
total_batch = jnp.array(0)

for step, next_batch in enumerate(train):
    if step % 100 == 0:
        print(total_batch/ 100)
        total_batch *= 0
        print(f"[Step {step}]")
        
    if step % 1000 == 1:
        start = time.time()
        
    if step % 1000 == 999:
        end = time.time()
        print(end - start)
    
    # print(f"[Step {step}]")
    if step % 1000 == 0 and step != 0:
        print(loss_scale)
        measure_current_performance(params, n_examples=100)
        
    # print(loss_scale)

    # Perform adam update
    batch_token_ids = jnp.array(encode_batch(next_batch['text']))
    # print(batch_token_ids)
    batch_labels = jnp.array(next_batch['label'])
    # print(batch_labels)
    params, opt_state, batch_loss, loss_scale = update(
        params, loss_scale, rng, opt_state, batch_token_ids, batch_labels
    )
    total_batch += batch_loss


DynamicLossScale(loss_scale=DeviceArray(32768., dtype=float16), counter=array(0, dtype=int32), period=2000, factor=2, min_loss_scale=array(0.2))
0.0
[Step 0]
HELP ME (4, 512, 768)
Traced<ShapedArray(float16[4,512,768])>with<DynamicJaxprTrace(level=3/1)> True
(4, 512, 768) (4, 512, 3072)


  leaves = jax.tree_leaves(tree)


0.77914125
[Step 100]
0.7014087
[Step 200]
0.7022821
[Step 300]
0.7062738
[Step 400]
0.7063592
[Step 500]
0.7179193
[Step 600]
0.7021692
[Step 700]
0.71631926
[Step 800]
0.7077234
[Step 900]
42.240750312805176
0.70462644
[Step 1000]
DynamicLossScale(loss_scale=DeviceArray(8192., dtype=float16), counter=DeviceArray(319, dtype=int32), period=2000, factor=2, min_loss_scale=array(1, dtype=int32))
HELP ME (25, 512, 768)
Traced<ShapedArray(float16[25,512,768])>with<DynamicJaxprTrace(level=1/1)> False
(25, 512, 768) (25, 512, 3072)
	 Train acc: 0.510
	 Test accuracy: 0.440
0.720588
[Step 1100]
0.7038684
[Step 1200]
0.72539735
[Step 1300]
0.710957
[Step 1400]
0.70910126
[Step 1500]
0.6912529
[Step 1600]
0.66272885
[Step 1700]
0.6441449
[Step 1800]
0.6667353
[Step 1900]
41.91590690612793
0.6545172
[Step 2000]
DynamicLossScale(loss_scale=DeviceArray(4096., dtype=float16), counter=DeviceArray(43, dtype=int32), period=2000, factor=2, min_loss_scale=array(1, dtype=int32))
	 Train acc: 0.590
	 Test 

In [16]:
measure_current_performance(params, n_examples=25000, splits='test')

	 Test accuracy: 0.861


In [17]:
blah = grads['Classifier/Transformer/Embedding/embed']
foo = params['Classifier/Transformer/Embedding/embed']
print(foo)
print(blah)

NameError: name 'grads' is not defined