#### Tutorial adapted from https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html

In [1]:
# uncomment this to enable jax gpu preallocation, might lead to memory issues

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [2]:
from __future__ import unicode_literals, print_function, division
from io import open
import glob
import os
import unicodedata
import string

all_letters = string.ascii_letters + " .,;'-"
n_letters = len(all_letters) + 1 # Plus EOS marker

def findFiles(path): return glob.glob(path)

# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )

# Read a file and split into lines
def readLines(filename):
    with open(filename, encoding='utf-8') as some_file:
        return [unicodeToAscii(line.strip()) for line in some_file]

# Build the category_lines dictionary, a list of lines per category
category_lines = {}
all_categories = []
for filename in findFiles('../names/names/*.txt'):
    category = os.path.splitext(os.path.basename(filename))[0]
    all_categories.append(category)
    lines = readLines(filename)
    category_lines[category] = lines

n_categories = len(all_categories)

if n_categories == 0:
    raise RuntimeError('Data not found. Make sure that you downloaded data '
        'from https://download.pytorch.org/tutorial/data.zip and extract it to '
        'the current directory.')

print('# categories:', n_categories, all_categories)
print(unicodeToAscii("O'Néàl"))

# categories: 18 ['Portuguese', 'Czech', 'Korean', 'Arabic', 'English', 'Russian', 'German', 'Spanish', 'Vietnamese', 'Polish', 'Irish', 'Japanese', 'French', 'Scottish', 'Greek', 'Chinese', 'Italian', 'Dutch']
O'Neal


In [3]:
from typing import Sequence, List, Tuple, Optional

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

In [4]:
from functools import partial

class TargetRNN(nn.Module):

    output_size: int

    @nn.compact
    def __call__(self, hidden, inp_tuple, train: bool = False):
        category, inp = inp_tuple
        x = jnp.concatenate((category, inp), axis=-1)
        hidden, _ = nn.GRUCell()(hidden, x)
        output = nn.Dense(256)(hidden)
        output = nn.relu(output)
        output = nn.Dense(256)(output)
        output = nn.relu(output)
        output = nn.Dense(self.output_size)(output)
        output = nn.Dropout(0.1, deterministic = not train)(output)
        output = nn.softmax(output, axis = -1)
        return hidden, (category, output)

Training

In [5]:
import jax
import jax.numpy as jnp

In [6]:
import random

# Random item from a list
def randomChoice(l):
    return l[random.randint(0, len(l) - 1)]

# Get a random category and random line from that category
def randomTrainingPair():
    category = randomChoice(all_categories)
    line = randomChoice(category_lines[category])
    return category, line

# One-hot vector for category
def categoryTensor(category):
    li = all_categories.index(category)
    tensor = np.zeros((1, n_categories))
    tensor[0][li] = 1
    return jnp.array(tensor)

# One-hot matrix of first to last letters (not including EOS) for input
def inputTensor(line):
    tensor = np.zeros((len(line), 1, n_letters))
    for li in range(len(line)):
        letter = line[li]
        tensor[li][0][all_letters.find(letter)] = 1
    return jnp.array(tensor)

# LongTensor of second letter to end (EOS) for target
def targetTensor(line):
    letter_indexes = [all_letters.find(line[li]) for li in range(1, len(line))]
    letter_indexes.append(n_letters - 1) # EOS
    return jnp.array(letter_indexes)

# Make category, input, and target tensors from a random category, line pair
def randomTrainingExample():
    category, line = randomTrainingPair()
    category_tensor = categoryTensor(category)
    input_line_tensor = inputTensor(line)
    target_line_tensor = targetTensor(line)
    return category_tensor, input_line_tensor, target_line_tensor

In [7]:
INPUT_DIMS = n_letters
OUTPUT_SIZE = n_letters
N_CATEGORIES = n_categories
N_HIDDEN = 128

def init_hidden(n_hidden: int, batch_size: int = 1):
    return jnp.zeros((batch_size, n_hidden))

In [8]:
import optax
from flax.training import train_state  # Useful dataclass to keep train state

def create_train_state(rng, net, learning_rate, input_shape):
    """Creates initial `TrainState`."""
    params = net.init(rng, jnp.ones((1,N_CATEGORIES)), jnp.ones((1,INPUT_DIMS)), jnp.ones((1,N_HIDDEN)), False)['params']
    tx = optax.chain(
        optax.clip_by_global_norm(10.0),
        optax.adam(learning_rate),
    )
    return train_state.TrainState.create(
        apply_fn=net.apply, params=params, tx=tx)

