In [1]:
from tensorflow.compat.v1.io.gfile import GFile
import gin
import os
import jax
import trax
from trax.supervised import inputs

import numpy as onp
import jax.numpy as np

from scipy.special import softmax

from sentencepiece import SentencePieceProcessor

# Setting up data and model

In this notebook, we'll be pushing the limits of just how many tokens we can fit on a single TPU device. The TPUs available in Colab have 8GB of memory per core, and 8 cores. We will set up a Reformer model that can fit a copy of "Crime and Punishment" on each of the 8 TPU cores (over 500,000 tokens per 8GB of memory).

In [2]:
with open('/home/sn/PhD_Dec_2019_Onwards/Experiments/NeuralLM/datas/bodo_raw/two/test.brx.txt','r', encoding='utf-8') as f:
    text=f.read()
    
#text=text[1:3013812]

In [3]:
# Load a BPE vocabulaary with 320 types. This mostly consists of single letters
# and pairs of letters, but it has some common words and word pieces, too.
#!gsutil cp gs://trax-ml/reformer/cp.320.* .

TOKENIZER = SentencePieceProcessor()
TOKENIZER.load('/home/sn/PhD_Dec_2019_Onwards/Experiments/NeuralLM/datas/pretrained_bodo/google-sentpiece/bodo-8k-sp-bpe.model')

True

In [4]:
# Tokenize
IDS = TOKENIZER.EncodeAsIds(text)


In [5]:
IDS = onp.asarray(IDS, dtype=onp.int32)


In [6]:
PAD_AMOUNT = 128 * 1024 - len(IDS)
print("Number of tokens:", IDS.shape[0])

Number of tokens: 114129


In [7]:
print(PAD_AMOUNT,len(IDS))

16943 114129


## As we see above, "Crime and Punishment" has just over half a million tokens with the BPE vocabulary we have selected.
Normally we would have a dataset with many examples, but for this demonstration we fit a language model on the single novel only. We don't want the model to just memorize the dataset by encoding the words in its position embeddings, so at each training iteration we will randomly select how much padding to put before the text vs. after it.
We have 8 TPU cores, so we will separately randomize the amount of padding for each core.

