In [1]:
env XLA_PYTHON_CLIENT_MEM_FRACTION=0.25

env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.25


In [2]:
import renn

In [15]:
import tensorflow_datasets as tfds
import jax
import jax.numpy as jnp

from jax.experimental import stax, optimizers

import numpy as np

from renn.losses import multiclass_xent

from tqdm import tqdm
from functools import partial
from renn import serialize

In [4]:
vocab_file='./data/vocab/ag_news.vocab'

seq_length = 160
num_classes = 3

def filter_fn(item):
    return item['labels'] < num_classes

# Load data
train_dset = renn.data.ag_news('train', vocab_file, sequence_length=seq_length, data_dir='./data', filter_fn=filter_fn)
test_dset = renn.data.ag_news('test', vocab_file, sequence_length=seq_length, data_dir='./data', filter_fn=filter_fn)

# Load vocab
with open(vocab_file, 'r') as f:
    vocab = f.readlines()
vocab_size = len(vocab)

example = next(iter(train_dset))

Instructions for updating:
`tf.batch_gather` is deprecated, please use `tf.gather` with `batch_dims=-1` instead.


Instructions for updating:
`tf.batch_gather` is deprecated, please use `tf.gather` with `batch_dims=-1` instead.


In [5]:
def SequenceSum():
    def init_fun(_, input_shape):
        return (input_shape[0], input_shape[2]), ()
    def apply_fun(_, inputs, **kwargs):
        return jnp.sum(inputs, axis=1)
    return init_fun, apply_fun

In [6]:
emb_size = 32

input_shape = (-1, seq_length)
l2_pen = 0

# Linear model
init_fun, apply_fun = stax.serial(
    renn.embedding(vocab_size, emb_size),
    SequenceSum(),
    stax.Dense(num_classes),
    )

# Initialize
key = jax.random.PRNGKey(0)
output_shape, initial_params = init_fun(key, input_shape)

# Hack to set the embedding to 0 for 0
emb = initial_params[0]
new_emb = np.array(emb)
new_emb[0] = np.zeros(emb_size)
initial_params = [jnp.array(new_emb), *initial_params[1:]]

# Loss
def xent(params, batch):
    logits = apply_fun(params, batch['inputs'])
    data_loss = multiclass_xent(logits, batch['labels'])
    reg_loss = l2_pen * renn.norm(params)
    return data_loss + reg_loss

f_df = jax.value_and_grad(xent)

# Accuracy
@jax.jit
def accuracy(params, batch):
    logits = apply_fun(params, batch['inputs'])
    predictions = jnp.argmax(logits, axis=1)
    return jnp.mean(predictions == batch['labels'])



In [7]:
learning_rate = optimizers.exponential_decay(2e-3, 1000, 0.8)
init_opt, update_opt, get_params = optimizers.adam(learning_rate)

state = init_opt(initial_params)
losses = []

@jax.jit
def step(k, opt_state, batch):
    params = get_params(opt_state)
    loss, gradients = f_df(params, batch)
    new_state = update_opt(k, gradients, opt_state)
    return new_state, loss

def test_acc(params):
    return jnp.array([accuracy(params, batch) for batch in tfds.as_numpy(test_dset)])

In [8]:
for epoch in range(1):
    print('=====================================')
    print(f'== Epoch #{epoch}')
    p = get_params(state)
    acc = np.mean(test_acc(p))
    print(f'== Test accuracy: {100. * acc:0.2f}%')
    print('=====================================')
    
    for batch in tfds.as_numpy(train_dset):
        k = len(losses)
        state, loss = step(k, state, batch)
        losses.append(loss)

        if k % 100 == 0:
            p = get_params(state)
            print(f'[step {k}]\tLoss: {np.mean(losses[k-100:k]):0.4f}', flush=True)
            
print('=====================================')
print(f'== Epoch #{epoch}')
p = get_params(state)
acc = np.mean(test_acc(p))
print(f'== Test accuracy: {100. * acc:0.2f}%')
print('=====================================')

== Epoch #0
== Test accuracy: 32.57%
[step 0]	Loss: nan


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


[step 100]	Loss: 0.4913
[step 200]	Loss: 0.2279
[step 300]	Loss: 0.1965
[step 400]	Loss: 0.2071
[step 500]	Loss: 0.1950
[step 600]	Loss: 0.1865
[step 700]	Loss: 0.1836
[step 800]	Loss: 0.1751
[step 900]	Loss: 0.1717
[step 1000]	Loss: 0.1686
[step 1100]	Loss: 0.1725
[step 1200]	Loss: 0.1511
[step 1300]	Loss: 0.1658
[step 1400]	Loss: 0.1638
== Epoch #1
== Test accuracy: 94.79%
[step 1500]	Loss: 0.1131
[step 1600]	Loss: 0.0870
[step 1700]	Loss: 0.0777
[step 1800]	Loss: 0.0867
[step 1900]	Loss: 0.0832
[step 2000]	Loss: 0.0772
[step 2100]	Loss: 0.0848
[step 2200]	Loss: 0.0755
[step 2300]	Loss: 0.0857
[step 2400]	Loss: 0.0765
[step 2500]	Loss: 0.0867
[step 2600]	Loss: 0.0710
[step 2700]	Loss: 0.0775
[step 2800]	Loss: 0.0890
== Epoch #2
== Test accuracy: 94.17%
[step 2900]	Loss: 0.0791
[step 3000]	Loss: 0.0526
[step 3100]	Loss: 0.0499
[step 3200]	Loss: 0.0475
[step 3300]	Loss: 0.0472
[step 3400]	Loss: 0.0429
[step 3500]	Loss: 0.0571
[step 3600]	Loss: 0.0433
[step 3700]	Loss: 0.0461
[step 3800

In [9]:
params = get_params(state)

In [10]:
def delta_logit(vocab_index, params):
    test_batch = jnp.zeros((2, 160), dtype = jnp.int64)
    test_batch = jax.ops.index_update(test_batch, jax.ops.index[1,0], vocab_index)
    
    logits = apply_fun(params, test_batch)
    return logits[1] - logits[0]

In [14]:
delta_logit(151, params)

DeviceArray([-0.46819264, -1.2719698 ,  1.8476232 ], dtype=float32)

In [21]:
with open('test','wb') as f:
    serialize.dump(np.array(params), f)

ValueError: Object and structured dtypes are not supported.