In [9]:
from einops import repeat

In [10]:
def cross_entropy_loss(predictions, label):
    return -jnp.log(predictions[label])

In [11]:
import functools
from einops import repeat

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def rnn_forward(apply_fn, params, category_tensor, input_line_tensor, hidden):
    category_tensor = repeat(category_tensor, 'b l -> (repeat b) l', repeat=input_line_tensor.shape[0])
    _, predictions = jax.lax.scan(lambda hidden, inp: apply_fn({"params":params}, hidden, inp),
            jnp.array(hidden),
            (category_tensor, np.squeeze(input_line_tensor)),
    )
    return predictions[-1]

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def train_rnn(apply_fn, state, category_tensor, input_line_tensor, hidden, targets):
    # category_tensor = repeat(category_tensor, 'b l -> (repeat b) l', repeat=input_line_tensor.shape[0])

    def loss_fn(params):
        predictions = []
        inp = np.squeeze(input_line_tensor)
        length = np.squeeze(input_line_tensor).shape[0]
        hidden_state = jnp.array(hidden)
        for i in range(length):
            hidden_state, output = apply_fn({"params":params}, hidden_state, (category_tensor, inp[i:i+1]))
            predictions.append(np.squeeze(output[1]))
        predictions = jnp.stack(predictions)

        # _, predictions = jax.lax.scan(lambda hidden, inp: apply_fn({"params":params}, hidden, inp),
        #         jnp.array(hidden),
        #         (category_tensor, np.squeeze(input_line_tensor))
        # )
        return jnp.mean(jax.vmap(cross_entropy_loss)(jnp.squeeze(predictions), targets))

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    return state.apply_gradients(grads=grads), loss, grads

In [12]:
import os
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

def get_tensorboard_logger(
    experiment_name: str, base_log_path: str = "tensorboard_logs"
):
    log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: python -m tensorboard.main --logdir '{}'".format(full_log_path)
    )
    return train_writer


In [13]:
import optax
from flax.training import train_state  # Useful dataclass to keep train state

def create_train_state(rng, model, learning_rate):
    """Creates initial `TrainState`."""
    params = model.init(rng, jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES))), False)['params']
    tx = optax.chain(
        optax.clip_by_global_norm(10.0),
        optax.adam(learning_rate),
    )
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx
    )

In [14]:
import pandas as pd
import tqdm

def flatten(d):
    df = pd.json_normalize(d, sep='_')
    return df.to_dict(orient='records')[0]

def train_static_rnn(
    num_epochs,
    model,
    seed: int = 0,
    lr: float = 0.0001,
    ):
    writer = get_tensorboard_logger("RNNCHAR")
    rng = jax.random.PRNGKey(seed)
    state = create_train_state(rng, model, lr)

    bar = tqdm.tqdm(np.arange(num_epochs))

    try:
        for i in bar:
            category_tensor, input_line_tensor, target_line_tensor = randomTrainingExample()
            state, loss, grads = train_rnn(model.apply, state, category_tensor, input_line_tensor, init_hidden(N_HIDDEN), target_line_tensor)
            grad_dict = {k:dict(grads[k]) for k in grads.keys()}
            # grad_dict = flatten(grad_dict)

            # grad_dict = {k: {kk: np.sum(vv).item() for kk, vv in v.items()}
            #             for k, v in grad_dict.items()}
            # grad_dict 

            metrics = {"loss":loss.item(), }
            for key in metrics:
                writer.add_scalar(key, metrics[key], i)

            bar.set_description('Loss: {}'.format(loss.item()))
        return model, state
    except KeyboardInterrupt as e:
        return model, state

In [15]:
rnn = TargetRNN(OUTPUT_SIZE)

In [None]:
from hypernn.jax.utils import count_jax_params

count_jax_params(rnn, inputs=[jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES))), False])

In [None]:
model, state = train_static_rnn(100000, rnn)

In [None]:
max_length = 20

# Sample from a category and starting letter
def sample(category, start_letter='A'):
    category_tensor = categoryTensor(category)
    input = inputTensor(start_letter)
    hidden = init_hidden(N_HIDDEN)

    output_name = start_letter

    for i in range(max_length):
        hidden, out = rnn.apply({"params":state.params}, hidden, (category_tensor, input[0]))
        category, output = out
        # topv, topi = output.topk(1)
        topi = np.argmax(np.squeeze(output), axis=-1)
        # topi = topi[0][0]
        if topi == n_letters - 1:
            break
        else:
            letter = all_letters[topi]
            output_name += letter
        input = inputTensor(letter)

    return output_name

