From 4d766e9629f28732df615e1dd4e2d3174f3cf703 Mon Sep 17 00:00:00 2001 From: Shawn Presser Date: Sat, 14 Dec 2019 02:31:22 -0600 Subject: [PATCH] Minimize 3D reshapes for TPUs --- src/model.py | 55 +++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/src/model.py b/src/model.py index 2efc868cf..319e7615b 100644 --- a/src/model.py +++ b/src/model.py @@ -2,6 +2,7 @@ import tensorflow as tf from tensorflow.contrib.training import HParams import os +import math def default_hparams(): return HParams( @@ -67,7 +68,8 @@ def norm(x, scope, *, axis=-1, epsilon=1e-5, hparams=None): def split_states(x, n): """Reshape the last dimension of x into [n, x.shape[-1]/n].""" - *start, m = shape_list(x) + *start, u, v = shape_list(x) + m = u * v return tf.reshape(x, start + [n, m//n]) def merge_states(x): @@ -109,26 +111,36 @@ def attention_mask(nd, ns, *, dtype): return tf.cast(m, dtype) -def attn(x, scope, n_state, *, past, hparams): - assert x.shape.ndims == 3 # Should be [batch, sequence, features] +def attn(x, scope, n_state, *, past, hparams, batch_size, seq_length): + assert x.shape.ndims == 2 # Should be [batch*sequence, features] assert n_state % hparams.n_head == 0 + *start, hidden_size = shape_list(x) + num_attention_heads = hparams.n_head + assert(hidden_size % num_attention_heads == 0) + size_per_head = hidden_size // num_attention_heads + if past is not None: assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v] def split_heads(x): # From [batch, sequence, features] to [batch, heads, sequence, features] - return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3]) + x = tf.reshape(x, [batch_size, seq_length, num_attention_heads, size_per_head]) + x = split_states(x, hparams.n_head) + return tf.transpose(x, [0, 2, 1, 3]) def merge_heads(x): # Reverse of split_heads - return merge_states(tf.transpose(x, [0, 2, 1, 3])) + x = tf.transpose(x, [0, 2, 1, 3]) + x = merge_states(x) + x = tf.reshape(x, [batch_size * seq_length, num_attention_heads * size_per_head]) + return x def mask_attn_weights(w): # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. _, _, nd, ns = shape_list(w) b = attention_mask(nd, ns, dtype=w.dtype) b = tf.reshape(b, [1, 1, nd, ns]) - w = w*b - tf.cast(65500 if w.dtype != tf.float32 else 1e10, w.dtype)*(1-b) + w = w*b - tf.cast(65500 if w.dtype == tf.float16 else 1e10, w.dtype)*(1-b) return w def multihead_attn(q, k, v): @@ -145,9 +157,8 @@ def multihead_attn(q, k, v): dtype = hparams.dtype if hparams else tf.float32 with tf.variable_scope(scope, dtype=dtype): c = conv1d(x, 'c_attn', n_state*3, hparams=hparams) - q, k, v = map(split_heads, tf.split(c, 3, axis=2)) - #present = tf.stack([k, v], axis=1) - present = tf.no_op() + q, k, v = map(split_heads, tf.split(c, 3, axis=-1)) + present = tf.stack([k, v], axis=1) if past is not None: pk, pv = tf.unstack(past, axis=1) k = tf.concat([pk, k], axis=-2) @@ -173,11 +184,11 @@ def dropout(x, pdrop=0.1, train=True): x = tf.nn.dropout(x, rate=pdrop) return x -def block(x, scope, *, past, hparams): +def block(x, scope, *, past, hparams, attn, **attn_kws): dtype = hparams.dtype if hparams else tf.float32 with tf.variable_scope(scope, dtype=dtype): nx = x.shape[-1].value - a, present = attn(norm(x, 'ln_1', hparams=hparams), 'attn', nx, past=past, hparams=hparams) + a, present = attn(norm(x, 'ln_1', hparams=hparams), 'attn', nx, past=past, hparams=hparams, **attn_kws) x = x + a m = mlp(norm(x, 'ln_2', hparams=hparams), 'mlp', nx*4, hparams=hparams) x = x + m @@ -211,16 +222,28 @@ def model(hparams, X, past=None, scope='model', reuse=tf.AUTO_REUSE): past_length = 0 if past is None else tf.shape(past)[-2] h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length)) + ## We keep the representation as a 2D tensor to avoid re-shaping it back and + ## forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on + ## the GPU/CPU but may not be free on the TPU, so we want to minimize them to + ## help the optimizer. + batch_size, seq_length, hidden_size = shape_list(h) + h = tf.reshape(h, [batch_size * seq_length, hidden_size]) + # Transformer - #presents = [] + presents = [] pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer assert len(pasts) == hparams.n_layer + every = int(math.sqrt(hparams.n_layer)) + #every = 1 + #tf.add_to_collection('checkpoints', h) for layer, past in enumerate(pasts): - h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) + h, present = block(h, 'h%d' % layer, past=past, hparams=hparams, + attn=attn, batch_size=batch, seq_length=sequence) #if layer == 10: - # tf.add_to_collection('checkpoints', h) - #presents.append(present) - #results['present'] = tf.stack(presents, axis=1) + if layer % every == 0: + tf.add_to_collection('checkpoints', h) + presents.append(present) + results['present'] = tf.stack(presents, axis=1) h = norm(h, 'ln_f', hparams=hparams) # Language model loss. Do tokens