In [8]:
# Set up the data pipeline.
def my_inputs(n_devices):
  while True:
    inputs = []
    mask = []
    pad_amounts = onp.random.choice(PAD_AMOUNT, n_devices)
    for i in range(n_devices):
      inputs.append(onp.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
                            mode='constant'))
      mask.append(onp.pad(onp.ones_like(IDS, dtype=onp.float32),
                          (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
                          mode='constant'))
    inputs = onp.stack(inputs)
    mask = onp.stack(mask)
    yield (inputs, inputs, mask)

print("(device count, tokens per device) = ",
      next(my_inputs(trax.math.device_count()))[0].shape)

(device count, tokens per device) =  (4, 131072)


In [9]:
# Configure hyperparameters.
gin.parse_config("""
import trax.layers
import trax.models
import trax.optimizers
import trax.supervised.inputs
import trax.supervised.trainer_lib

# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.ReformerLM
# Our model will have 6 layers, alternating between the LSH attention proposed
# in the Reformer paper and local attention within a certain context window.
n_layers = 6
attn_type = [
  @TimeBinCausalAttention,
  @LSHCausalAttention,  
  @TimeBinCausalAttention,
  @LSHCausalAttention,
  @TimeBinCausalAttention,
  @LSHCausalAttention,
  ]
share_qk = False  # LSHCausalAttention ignores this flag and always shares q & k
n_heads = 2
attn_kv = 64
dropout = 0.05
n_tokens = 131072

# Parameters for MultifactorSchedule:
# ==============================================================================
MultifactorSchedule.constant = 0.01
MultifactorSchedule.factors = 'constant * linear_warmup * cosine_decay'
MultifactorSchedule.warmup_steps = 100
MultifactorSchedule.steps_per_cycle = 900

# Parameters for Adam:
# ==============================================================================
Adam.weight_decay_rate=0.0
Adam.b1 = 0.86
Adam.b2 = 0.92
Adam.eps = 1e-9

# Parameters for TimeBinCausalAttention:
# ==============================================================================
TimeBinCausalAttention.bin_length = 64
TimeBinCausalAttention.dropout = 0.05
TimeBinCausalAttention.n_bins = None
TimeBinCausalAttention.share_qk = %share_qk

# Parameters for LSHCausalAttention:
# ==============================================================================
LSHCausalAttention.allow_duplicate_attention = False
LSHCausalAttention.attend_across_buckets = True
LSHCausalAttention.rehash_each_round = True
LSHCausalAttention.data_rotation = False
LSHCausalAttention.n_bins = 4096
LSHCausalAttention.n_buckets = 8192
LSHCausalAttention.factorize_hash = [64, 128]
LSHCausalAttention.n_hashes = 1
LSHCausalAttention.one_rng = False
LSHCausalAttention.hard_k = 0
LSHCausalAttention.dropout = 0.0
LSHCausalAttention.drop_for_hash_rate = 0.0
LSHCausalAttention.max_len_for_inference = 2048
LSHCausalAttention.bucket_capacity_for_inference = 64

# Parameters for ReformerLM:
# ==============================================================================
ReformerLM.attention_type = %attn_type
ReformerLM.d_attention_key = %attn_kv
ReformerLM.d_attention_value = %attn_kv
ReformerLM.d_model = 256
ReformerLM.d_ff = 128
ReformerLM.dropout = %dropout
ReformerLM.ff_activation = @trax.layers.Relu
ReformerLM.max_len = %n_tokens
ReformerLM.mode = 'train'
ReformerLM.n_heads = %n_heads
ReformerLM.n_layers = %n_layers
ReformerLM.vocab_size = 8000
ReformerLM.share_qk = %share_qk
ReformerLM.axial_pos_shape = (128, 1024)
ReformerLM.d_axial_pos_embs= (64, 192)
""")

In [10]:
# Set up a Trainer.
output_dir = os.path.expanduser('~/train_dir/2/')
!rm -f ~/train_dir/model.pkl  # Remove old model
trainer = trax.supervised.Trainer(
    model=trax.models.ReformerLM,
    loss_fn=trax.layers.CrossEntropyLoss,
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.MultifactorSchedule,
    inputs=trax.supervised.inputs.Inputs(my_inputs),
    output_dir=output_dir,
    has_weights=True)

Model loaded from /home/sn/train_dir/2/model.pkl at step 1


In [11]:
# Run one training step, to make sure the model fits in memory.
# The first time trainer.train_epoch is called, it will JIT the entire network
# architecture, which takes around 2 minutes. The JIT-compiled model is saved
# so subsequent runs will be much faster than the first.
trainer.train_epoch(n_steps=1, n_eval_steps=1)


Step      2: Ran 1 train steps in 48.17 secs
Step      2: Evaluation
Step      2: train                   accuracy |  0.00010733
Step      2: train                       loss |  9.01715088
Step      2: train         neg_log_perplexity |  9.01715279
Step      2: train weights_per_batch_per_core |  114129.00000000
Step      2: eval                    accuracy |  0.00014676
Step      2: eval                        loss |  9.01628017
Step      2: eval          neg_log_perplexity |  9.01628017
Step      2: eval  weights_per_batch_per_core |  114129.00000000
Step      2: Finished evaluation


In [12]:
# Train for 600 steps total
# The first ~20 steps are slow to run, but after that it reaches steady-state
# speed. This will take at least 30 minutes to run to completion, but can safely
# be interrupted by selecting "Runtime > Interrupt Execution" from the menu.
# The language model won't be exceptionally good when trained for just a few
# steps and with minimal regularization. However, we can still sample from it to
# see what it learns.
trainer.train_epoch(n_steps=9, n_eval_steps=1)
for _ in range(59):
  trainer.train_epoch(n_steps=10, n_eval_steps=1)


Step     11: Ran 9 train steps in 42.75 secs
Step     11: Evaluation
Step     11: train                   accuracy |  0.05587537
Step     11: train                       loss |  7.44203568
Step     11: train         neg_log_perplexity |  7.44203615
Step     11: train weights_per_batch_per_core |  114129.00000000
Step     11: eval                    accuracy |  0.05587537
Step     11: eval                        loss |  7.44202852
Step     11: eval          neg_log_perplexity |  7.44202805
Step     11: eval  weights_per_batch_per_core |  114129.00000000
Step     11: Finished evaluation

Step     21: Ran 10 train steps in 5.88 secs
Step     21: Evaluation
Step     21: train                   accuracy |  0.07501161
Step     21: train                       loss |  6.94360447
Step     21: train         neg_log_perplexity |  6.94360447
Step     21: train weights_per_batch_per_core |  114129.00000000
Step     21: eval                    accuracy |  0.07511237
Step     21: eval               

Step    141: eval  weights_per_batch_per_core |  114129.00000000
Step    141: Finished evaluation

Step    151: Ran 10 train steps in 5.92 secs
Step    151: Evaluation
Step    151: train                   accuracy |  0.74781162
Step    151: train                       loss |  0.84602916
Step    151: train         neg_log_perplexity |  0.84602916
Step    151: train weights_per_batch_per_core |  114129.00000000
Step    151: eval                    accuracy |  0.74928588
Step    151: eval                        loss |  0.84323359
Step    151: eval          neg_log_perplexity |  0.84323353
Step    151: eval  weights_per_batch_per_core |  114129.00000000
Step    151: Finished evaluation

Step    161: Ran 10 train steps in 5.91 secs
Step    161: Evaluation
Step    161: train                   accuracy |  0.79256809
Step    161: train                       loss |  0.67137605
Step    161: train         neg_log_perplexity |  0.67137605
Step    161: train weights_per_batch_per_core |  114129.000

Step    281: eval          neg_log_perplexity |  0.13578185
Step    281: eval  weights_per_batch_per_core |  114129.00000000
Step    281: Finished evaluation

Step    291: Ran 10 train steps in 5.92 secs
Step    291: Evaluation
Step    291: train                   accuracy |  0.94935340
Step    291: train                       loss |  0.12381007
Step    291: train         neg_log_perplexity |  0.12381005
Step    291: train weights_per_batch_per_core |  114129.00000000
Step    291: eval                    accuracy |  0.94912773
Step    291: eval                        loss |  0.12447064
Step    291: eval          neg_log_perplexity |  0.12447064
Step    291: eval  weights_per_batch_per_core |  114129.00000000
Step    291: Finished evaluation

Step    301: Ran 10 train steps in 5.92 secs
Step    301: Evaluation
Step    301: train                   accuracy |  0.95251644
Step    301: train                       loss |  0.11484772
Step    301: train         neg_log_perplexity |  0.11484773

Step    421: eval                        loss |  0.05848213
Step    421: eval          neg_log_perplexity |  0.05848213
Step    421: eval  weights_per_batch_per_core |  114129.00000000
Step    421: Finished evaluation

Step    431: Ran 10 train steps in 5.94 secs
Step    431: Evaluation
Step    431: train                   accuracy |  0.97627461
Step    431: train                       loss |  0.05556519
Step    431: train         neg_log_perplexity |  0.05556519
Step    431: train weights_per_batch_per_core |  114129.00000000
Step    431: eval                    accuracy |  0.97620451
Step    431: eval                        loss |  0.05535147
Step    431: eval          neg_log_perplexity |  0.05535147
Step    431: eval  weights_per_batch_per_core |  114129.00000000
Step    431: Finished evaluation

Step    441: Ran 10 train steps in 5.96 secs
Step    441: Evaluation
Step    441: train                   accuracy |  0.97735238
Step    441: train                       loss |  0.05214088

Step    561: eval                    accuracy |  0.98749000
Step    561: eval                        loss |  0.02909270
Step    561: eval          neg_log_perplexity |  0.02909270
Step    561: eval  weights_per_batch_per_core |  114129.00000000
Step    561: Finished evaluation

Step    571: Ran 10 train steps in 5.94 secs
Step    571: Evaluation
Step    571: train                   accuracy |  0.98787773
Step    571: train                       loss |  0.02842221
Step    571: train         neg_log_perplexity |  0.02842220
Step    571: train weights_per_batch_per_core |  114129.00000000
Step    571: eval                    accuracy |  0.98789525
Step    571: eval                        loss |  0.02837910
Step    571: eval          neg_log_perplexity |  0.02837910
Step    571: eval  weights_per_batch_per_core |  114129.00000000
Step    571: Finished evaluation

Step    581: Ran 10 train steps in 5.96 secs
Step    581: Evaluation
Step    581: train                   accuracy |  0.98845601

# Sample from the model

In [21]:
# As we report in the Reformer paper, increasing the number of hashing rounds
# helps with quality. We can even increase the number of hashing rounds at
# evaluation time only.
gin.parse_config("""LSHCausalAttention.n_hashes = 4""")
model_infer = trax.models.ReformerLM(mode='predict')

In [22]:
# Prepare a jitted copy of the model.
jit_model_infer = trax.layers.base._accelerate(
    model_infer._forward_internal, trax.math.device_count())
# Set up the initial state for sampling.
infer_state = model_infer.new_weights_and_state(
    trax.supervised.trainer_lib.ShapeDtype((1,1), dtype=np.int32))[1]
infer_state = trainer._for_n_devices(infer_state)

In [23]:
def sample(length=2048, prompt=None):
  """Sample from the ReformerLM model"""
  model_weights = trainer._opt_state[0][0]

  # Token id 0 is the equivalent of a "start" token
  cur_inputs = np.zeros((trax.math.device_count(), 1, 1), dtype=np.int32)

  cur_state = infer_state
  rngs = trax.math.random.split(trax.math.random.get_prng(0), trax.math.device_count())
  all_samples = []

  if prompt is not None:
    prompt = np.asarray(
        [TOKENIZER.EncodeAsIds(prompt)] * trax.math.device_count())

  for iteration in range(length):
    logits, cur_state = jit_model_infer(
        cur_inputs,
        model_weights,
        cur_state,
        rngs)
    
    if prompt is not None and iteration < prompt.shape[1]:
      cur_samples = onp.array(prompt[:, iteration], dtype=int)
    else:
      logits = onp.array(logits)[:,0,0,:]
      probs = onp.exp(logits)
      cur_samples = [onp.random.choice(probs.shape[-1], p=probs[i,:])
                     for i in range(probs.shape[0])]
      cur_samples = onp.array(cur_samples, dtype=int)
    all_samples.append(cur_samples)

    cur_inputs = np.array(cur_samples[:,None,None])
  all_samples = onp.stack(all_samples, -1)
  
  return all_samples

In [24]:
# Sample from the Reformer language model, given a prefix.
samples = sample(length=200, prompt="मिथिंगा")
for ids in samples:
  print(TOKENIZER.DecodeIds(ids.tolist()))
  print('\n')

मिथिंगाखौ बे, जागासिनो दं।गोथैसालियाव थांनानै दं। गलिया। गभनि सलʼनि सलʼनि खन्थाइनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनि सलʼनो बेसेबा गोबाव खालामफायो। आरʼनि बिलाइयाव मोन्नायखाय कल्पʼनि बादʼवाव आगान होलाङो सुजुग्राया- "आ हामखांब्लानो सिनायनो हायो। नों बबेवेटारवा, मुथाङै- "आ हांमाʼ गोर्लैबादि आं खामानिखौ गारनो हाया। सुथि मोनखाङो। नाथाय बेखौ सलʼ लेवा गावनि मासियाव जिरायनानै सानो। सुथि मोनाबा बे बरʼवा गावनि मासियाव जिरायनानै सानो। अख्राङाव दोनफिनगासिनो दं।" "नोंथाङा बायहेरायाव ओंखार


मिथिंगाखौ बे समाव अजितनो खिन्थाहैनाय खौरांखौ खिन्थानो मोननाय मुस्रीमा ? शान्तिया। गुमुर गोयै बिफाखौ लानानै इम्पʼनि उनावनोार्सखौ खेवनायनि सिमां ...नों नʼदों। बिफानि आथिङाव गसंना सारजों मोजाङै नोजोर होखानाय मैहुरारी नङाब्ला जाथाया बरʼवारीया सासे नार्सिंब्लाबो गोरोन्थिफोर थांलाय फैलाय खामानि मावग्रा आइजोआ बोहैथिखौ नुखायो नामा ?" "औ।" बुंना हां बोयो। "सारबोदों आं मोजाङै नोजोर होगासिनो दंमोन।