In [1]:
import renn

In [37]:
import tensorflow_datasets as tfds
import jax
import jax.numpy as jnp

from jax.experimental import stax, optimizers

import numpy as np

from renn.losses import multiclass_xent

from tqdm import tqdm
from functools import partial

In [65]:
vocab_file='./data/vocab/ag_news.vocab'

seq_length = 160
num_classes = 2

def filter_fn(item):
    return item['labels'] < num_classes

# Load data
train_dset = renn.data.ag_news('train', vocab_file, sequence_length=seq_length, data_dir='./data', filter_fn=filter_fn)
test_dset = renn.data.ag_news('test', vocab_file, sequence_length=seq_length, data_dir='./data', filter_fn=filter_fn)

# Load vocab
with open(vocab_file, 'r') as f:
    vocab = f.readlines()
vocab_size = len(vocab)

example = next(iter(train_dset))

In [66]:
def SequenceSum():
    def init_fun(_, input_shape):
        return (input_shape[0], input_shape[2]), ()
    def apply_fun(_, inputs, **kwargs):
        return jnp.sum(inputs, axis=1)
    return init_fun, apply_fun

In [74]:
emb_size = 32

input_shape = (-1, seq_length)
l2_pen = 0

# Linear model
init_fun, apply_fun = stax.serial(
    renn.embedding(vocab_size, emb_size),
    SequenceSum(),
    stax.Dense(num_classes),
    )

# Initialize
key = jax.random.PRNGKey(0)
output_shape, initial_params = init_fun(key, input_shape)

# Hack to set the embedding to 0 for 0
emb = initial_params[0]
new_emb = np.array(emb)
new_emb[0] = np.zeros(emb_size)
initial_params = [jnp.array(new_emb), *initial_params[1:]]

# Loss
def xent(params, batch):
    logits = apply_fun(params, batch['inputs'])
    data_loss = multiclass_xent(logits, batch['labels'])
    reg_loss = l2_pen * renn.norm(params)
    return data_loss + reg_loss

f_df = jax.value_and_grad(xent)

# Accuracy
@jax.jit
def accuracy(params, batch):
    logits = apply_fun(params, batch['inputs'])
    predictions = jnp.argmax(logits, axis=1)
    return jnp.mean(predictions == batch['labels'])



In [75]:
learning_rate = optimizers.exponential_decay(2e-3, 1000, 0.8)
init_opt, update_opt, get_params = optimizers.adam(learning_rate)

state = init_opt(initial_params)
losses = []

@jax.jit
def step(k, opt_state, batch):
    params = get_params(opt_state)
    loss, gradients = f_df(params, batch)
    new_state = update_opt(k, gradients, opt_state)
    return new_state, loss

def test_acc(params):
    return jnp.array([accuracy(params, batch) for batch in tfds.as_numpy(test_dset)])

In [76]:
for epoch in range(3):
    print('=====================================')
    print(f'== Epoch #{epoch}')
    p = get_params(state)
    acc = np.mean(test_acc(p))
    print(f'== Test accuracy: {100. * acc:0.2f}%')
    print('=====================================')
    
    for batch in tfds.as_numpy(train_dset):
        k = len(losses)
        state, loss = step(k, state, batch)
        losses.append(loss)

        if k % 100 == 0:
            p = get_params(state)
            print(f'[step {k}]\tLoss: {np.mean(losses[k-100:k]):0.4f}', flush=True)
            
print('=====================================')
print(f'== Epoch #{epoch}')
p = get_params(state)
acc = np.mean(test_acc(p))
print(f'== Test accuracy: {100. * acc:0.2f}%')
print('=====================================')

== Epoch #0
== Test accuracy: 53.80%
[step 0]	Loss: nan
[step 100]	Loss: 0.2527
[step 200]	Loss: 0.1102
[step 300]	Loss: 0.1061
[step 400]	Loss: 0.0946
[step 500]	Loss: 0.1012
[step 600]	Loss: 0.0858
[step 700]	Loss: 0.0788
[step 800]	Loss: 0.0933
[step 900]	Loss: 0.0817
== Epoch #1
== Test accuracy: 97.32%
[step 1000]	Loss: 0.0596
[step 1100]	Loss: 0.0344
[step 1200]	Loss: 0.0308
[step 1300]	Loss: 0.0412
[step 1400]	Loss: 0.0392
[step 1500]	Loss: 0.0437
[step 1600]	Loss: 0.0299
[step 1700]	Loss: 0.0300
[step 1800]	Loss: 0.0295
== Epoch #2
== Test accuracy: 97.22%
[step 1900]	Loss: 0.0351
[step 2000]	Loss: 0.0232
[step 2100]	Loss: 0.0149
[step 2200]	Loss: 0.0188
[step 2300]	Loss: 0.0168
[step 2400]	Loss: 0.0236
[step 2500]	Loss: 0.0239
[step 2600]	Loss: 0.0169
[step 2700]	Loss: 0.0112
[step 2800]	Loss: 0.0149
== Epoch #2
== Test accuracy: 96.69%


In [77]:
params = get_params(state)