# Get multiple samples from one category and multiple starting letters
def samples(category, start_letters='ABC'):
    for start_letter in start_letters:
        print(sample(category, start_letter))

samples('Russian', 'RUS')

samples('German', 'GER')

samples('Spanish', 'SPA')

samples('Chinese', 'CHI')

### Hypernetwork

In [None]:
from typing import Optional, Any
import jax.numpy as jnp

from hypernn.jax.embedding_module import FlaxEmbeddingModule
from hypernn.jax.weight_generator import FlaxWeightGenerator
from hypernn.jax.hypernet import FlaxHyperNetwork


class CustomFlaxEmbeddingModule(FlaxEmbeddingModule):
    def setup(self):
        self.embedding = nn.Embed(self.num_embeddings, self.embedding_dim)

    def __call__(self, inp: Optional[Any] = None):
        indices = jnp.arange(0, self.num_embeddings)
        return self.embedding(indices)

class CustomFlaxWeightGenerator(FlaxWeightGenerator):
    def setup(self):
        self.dense1 = nn.Dense(32)
        self.dense2 = nn.Dense(self.hidden_dim)

    def __call__(self, embedding: jnp.array, inp: Optional[Any] = None):
        x = self.dense1(embedding)
        x = nn.relu(x)
        x = self.dense2(x)
        return x


In [None]:
hyper = FlaxHyperNetwork.from_target(
    target_network=rnn,
    inputs=[jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES)))],
    embedding_module=CustomFlaxEmbeddingModule,
    weight_generator=CustomFlaxWeightGenerator,
    embedding_dim = 32,
    num_embeddings = 512
)

In [None]:
rng = jax.random.PRNGKey(0)
params = hyper.init(rng, inp=[jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES)))])

In [None]:
hyper_generated_params = hyper.apply(params, method=hyper.generate_params)

In [None]:
generated_params, embeddings = hyper_generated_params

In [None]:
generated_params.shape

In [None]:
out, generated_params, embeddings = hyper.apply(params, inp=[jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES)))], generated_params=generated_params)

#### Static Hypernetwork

In [None]:
import functools
from einops import repeat
import pandas as pd
import tqdm

import optax
from flax.training import train_state  # Useful dataclass to keep train state

def create_hyper_train_state(rng, model, learning_rate):
    """Creates initial `TrainState`."""
    params = model.init(rng, [jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES))), False])['params']
    tx = optax.chain(
        optax.clip_by_global_norm(10.0),
        optax.adam(learning_rate),
    )
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx
    )

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def hyper_rnn_forward(apply_fn, params, category_tensor, input_line_tensor, hidden):
    category_tensor = repeat(category_tensor, 'b l -> (repeat b) l', repeat=input_line_tensor.shape[0])
    _, predictions = jax.lax.scan(lambda hidden, inp: apply_fn({"params":params}, hidden, inp),
            jnp.array(hidden),
            (category_tensor, np.squeeze(input_line_tensor)),
    )
    return predictions[-1]

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def train_hyper_rnn(apply_fn, state, category_tensor, input_line_tensor, hidden, targets):
    # category_tensor = repeat(category_tensor, 'b l -> (repeat b) l', repeat=input_line_tensor.shape[0])

    def loss_fn(params):
        predictions = []
        inp = np.squeeze(input_line_tensor)
        length = np.squeeze(input_line_tensor).shape[0]
        hidden_state = jnp.array(hidden)
        for i in range(length):
            out, _, _ = apply_fn({"params":params}, [hidden_state, (category_tensor, inp[i:i+1])])
            hidden_state, output = out
            predictions.append(np.squeeze(output[1]))
        predictions = jnp.stack(predictions)

        # _, predictions = jax.lax.scan(lambda hidden, inp: apply_fn({"params":params}, hidden, inp),
        #         jnp.array(hidden),
        #         (category_tensor, np.squeeze(input_line_tensor))
        # )
        return jnp.mean(jax.vmap(cross_entropy_loss)(jnp.squeeze(predictions), targets))

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    return astate.apply_gradients(grads=grads), loss, grads

def flatten(d):
    df = pd.json_normalize(d, sep='_')
    return df.to_dict(orient='records')[0]

