## Set Up

In [None]:
!pip install git+https://github.com/google-research/bigbird.git -q

[K     |████████████████████████████████| 1.2MB 13.4MB/s 
[K     |████████████████████████████████| 3.4MB 53.7MB/s 
[K     |████████████████████████████████| 1.5MB 54.2MB/s 
[K     |████████████████████████████████| 3.8MB 54.5MB/s 
[K     |████████████████████████████████| 706kB 47.9MB/s 
[K     |████████████████████████████████| 5.6MB 47.9MB/s 
[K     |████████████████████████████████| 368kB 53.1MB/s 
[K     |████████████████████████████████| 378kB 53.4MB/s 
[K     |████████████████████████████████| 194kB 53.3MB/s 
[K     |████████████████████████████████| 368kB 50.7MB/s 
[K     |████████████████████████████████| 358kB 51.2MB/s 
[K     |████████████████████████████████| 983kB 48.1MB/s 
[K     |████████████████████████████████| 655kB 47.3MB/s 
[K     |████████████████████████████████| 256kB 58.4MB/s 
[?25h  Building wheel for bigbird (setup.py) ... [?25l[?25hdone
  Building wheel for gunicorn (setup.py) ... [?25l[?25hdone
  Building wheel for bz2file (setup.py) ... 

In [None]:
from bigbird.core import flags
from bigbird.core import modeling
from bigbird.core import utils
from bigbird.summarization import run_summarization
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tensorflow_text as tft
from tqdm import tqdm
import sys

FLAGS = flags.FLAGS
if not hasattr(FLAGS, "f"): flags.DEFINE_string("f", "", "")
FLAGS(sys.argv)

tf.enable_v2_behavior()

## Set options

In [None]:
FLAGS.data_dir = "tfds://scientific_papers/pubmed"
FLAGS.attention_type = "block_sparse"
FLAGS.couple_encoder_decoder = True
FLAGS.max_encoder_length = 2048  # on free colab only lower memory GPU like T4 is available
FLAGS.max_decoder_length = 256
FLAGS.block_size = 64
FLAGS.learning_rate = 1e-5
FLAGS.num_train_steps = 1000
FLAGS.attention_probs_dropout_prob = 0.0
FLAGS.hidden_dropout_prob = 0.0
FLAGS.vocab_model_file = "gpt2"

In [None]:
transformer_config = flags.as_dictionary()

## Define summarization model

In [None]:
from tensorflow.python.ops.variable_scope import EagerVariableStore
container = EagerVariableStore()

In [None]:
with container.as_default():
  model = modeling.TransformerModel(transformer_config)

In [None]:
@tf.function(experimental_compile=True)
def fwd_bwd(features, labels):
  with tf.GradientTape() as g:
    (llh, logits, pred_ids), _ = model(features, target_ids=labels,
                                       training=True)
    loss = run_summarization.padded_cross_entropy_loss(
        logits, labels,
        transformer_config["label_smoothing"],
        transformer_config["vocab_size"])
  grads = g.gradient(loss, model.trainable_weights)
  return loss, llh, logits, pred_ids, grads

## Dataset pipeline

In [None]:
train_input_fn = run_summarization.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        max_decoder_length=FLAGS.max_decoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=True)
dataset = train_input_fn({'batch_size': 2})

[1mDownloading and preparing dataset 4.20 GiB (download: 4.20 GiB, generated: 2.34 GiB, total: 6.54 GiB) to /root/tensorflow_datasets/scientific_papers/pubmed/1.1.1...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…






HBox(children=(FloatProgress(value=0.0, description='Generating splits...', max=3.0, style=ProgressStyle(descr…

HBox(children=(FloatProgress(value=0.0, description='Generating train examples...', max=119924.0, style=Progre…

HBox(children=(FloatProgress(value=0.0, description='Shuffling scientific_papers-train.tfrecord...', max=11992…

HBox(children=(FloatProgress(value=0.0, description='Generating validation examples...', max=6633.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Shuffling scientific_papers-validation.tfrecord...', max=…

HBox(children=(FloatProgress(value=0.0, description='Generating test examples...', max=6658.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Shuffling scientific_papers-test.tfrecord...', max=6658.0…

[1mDataset scientific_papers downloaded and prepared to /root/tensorflow_datasets/scientific_papers/pubmed/1.1.1. Subsequent calls will reuse this data.[0m


  deterministic=is_training)


In [None]:
# inspect at a few examples
for ex in dataset.take(3):
  print(ex)

(<tf.Tensor: shape=(2, 2048), dtype=int32, numpy=
array([[  382, 18450,   387, ...,  1503,  1372,   387],
       [  363,   321,   100, ...,   358,   936,   427]], dtype=int32)>, <tf.Tensor: shape=(2, 256), dtype=int32, numpy=
array([[10913, 21105, 31760,   583,  1228,   490,  3278,   456,  5972,
          430,  1793, 13926,   865,   388,   529,  2151,   321,   100,
          457, 15576,   363,  5509,   387, 49438,   458,   401,   639,
         1368,   391,   321,  9290,  7602, 33085, 43802,   458,   321,
         9290,  7602,   181,  1368, 10913,   388,  8796,   385,  2320,
          743,   113,  2805, 15331,   494,   743,   185,  2314, 17771,
         1159,   633,   118, 50035,   171,  1976,   358,   458,   401,
          639,  1368,   938,   779,   165,  1976,   409,   458,   401,
          639,  1368,   391,   633,  1258, 26615,   167,  1976,   409,
          458,   321,  9290,  7602,   181,  1368,   388, 25346, 20880,
         1852,   391, 25548,   865,   321,   100,  3686, 27056, 

## Check outputs

In [None]:
with container.as_default():
  loss, llh, logits, pred_ids, grads = fwd_bwd(ex[0], ex[1])
print('Loss: ', loss)



Loss:  tf.Tensor(9.549082, shape=(), dtype=float32)


## (Optionally) Load pretrained model

In [None]:
# For training from scratch use
# ckpt_path = 'gs://bigbird-transformer/pretrain/bigbr_base/model.ckpt-0'
# For quick check continue from trained checkpoint
ckpt_path = 'gs://bigbird-transformer/summarization/pubmed/roberta/model.ckpt-300000'
ckpt_reader = tf.compat.v1.train.NewCheckpointReader(ckpt_path)
loaded_weights = []
for v in tqdm(model.trainable_weights, position=0):
  try:
    val = ckpt_reader.get_tensor(v.name[:-2])
  except:
    val = v.numpy()
  loaded_weights.append(val)

model.set_weights(loaded_weights)

100%|██████████| 316/316 [01:20<00:00,  3.93it/s]


## Train

In [None]:
opt = tf.keras.optimizers.Adam(FLAGS.learning_rate)
train_loss = tf.keras.metrics.Mean(name='train_loss')

for i, ex in enumerate(tqdm(dataset.take(FLAGS.num_train_steps), position=0)):
  loss, llh, logits, pred_ids, grads = fwd_bwd(ex[0], ex[1])
  opt.apply_gradients(zip(grads, model.trainable_weights))
  train_loss(loss)
  if i% 10 == 0:
    print('Loss = {} '.format(train_loss.result().numpy()))

  0%|          | 1/1000 [01:50<30:42:31, 110.66s/it]

Loss = 1.363891363143921 


  1%|          | 11/1000 [02:02<1:10:27,  4.27s/it]

Loss = 2.0001139640808105 


  2%|▏         | 21/1000 [02:14<20:59,  1.29s/it]

Loss = 1.827647089958191 


  3%|▎         | 31/1000 [02:26<19:25,  1.20s/it]

Loss = 1.6817251443862915 


  4%|▍         | 41/1000 [02:38<19:12,  1.20s/it]

Loss = 1.58353590965271 


  5%|▌         | 51/1000 [02:50<19:24,  1.23s/it]

Loss = 1.5402835607528687 


  6%|▌         | 61/1000 [03:03<19:22,  1.24s/it]

Loss = 1.493562936782837 


  7%|▋         | 71/1000 [03:15<19:28,  1.26s/it]

Loss = 1.436427116394043 


  8%|▊         | 81/1000 [03:27<19:12,  1.25s/it]

Loss = 1.4210580587387085 


  9%|▉         | 91/1000 [03:40<19:07,  1.26s/it]

Loss = 1.3970165252685547 


 10%|█         | 101/1000 [03:53<19:05,  1.27s/it]

Loss = 1.3981235027313232 


 11%|█         | 111/1000 [04:05<18:55,  1.28s/it]

Loss = 1.3892427682876587 


 12%|█▏        | 121/1000 [04:18<19:00,  1.30s/it]

Loss = 1.3859059810638428 


 13%|█▎        | 131/1000 [04:31<18:50,  1.30s/it]

Loss = 1.3729757070541382 


 14%|█▍        | 141/1000 [04:44<18:17,  1.28s/it]

Loss = 1.3730189800262451 


 15%|█▌        | 151/1000 [04:57<18:24,  1.30s/it]

Loss = 1.3701679706573486 


 16%|█▌        | 161/1000 [05:10<18:18,  1.31s/it]

Loss = 1.3543964624404907 


 17%|█▋        | 171/1000 [05:23<18:09,  1.31s/it]

Loss = 1.342591643333435 


 18%|█▊        | 181/1000 [05:36<17:58,  1.32s/it]

Loss = 1.3314224481582642 


 19%|█▉        | 191/1000 [05:49<17:31,  1.30s/it]

Loss = 1.3309468030929565 


 20%|██        | 201/1000 [06:02<17:38,  1.33s/it]

Loss = 1.3296253681182861 


 21%|██        | 211/1000 [06:15<17:27,  1.33s/it]

Loss = 1.327627182006836 


 22%|██▏       | 221/1000 [06:29<17:08,  1.32s/it]

Loss = 1.3304057121276855 


 23%|██▎       | 231/1000 [06:42<16:51,  1.32s/it]

Loss = 1.3215280771255493 


 24%|██▍       | 241/1000 [06:55<16:45,  1.33s/it]

Loss = 1.3212459087371826 


 25%|██▌       | 251/1000 [07:09<16:45,  1.34s/it]

Loss = 1.323975920677185 


 26%|██▌       | 261/1000 [07:22<16:21,  1.33s/it]

Loss = 1.3262150287628174 


 27%|██▋       | 271/1000 [07:35<16:21,  1.35s/it]

Loss = 1.3242645263671875 


 28%|██▊       | 281/1000 [07:49<16:13,  1.35s/it]

Loss = 1.3198277950286865 


 29%|██▉       | 291/1000 [08:02<15:40,  1.33s/it]

Loss = 1.3146157264709473 


 30%|███       | 301/1000 [08:16<15:47,  1.36s/it]

Loss = 1.313524603843689 


 31%|███       | 311/1000 [08:29<15:21,  1.34s/it]

Loss = 1.31621515750885 


 32%|███▏      | 321/1000 [08:43<15:22,  1.36s/it]

Loss = 1.3117657899856567 


 33%|███▎      | 331/1000 [08:56<15:08,  1.36s/it]

Loss = 1.3117302656173706 


 34%|███▍      | 341/1000 [09:10<14:51,  1.35s/it]

Loss = 1.3127896785736084 


 35%|███▌      | 351/1000 [09:23<14:34,  1.35s/it]

Loss = 1.3157973289489746 


 36%|███▌      | 361/1000 [09:37<14:18,  1.34s/it]

Loss = 1.3148447275161743 


 37%|███▋      | 371/1000 [09:50<14:09,  1.35s/it]

Loss = 1.3141580820083618 


 38%|███▊      | 381/1000 [10:04<14:02,  1.36s/it]

Loss = 1.3128198385238647 


 39%|███▉      | 391/1000 [10:17<13:44,  1.35s/it]

Loss = 1.3187257051467896 


 40%|████      | 401/1000 [10:31<13:28,  1.35s/it]

Loss = 1.3182646036148071 


 41%|████      | 411/1000 [10:44<13:06,  1.34s/it]

Loss = 1.3199938535690308 


 42%|████▏     | 421/1000 [10:58<13:07,  1.36s/it]

Loss = 1.320509910583496 


 43%|████▎     | 431/1000 [11:11<12:46,  1.35s/it]

Loss = 1.3133751153945923 


 44%|████▍     | 441/1000 [11:25<12:31,  1.34s/it]

Loss = 1.312691569328308 


 45%|████▌     | 451/1000 [11:38<12:28,  1.36s/it]

Loss = 1.3118925094604492 


 46%|████▌     | 461/1000 [11:52<12:06,  1.35s/it]

Loss = 1.310736060142517 


 47%|████▋     | 471/1000 [12:05<11:59,  1.36s/it]

Loss = 1.3095834255218506 


 48%|████▊     | 481/1000 [12:19<11:47,  1.36s/it]

Loss = 1.3103281259536743 


 49%|████▉     | 491/1000 [12:33<11:32,  1.36s/it]

Loss = 1.3085415363311768 


 50%|█████     | 501/1000 [12:46<11:15,  1.35s/it]

Loss = 1.3101277351379395 


 51%|█████     | 511/1000 [12:59<10:58,  1.35s/it]

Loss = 1.3103680610656738 


 52%|█████▏    | 521/1000 [13:13<10:51,  1.36s/it]

Loss = 1.316702961921692 


 53%|█████▎    | 531/1000 [13:26<10:37,  1.36s/it]

Loss = 1.3169350624084473 


 54%|█████▍    | 541/1000 [13:40<10:16,  1.34s/it]

Loss = 1.316780924797058 


 55%|█████▌    | 551/1000 [13:54<10:12,  1.36s/it]

Loss = 1.3159213066101074 


 56%|█████▌    | 561/1000 [14:07<09:55,  1.36s/it]

Loss = 1.3140277862548828 


 57%|█████▋    | 571/1000 [14:21<09:38,  1.35s/it]

Loss = 1.3107638359069824 


 58%|█████▊    | 581/1000 [14:34<09:28,  1.36s/it]

Loss = 1.311909556388855 


 59%|█████▉    | 591/1000 [14:48<09:17,  1.36s/it]

Loss = 1.3127955198287964 


 60%|██████    | 601/1000 [15:01<09:04,  1.36s/it]

Loss = 1.3105103969573975 


 61%|██████    | 611/1000 [15:15<08:46,  1.35s/it]

Loss = 1.3099581003189087 


 62%|██████▏   | 621/1000 [15:28<08:35,  1.36s/it]

Loss = 1.3101460933685303 


 63%|██████▎   | 631/1000 [15:42<08:20,  1.36s/it]

Loss = 1.3088980913162231 


 64%|██████▍   | 641/1000 [15:56<08:13,  1.37s/it]

Loss = 1.308201789855957 


 65%|██████▌   | 651/1000 [16:09<07:51,  1.35s/it]

Loss = 1.309259057044983 


 66%|██████▌   | 661/1000 [16:23<07:36,  1.35s/it]

Loss = 1.3082363605499268 


 67%|██████▋   | 671/1000 [16:36<07:30,  1.37s/it]

Loss = 1.3094345331192017 


 68%|██████▊   | 681/1000 [16:50<07:11,  1.35s/it]

Loss = 1.3100460767745972 


 69%|██████▉   | 691/1000 [17:03<06:51,  1.33s/it]

Loss = 1.3075838088989258 


 70%|███████   | 701/1000 [17:17<06:41,  1.34s/it]

Loss = 1.3072311878204346 


 71%|███████   | 711/1000 [17:30<06:35,  1.37s/it]

Loss = 1.306130051612854 


 72%|███████▏  | 721/1000 [17:44<06:12,  1.34s/it]

Loss = 1.3046369552612305 


 73%|███████▎  | 731/1000 [17:57<06:04,  1.36s/it]

Loss = 1.3031822443008423 


 74%|███████▍  | 741/1000 [18:11<05:50,  1.35s/it]

Loss = 1.3019015789031982 


 75%|███████▌  | 751/1000 [18:24<05:37,  1.35s/it]

Loss = 1.3020117282867432 


 76%|███████▌  | 761/1000 [18:38<05:19,  1.34s/it]

Loss = 1.300988793373108 


 77%|███████▋  | 771/1000 [18:51<05:06,  1.34s/it]

Loss = 1.2990926504135132 


 78%|███████▊  | 781/1000 [19:05<04:56,  1.35s/it]

Loss = 1.299412727355957 


 79%|███████▉  | 791/1000 [19:18<04:45,  1.37s/it]

Loss = 1.29827082157135 


 80%|████████  | 801/1000 [19:32<04:28,  1.35s/it]

Loss = 1.2957195043563843 


 81%|████████  | 811/1000 [19:45<04:17,  1.36s/it]

Loss = 1.2915773391723633 


 82%|████████▏ | 821/1000 [19:59<04:03,  1.36s/it]

Loss = 1.2912721633911133 


 83%|████████▎ | 831/1000 [20:12<03:49,  1.36s/it]

Loss = 1.2932581901550293 


 84%|████████▍ | 841/1000 [20:26<03:36,  1.36s/it]

Loss = 1.2915959358215332 


 85%|████████▌ | 851/1000 [20:39<03:22,  1.36s/it]

Loss = 1.292357087135315 


 86%|████████▌ | 861/1000 [20:53<03:10,  1.37s/it]

Loss = 1.2923327684402466 


 87%|████████▋ | 871/1000 [21:07<02:51,  1.33s/it]

Loss = 1.291812539100647 


 88%|████████▊ | 881/1000 [21:20<02:41,  1.36s/it]

Loss = 1.2922062873840332 


 89%|████████▉ | 891/1000 [21:34<02:28,  1.36s/it]

Loss = 1.2905346155166626 


 90%|█████████ | 901/1000 [21:47<02:14,  1.36s/it]

Loss = 1.2910197973251343 


 91%|█████████ | 911/1000 [22:01<02:00,  1.36s/it]

Loss = 1.2887943983078003 


 92%|█████████▏| 921/1000 [22:14<01:47,  1.36s/it]

Loss = 1.2903236150741577 


 93%|█████████▎| 931/1000 [22:28<01:33,  1.35s/it]

Loss = 1.2902488708496094 


 94%|█████████▍| 941/1000 [22:41<01:19,  1.35s/it]

Loss = 1.289626955986023 


 95%|█████████▌| 951/1000 [22:55<01:06,  1.36s/it]

Loss = 1.288147211074829 


 96%|█████████▌| 961/1000 [23:09<00:52,  1.36s/it]

Loss = 1.286420226097107 


 97%|█████████▋| 971/1000 [23:22<00:39,  1.35s/it]

Loss = 1.2832293510437012 


 98%|█████████▊| 981/1000 [23:36<00:25,  1.36s/it]

Loss = 1.2840265035629272 


 99%|█████████▉| 991/1000 [23:49<00:12,  1.34s/it]

Loss = 1.283961534500122 


100%|██████████| 1000/1000 [24:01<00:00,  1.44s/it]


### Print predictions

In [None]:
tokenizer = tft.SentencepieceTokenizer(
        model=tf.io.gfile.GFile(FLAGS.vocab_model_file, "rb").read())

In [None]:
@tf.function(experimental_compile=True)
def fwd_only(features, labels):
  (llh, logits, pred_ids), _ = model(features, target_ids=labels,
                                       training=False)
  return llh, logits, pred_ids

In [None]:
_, _, pred_ids = fwd_only(ex[0], ex[1])



Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))


Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))


In [None]:
print('Article:\n {}\n\n Predicted summary:\n {}\n\n Ground truth summary:\n {}\n\n'.format(
    tokenizer.detokenize(ex[0]),
    tokenizer.detokenize(pred_ids),
    tokenizer.detokenize(ex[1])))

Article:
 [b"the last decade has witnessed significant advances in the surgical management of cervical radiculopathy and myelopathy , to include development of motion - sparing alternatives to traditional anterior cervical diskectomy and fusion ( acdf ) . \xe2\x81\x87 the theoretical benefits of these alternatives , primarily cervical disk arthroplasty ( cda ) , include diminished contiguous level strain , preservation of motion at the affected vertebral segment , and a hypothetical decrease in the development or progression of degenerative disease processes at immediately adjacent levels.1  \xe2\x81\x87  2  \xe2\x81\x87  3 several small randomized controlled trials have suggested that cda may be associated with better neurologic outcomes , fewer revisions , and better overall success when compared with acdf.4  \xe2\x81\x87  5  \xe2\x81\x87  6  \xe2\x81\x87  7  \xe2\x81\x87  8 previous studies have found that cda provides up to an 89% rate of complete preoperative symptom relief , with

## Eval

In [None]:
eval_input_fn = run_summarization.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        max_decoder_length=FLAGS.max_decoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=False)
eval_dataset = eval_input_fn({'batch_size': 2})

  deterministic=is_training)


In [None]:
eval_llh = tf.keras.metrics.Mean(name='eval_llh')

for ex in tqdm(eval_dataset, position=0):
  llh, logits, pred_ids = fwd_only(ex[0], ex[1])
  eval_llh(llh)
print('Log Likelihood = {}'.format(eval_llh.result().numpy()))