#### Tutorial adapted from https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html

#### Install reqs

In [None]:
# !pip install hypernn
# !pip install tqdm
# !pip install tensorboard
# !pip install matplotlib
# !pip install optax
# !pip install einops

#### Name dataset

In [None]:
from __future__ import unicode_literals, print_function, division
from io import open
import glob
import os
import unicodedata
import string
from hypernn.jax.utils import count_jax_params

all_letters = string.ascii_letters + " .,;'-"
n_letters = len(all_letters) + 1 # Plus EOS marker

def findFiles(path): return glob.glob(path)

# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )

# Read a file and split into lines
def readLines(filename):
    with open(filename, encoding='utf-8') as some_file:
        return [unicodeToAscii(line.strip()) for line in some_file]

# Build the category_lines dictionary, a list of lines per category
category_lines = {}
all_categories = []
for filename in findFiles('./names/names/*.txt'):
    category = os.path.splitext(os.path.basename(filename))[0]
    all_categories.append(category)
    lines = readLines(filename)
    category_lines[category] = lines

n_categories = len(all_categories)

if n_categories == 0:
    raise RuntimeError('Data not found. Make sure that you downloaded data '
        'from https://download.pytorch.org/tutorial/data.zip and extract it to '
        'the current directory.')

print('# categories:', n_categories, all_categories)
print(unicodeToAscii("O'Néàl"))

In [None]:
from typing import Sequence, List, Tuple, Optional

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
from functools import partial

class TargetRNN(nn.Module):

    output_size: int

    @nn.compact
    def __call__(self, category, inp, hidden):
        x = jnp.concatenate((category, inp), axis=-1)
        hidden, _ = nn.GRUCell()(hidden, x)
        output = nn.Dense(256)(hidden)
        output = nn.relu(output)
        output = nn.Dense(256)(output)
        output = nn.relu(output)
        output = nn.Dense(self.output_size)(output)
        output = nn.softmax(output, axis = -1)
        return output, hidden

INPUT_DIMS = n_letters
OUTPUT_SIZE = n_letters
N_CATEGORIES = n_categories
N_HIDDEN = 128

def init_hidden(n_hidden: int, batch_size: int = 1):
    return jnp.zeros((batch_size, n_hidden))

In [None]:
target_network = TargetRNN(n_letters)
count_jax_params(target_network, inputs=[jnp.zeros((1, n_categories)),  jnp.zeros((1, n_letters)),jnp.zeros((1, N_HIDDEN))])

#### Dynamic Hypernetwork that modifies its weights based on network inputs

In [None]:
from typing import Optional, Iterable, Dict, Any, Tuple
from hypernn.jax.dynamic_hypernet import JaxDynamicHyperNetwork

class CharDynamicHyperNetwork(JaxDynamicHyperNetwork):


    def generate_params(
        self, category_tensor, input_tensor, hidden_state = None
    ):
        concatenated = jnp.concatenate((category_tensor, input_tensor), axis=-1)
        embedding, hidden_state = self.embedding_module(concatenated, hidden_state=hidden_state)
        generated_params = self.weight_generator(embedding).reshape(-1)
        return generated_params, {"embedding": embedding, "hidden_state": hidden_state}


In [None]:
EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 64


hypernetwork = CharDynamicHyperNetwork.from_target(
    target_network,
    embedding_dim=EMBEDDING_DIM,
    num_embeddings=NUM_EMBEDDINGS,
    input_dim=n_letters+n_categories,
    inputs=[jnp.zeros((1, N_CATEGORIES)),  jnp.zeros((1, INPUT_DIMS)), jnp.zeros((1, N_HIDDEN))],
)

count_jax_params(
    hypernetwork,
    inputs=[jnp.zeros((1, N_CATEGORIES)),  jnp.zeros((1, INPUT_DIMS)), jnp.zeros((1, N_HIDDEN))],
    generate_params_kwargs={
        "category_tensor":jnp.zeros((1, N_CATEGORIES)),
        "input_tensor":jnp.zeros((1, INPUT_DIMS))
    }
)

#### Training procedure

In [None]:
import random

# Random item from a list
def randomChoice(l):
    return l[random.randint(0, len(l) - 1)]

# Get a random category and random line from that category
def randomTrainingPair():
    category = randomChoice(all_categories)
    line = randomChoice(category_lines[category])
    return category, line

# One-hot vector for category
def categoryTensor(category):
    li = all_categories.index(category)
    tensor = np.zeros((1, n_categories))
    tensor[0][li] = 1
    return jnp.array(tensor)

# One-hot matrix of first to last letters (not including EOS) for input
def inputTensor(line):
    tensor = np.zeros((len(line), 1, n_letters))
    for li in range(len(line)):
        letter = line[li]
        tensor[li][0][all_letters.find(letter)] = 1
    return jnp.array(tensor)

# LongTensor of second letter to end (EOS) for target
def targetTensor(line):
    letter_indexes = [all_letters.find(line[li]) for li in range(1, len(line))]
    letter_indexes.append(n_letters - 1) # EOS
    return jnp.array(letter_indexes)

# Make category, input, and target tensors from a random category, line pair
def randomTrainingExample():
    category, line = randomTrainingPair()
    category_tensor = categoryTensor(category)
    input_line_tensor = inputTensor(line)
    target_line_tensor = targetTensor(line)
    return category_tensor, input_line_tensor, target_line_tensor