def train_static_hyper_rnn(
    num_epochs,
    model,
    seed: int = 0,
    lr: float = 0.0001,
    ):
    writer = get_tensorboard_logger("HYPER_RNNCHAR")
    rng = jax.random.PRNGKey(seed)
    state = create_hyper_train_state(rng, model, lr)

    bar = tqdm.tqdm(np.arange(num_epochs))

    try:
        for i in bar:
            category_tensor, input_line_tensor, target_line_tensor = randomTrainingExample()
            state, loss, grads = train_hyper_rnn(model.apply, state, category_tensor, input_line_tensor, init_hidden(N_HIDDEN), target_line_tensor)
            grad_dict = {k:dict(grads[k]) for k in grads.keys()}
            # grad_dict = flatten(grad_dict)

            # grad_dict = {k: {kk: np.sum(vv).item() for kk, vv in v.items()}
            #             for k, v in grad_dict.items()}
            # grad_dict 

            metrics = {"loss":loss.item(), }
            for key in metrics:
                writer.add_scalar(key, metrics[key], i)

            bar.set_description('Loss: {}'.format(loss.item()))
        return model, state
    except KeyboardInterrupt as e:
        return model, state

In [None]:
_, hyper_state = train_static_hyper_rnn(100000, hyper)

In [None]:
max_length = 20

# Sample from a category and starting letter
def sample(category, start_letter='A'):
    category_tensor = categoryTensor(category)
    input = inputTensor(start_letter)
    hidden = init_hidden(N_HIDDEN)

    output_name = start_letter

    for i in range(max_length):
        hyper_out, _, _ = hyper.apply({"params":hyper_state.params}, inp=[hidden, (category_tensor, input[0])])
        hidden, out = hyper_out
        category, output = out
        # topv, topi = output.topk(1)
        topi = np.argmax(np.squeeze(output), axis=-1)
        # topi = topi[0][0]
        if topi == n_letters - 1:
            break
        else:
            letter = all_letters[topi]
            output_name += letter
        input = inputTensor(letter)

    return output_name

# Get multiple samples from one category and multiple starting letters
def samples(category, start_letters='ABC'):
    for start_letter in start_letters:
        print(sample(category, start_letter))

samples('Russian', 'RUS')

samples('German', 'GER')

samples('Spanish', 'SPA')

samples('Chinese', 'CHI')

### Dynamic HyperNetwork

In [16]:
from hypernn.jax.utils import count_jax_params


In [24]:
class RNNCell(nn.Module):

    output_size: int


    @nn.compact
    def __call__(self, state, x):
        # Wh @ h + Wx @ x + b can be efficiently computed
        # by concatenating the vectors and then having a single dense layer
        x = jnp.concatenate([state, x], axis=-1)
        new_state = jnp.tanh(nn.Dense(self.output_size)(x))
        return new_state

In [25]:
from typing import Optional, Any
import jax.numpy as jnp

from hypernn.jax.embedding_module import FlaxEmbeddingModule
from hypernn.jax.weight_generator import FlaxWeightGenerator
from hypernn.jax.hypernet import FlaxHyperNetwork


class DynamicFlaxEmbeddingModule(FlaxEmbeddingModule):
    def setup(self):
        self.embedding = nn.Embed(self.num_embeddings, self.embedding_dim)
        self.rnn = RNNCell(self.num_embeddings)

    def __call__(self, inp: Optional[Any] = None, hidden: Optional[np.array] = None):
        inp = inp[1]
        if hidden is None:
            hidden = jnp.zeros((1, self.num_embeddings))
        category, input_tensor = inp
        x = jnp.concatenate((category, input_tensor), axis=-1)
        hidden = self.rnn(hidden, x)
        indices = jnp.arange(0, self.num_embeddings)
        return self.embedding(indices)*hidden.reshape(self.num_embeddings, 1), hidden

class DynamicFlaxWeightGenerator(FlaxWeightGenerator):
    def setup(self):
        self.dense1 = nn.Dense(32)
        self.dense2 = nn.Dense(self.hidden_dim)

    def __call__(self, embedding: jnp.array, inp: Optional[Any] = None):
        x = self.dense1(embedding[0])
        x = nn.relu(x)
        x = self.dense2(x)
        return x


In [26]:
import functools
from einops import repeat
import pandas as pd
import tqdm

import optax
from flax.training import train_state  # Useful dataclass to keep train state

