Skip to content

Commit

Permalink
Minimize 3D reshapes for TPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
shawwn committed Dec 18, 2019
1 parent 5b96108 commit 4d766e9
Showing 1 changed file with 39 additions and 16 deletions.
55 changes: 39 additions & 16 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tensorflow as tf
from tensorflow.contrib.training import HParams
import os
import math

def default_hparams():
return HParams(
Expand Down Expand Up @@ -67,7 +68,8 @@ def norm(x, scope, *, axis=-1, epsilon=1e-5, hparams=None):

def split_states(x, n):
"""Reshape the last dimension of x into [n, x.shape[-1]/n]."""
*start, m = shape_list(x)
*start, u, v = shape_list(x)
m = u * v
return tf.reshape(x, start + [n, m//n])

def merge_states(x):
Expand Down Expand Up @@ -109,26 +111,36 @@ def attention_mask(nd, ns, *, dtype):
return tf.cast(m, dtype)


def attn(x, scope, n_state, *, past, hparams):
assert x.shape.ndims == 3 # Should be [batch, sequence, features]
def attn(x, scope, n_state, *, past, hparams, batch_size, seq_length):
assert x.shape.ndims == 2 # Should be [batch*sequence, features]
assert n_state % hparams.n_head == 0
*start, hidden_size = shape_list(x)
num_attention_heads = hparams.n_head
assert(hidden_size % num_attention_heads == 0)
size_per_head = hidden_size // num_attention_heads

if past is not None:
assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]

def split_heads(x):
# From [batch, sequence, features] to [batch, heads, sequence, features]
return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3])
x = tf.reshape(x, [batch_size, seq_length, num_attention_heads, size_per_head])
x = split_states(x, hparams.n_head)
return tf.transpose(x, [0, 2, 1, 3])

def merge_heads(x):
# Reverse of split_heads
return merge_states(tf.transpose(x, [0, 2, 1, 3]))
x = tf.transpose(x, [0, 2, 1, 3])
x = merge_states(x)
x = tf.reshape(x, [batch_size * seq_length, num_attention_heads * size_per_head])
return x

def mask_attn_weights(w):
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
_, _, nd, ns = shape_list(w)
b = attention_mask(nd, ns, dtype=w.dtype)
b = tf.reshape(b, [1, 1, nd, ns])
w = w*b - tf.cast(65500 if w.dtype != tf.float32 else 1e10, w.dtype)*(1-b)
w = w*b - tf.cast(65500 if w.dtype == tf.float16 else 1e10, w.dtype)*(1-b)
return w

def multihead_attn(q, k, v):
Expand All @@ -145,9 +157,8 @@ def multihead_attn(q, k, v):
dtype = hparams.dtype if hparams else tf.float32
with tf.variable_scope(scope, dtype=dtype):
c = conv1d(x, 'c_attn', n_state*3, hparams=hparams)
q, k, v = map(split_heads, tf.split(c, 3, axis=2))
#present = tf.stack([k, v], axis=1)
present = tf.no_op()
q, k, v = map(split_heads, tf.split(c, 3, axis=-1))
present = tf.stack([k, v], axis=1)
if past is not None:
pk, pv = tf.unstack(past, axis=1)
k = tf.concat([pk, k], axis=-2)
Expand All @@ -173,11 +184,11 @@ def dropout(x, pdrop=0.1, train=True):
x = tf.nn.dropout(x, rate=pdrop)
return x

def block(x, scope, *, past, hparams):
def block(x, scope, *, past, hparams, attn, **attn_kws):
dtype = hparams.dtype if hparams else tf.float32
with tf.variable_scope(scope, dtype=dtype):
nx = x.shape[-1].value
a, present = attn(norm(x, 'ln_1', hparams=hparams), 'attn', nx, past=past, hparams=hparams)
a, present = attn(norm(x, 'ln_1', hparams=hparams), 'attn', nx, past=past, hparams=hparams, **attn_kws)
x = x + a
m = mlp(norm(x, 'ln_2', hparams=hparams), 'mlp', nx*4, hparams=hparams)
x = x + m
Expand Down Expand Up @@ -211,16 +222,28 @@ def model(hparams, X, past=None, scope='model', reuse=tf.AUTO_REUSE):
past_length = 0 if past is None else tf.shape(past)[-2]
h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))

## We keep the representation as a 2D tensor to avoid re-shaping it back and
## forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
## the GPU/CPU but may not be free on the TPU, so we want to minimize them to
## help the optimizer.
batch_size, seq_length, hidden_size = shape_list(h)
h = tf.reshape(h, [batch_size * seq_length, hidden_size])

# Transformer
#presents = []
presents = []
pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
assert len(pasts) == hparams.n_layer
every = int(math.sqrt(hparams.n_layer))
#every = 1
#tf.add_to_collection('checkpoints', h)
for layer, past in enumerate(pasts):
h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
h, present = block(h, 'h%d' % layer, past=past, hparams=hparams,
attn=attn, batch_size=batch, seq_length=sequence)
#if layer == 10:
# tf.add_to_collection('checkpoints', h)
#presents.append(present)
#results['present'] = tf.stack(presents, axis=1)
if layer % every == 0:
tf.add_to_collection('checkpoints', h)
presents.append(present)
results['present'] = tf.stack(presents, axis=1)
h = norm(h, 'ln_f', hparams=hparams)

# Language model loss. Do tokens <n predict token n?
Expand Down

0 comments on commit 4d766e9

Please sign in to comment.