import os
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

def get_tensorboard_logger(
    experiment_name: str, base_log_path: str = "tensorboard_logs"
):
    log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: python -m tensorboard.main --logdir '{}'".format(full_log_path)
    )
    return train_writer


In [None]:
INPUT_DIMS = n_letters
OUTPUT_SIZE = n_letters
N_CATEGORIES = n_categories
N_HIDDEN = 128

def init_hidden(n_hidden: int, batch_size: int = 1):
    return jnp.zeros((batch_size, n_hidden))

In [None]:
import functools
from einops import repeat
import pandas as pd
import tqdm

import flax
import optax
from flax.training import train_state  # Useful dataclass to keep train state

def flatten(d):
    df = pd.json_normalize(d, sep='_')
    return df.to_dict(orient='records')[0]

def cross_entropy_loss(predictions, label):
    return -jnp.log(predictions[label])

def create_hyper_train_state(rng, model, learning_rate):
    """Creates initial `TrainState`."""
    params = model.init(
        rng,
        jnp.zeros((1, N_CATEGORIES)),  jnp.zeros((1, INPUT_DIMS)), jnp.zeros((1, N_HIDDEN)),
        generate_params_kwargs={
            "category_tensor":jnp.zeros((1, N_CATEGORIES)),
            "input_tensor":jnp.zeros((1, INPUT_DIMS))
        }
    )['params']
    tx = optax.chain(
        optax.clip_by_global_norm(10.0),
        optax.adam(learning_rate),
    )
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx
    )

@functools.partial(jax.jit, static_argnames=('apply_fn', 'generate_params_fn'))
def train_dynamic_hyper_rnn(apply_fn, generate_params_fn, state, category_tensor, input_line_tensor, hidden, targets):

    def loss_fn(params):
        predictions = []
        inp = np.squeeze(input_line_tensor)
        length = np.squeeze(input_line_tensor).shape[0]
        hidden_state = jnp.array(hidden)
        hyper_hidden = None
        for i in range(length):
            out, _, aux_output = apply_fn(
                {"params":params},
                category_tensor,
                inp[i:i+1],
                hidden_state,
                generate_params_kwargs = dict(
                    category_tensor=category_tensor,
                    input_tensor=inp[i:i+1],
                    hidden_state=hyper_hidden,
                ),
                has_aux=True
            )
            hyper_hidden = aux_output["hidden_state"]
            output, hidden_state = out
            predictions.append(jnp.squeeze(output))
        predictions = jnp.stack(predictions)
        return jnp.mean(jax.vmap(cross_entropy_loss)(jnp.squeeze(predictions), targets))

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    return state.apply_gradients(grads=grads), loss, grads

import collections

def flatten(d, parent_key='', sep='_'):
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, flax.core.frozen_dict.FrozenDict) or isinstance(v, dict):
            items.extend(flatten(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

def train_hyper_rnn(
    num_epochs,
    model,
    seed: int = 0,
    lr: float = 0.0001,
    ):
    writer = get_tensorboard_logger("JaxHyperRNN")
    rng = jax.random.PRNGKey(seed)
    state = create_hyper_train_state(rng, model, lr)

    bar = tqdm.tqdm(np.arange(num_epochs))

    try:
        for i in bar:
            category_tensor, input_line_tensor, target_line_tensor = randomTrainingExample()
            state, loss, grads = train_dynamic_hyper_rnn(model.apply, model.generate_params, state, category_tensor, input_line_tensor, init_hidden(N_HIDDEN), target_line_tensor)
            grad_dict = {k:dict(grads[k]) for k in grads.keys()}
            grad_dict = flatten(grad_dict)
            grad_dict = {k:np.sum(np.array(grad_dict[k])) for k in grad_dict}

            metrics = {"loss":loss.item(), **grad_dict}
            for key in metrics:
                writer.add_scalar(key, metrics[key], i)

            bar.set_description('Loss: {}'.format(loss.item()))
        return model, state
    except KeyboardInterrupt as e:
        return model, state

In [None]:
hypernetwork, state = train_hyper_rnn(50000, hypernetwork, lr=0.0001)

In [None]:
max_length = 20

# Sample from a category and starting letter
def sample(category, start_letter='A'):
    category_tensor = categoryTensor(category)
    input = inputTensor(start_letter)
    hidden = init_hidden(N_HIDDEN)

    output_name = start_letter
    hyper_hidden = None

    for i in range(max_length):
        out, _, aux_output = hypernetwork.apply(
            {"params":state.params},
            category_tensor,
            input[0],
            hidden,
            generate_params_kwargs = dict(
                category_tensor=category_tensor,
                input_tensor=input[0],
                hidden_state=hyper_hidden
            ),
            has_aux=True
        )
        hyper_hidden = aux_output["hidden_state"]
        output, hidden = out
        # topv, topi = output.topk(1)
        topi = np.argmax(np.squeeze(output), axis=-1)
        # topi = topi[0][0]
        if topi == n_letters - 1:
            break
        else:
            letter = all_letters[topi]
            output_name += letter
        input = inputTensor(letter)

    return output_name

# Get multiple samples from one category and multiple starting letters
def samples(category, start_letters='ABC'):
    for start_letter in start_letters:
        print(sample(category, start_letter))

samples('Russian', 'RUS')

samples('German', 'GER')

samples('Spanish', 'SPA')

samples('Chinese', 'CHI')