def create_hyper_train_state(rng, model, learning_rate):
    """Creates initial `TrainState`."""
    params = model.init(rng, [jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES))), False])['params']
    tx = optax.chain(
        optax.clip_by_global_norm(10.0),
        optax.adam(learning_rate),
    )
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx
    )

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def train_hyper_rnn(apply_fn, state, category_tensor, input_line_tensor, hidden, targets):
    # category_tensor = repeat(category_tensor, 'b l -> (repeat b) l', repeat=input_line_tensor.shape[0])

    def loss_fn(params):
        predictions = []
        inp = np.squeeze(input_line_tensor)
        length = np.squeeze(input_line_tensor).shape[0]
        hidden_state = jnp.array(hidden)
        hyper_hidden = None
        for i in range(length):
            out, _, embedding_output = apply_fn({"params":params}, [hidden_state, (category_tensor, inp[i:i+1])], embedding_module_kwargs={"hidden":hyper_hidden})
            hyper_hidden = embedding_output[1]
            hidden_state, output = out
            predictions.append(np.squeeze(output[1]))
        predictions = jnp.stack(predictions)

        # _, predictions = jax.lax.scan(lambda hidden, inp: apply_fn({"params":params}, hidden, inp),
        #         jnp.array(hidden),
        #         (category_tensor, np.squeeze(input_line_tensor))
        # )
        return jnp.mean(jax.vmap(cross_entropy_loss)(jnp.squeeze(predictions), targets))

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    return state.apply_gradients(grads=grads), loss, grads

def flatten(d):
    df = pd.json_normalize(d, sep='_')
    return df.to_dict(orient='records')[0]

def train_static_hyper_rnn(
    num_epochs,
    model,
    seed: int = 0,
    lr: float = 0.0001,
    ):
    writer = get_tensorboard_logger("HYPER_RNNCHAR")
    rng = jax.random.PRNGKey(seed)
    state = create_hyper_train_state(rng, model, lr)

    bar = tqdm.tqdm(np.arange(num_epochs))

    try:
        for i in bar:
            category_tensor, input_line_tensor, target_line_tensor = randomTrainingExample()
            state, loss, grads = train_hyper_rnn(model.apply, state, category_tensor, input_line_tensor, init_hidden(N_HIDDEN), target_line_tensor)
            grad_dict = {k:dict(grads[k]) for k in grads.keys()}
            # grad_dict = flatten(grad_dict)

            # grad_dict = {k: {kk: np.sum(vv).item() for kk, vv in v.items()}
            #             for k, v in grad_dict.items()}
            # grad_dict 

            metrics = {"loss":loss.item(), }
            for key in metrics:
                writer.add_scalar(key, metrics[key], i)

            bar.set_description('Loss: {}'.format(loss.item()))
        return model, state
    except KeyboardInterrupt as e:
        return model, state

In [27]:
hyper = FlaxHyperNetwork.from_target(
    target_network=rnn,
    inputs=[jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES)))],
    embedding_module=DynamicFlaxEmbeddingModule,
    weight_generator=DynamicFlaxWeightGenerator,
    embedding_dim = 8,
    num_embeddings = 128
)

In [28]:
count_jax_params(hyper, inputs=[[jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES))), False]])

77853

In [29]:
model, state = train_static_hyper_rnn(100000, hyper)

Follow tensorboard logs with: python -m tensorboard.main --logdir '/home/shyam/Code/hyper-nn/notebooks/jax/tensorboard_logs/HYPER_RNNCHAR_2022-03-27 00:38:25.475319'


Loss: 1.424128532409668: 100%|██████████| 100000/100000 [16:56<00:00, 98.41it/s]   


In [None]:
max_length = 20

# Sample from a category and starting letter
def sample(category, start_letter='A'):
    category_tensor = categoryTensor(category)
    input = inputTensor(start_letter)
    hidden = init_hidden(N_HIDDEN)

    output_name = start_letter

    hyper_hidden = None
    for i in range(max_length):
        hyper_out, _, embedding_output = hyper.apply({"params":state.params}, inp=[hidden, (category_tensor, input[0])], embedding_module_kwargs={"hidden":hyper_hidden})
        hyper_hidden = embedding_output[1]
        hidden, out = hyper_out
        category, output = out
        # topv, topi = output.topk(1)
        topi = np.argmax(np.squeeze(output), axis=-1)
        # topi = topi[0][0]
        if topi == n_letters - 1:
            break
        else:
            letter = all_letters[topi]
            output_name += letter
        input = inputTensor(letter)

    return output_name

# Get multiple samples from one category and multiple starting letters
def samples(category, start_letters='ABC'):
    for start_letter in start_letters:
        print(sample(category, start_letter))

samples('Russian', 'RUS')

samples('German', 'GER')

samples('Spanish', 'SPA')

samples('Chinese', 'CHI')