#### Tutorial adapted from https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html

In [6]:
# uncomment this to enable jax gpu preallocation, might lead to memory issues

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [7]:
from __future__ import unicode_literals, print_function, division
from io import open
import glob
import os
import unicodedata
import string

all_letters = string.ascii_letters + " .,;'-"
n_letters = len(all_letters) + 1 # Plus EOS marker

def findFiles(path): return glob.glob(path)

# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )

# Read a file and split into lines
def readLines(filename):
    with open(filename, encoding='utf-8') as some_file:
        return [unicodeToAscii(line.strip()) for line in some_file]

# Build the category_lines dictionary, a list of lines per category
category_lines = {}
all_categories = []
for filename in findFiles('../names/names/*.txt'):
    category = os.path.splitext(os.path.basename(filename))[0]
    all_categories.append(category)
    lines = readLines(filename)
    category_lines[category] = lines

n_categories = len(all_categories)

if n_categories == 0:
    raise RuntimeError('Data not found. Make sure that you downloaded data '
        'from https://download.pytorch.org/tutorial/data.zip and extract it to '
        'the current directory.')

print('# categories:', n_categories, all_categories)
print(unicodeToAscii("O'Néàl"))

# categories: 18 ['Portuguese', 'Czech', 'Korean', 'Arabic', 'English', 'Russian', 'German', 'Spanish', 'Vietnamese', 'Polish', 'Irish', 'Japanese', 'French', 'Scottish', 'Greek', 'Chinese', 'Italian', 'Dutch']
O'Neal


In [8]:
from typing import Sequence, List, Tuple, Optional

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

In [117]:
from functools import partial

class TargetRNN(nn.Module):

    output_size: int

    @nn.compact
    def __call__(self, hidden, inp_tuple, train: bool = False):
        category, inp = inp_tuple
        x = jnp.concatenate((category, inp), axis=-1)
        hidden, _ = nn.GRUCell()(hidden, x)
        output = nn.Dense(self.output_size)(hidden)
        output = nn.Dropout(0.1, deterministic = not train)(output)
        output = nn.softmax(output, axis = -1)
        return hidden, (category, output)

Training

In [118]:
import jax
import jax.numpy as jnp

In [119]:
import random

# Random item from a list
def randomChoice(l):
    return l[random.randint(0, len(l) - 1)]

# Get a random category and random line from that category
def randomTrainingPair():
    category = randomChoice(all_categories)
    line = randomChoice(category_lines[category])
    return category, line

# One-hot vector for category
def categoryTensor(category):
    li = all_categories.index(category)
    tensor = np.zeros((1, n_categories))
    tensor[0][li] = 1
    return jnp.array(tensor)

# One-hot matrix of first to last letters (not including EOS) for input
def inputTensor(line):
    tensor = np.zeros((len(line), 1, n_letters))
    for li in range(len(line)):
        letter = line[li]
        tensor[li][0][all_letters.find(letter)] = 1
    return jnp.array(tensor)

# LongTensor of second letter to end (EOS) for target
def targetTensor(line):
    letter_indexes = [all_letters.find(line[li]) for li in range(1, len(line))]
    letter_indexes.append(n_letters - 1) # EOS
    return jnp.array(letter_indexes)

# Make category, input, and target tensors from a random category, line pair
def randomTrainingExample():
    category, line = randomTrainingPair()
    category_tensor = categoryTensor(category)
    input_line_tensor = inputTensor(line)
    target_line_tensor = targetTensor(line)
    return category_tensor, input_line_tensor, target_line_tensor

In [120]:
INPUT_DIMS = n_letters
OUTPUT_SIZE = n_letters
N_CATEGORIES = n_categories
N_HIDDEN = 128

def init_hidden(n_hidden: int, batch_size: int = 1):
    return jnp.zeros((batch_size, n_hidden))

In [121]:
import optax
from flax.training import train_state  # Useful dataclass to keep train state

def create_train_state(rng, net, learning_rate, input_shape):
    """Creates initial `TrainState`."""
    params = net.init(rng, jnp.ones((1,N_CATEGORIES)), jnp.ones((1,INPUT_DIMS)), jnp.ones((1,N_HIDDEN)), False)['params']
    tx = optax.chain(
        optax.clip_by_global_norm(10.0),
        optax.adam(learning_rate),
    )
    return train_state.TrainState.create(
        apply_fn=net.apply, params=params, tx=tx)

In [122]:
category_tensor, input_line_tensor, target_line_tensor = randomTrainingExample()


In [123]:
category_tensor.shape

(1, 18)

In [124]:
input_line_tensor.shape

(3, 1, 59)

In [125]:
target_line_tensor[1]

DeviceArray(7, dtype=int32)

In [144]:
target_line_tensor.shape

(3,)

In [126]:
rnn = TargetRNN(OUTPUT_SIZE)

In [127]:
rng = jax.random.PRNGKey(0)

In [128]:
params = rnn.init(rng, jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES))), False)

In [129]:
hidden_state = init_hidden(N_HIDDEN)

In [130]:
np.squeeze(category_tensor)

DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             1., 0., 0.], dtype=float32)

In [131]:
input_line_tensor.shape

(3, 1, 59)

In [132]:
hidden_state.shape

(1, 128)

In [133]:
repeated_category = repeat(category_tensor, 'b l -> (repeat b) l', repeat=input_line_tensor.shape[0])

In [134]:
repeated_category.shape

(3, 18)

In [135]:
input_line_tensor.shape

(3, 1, 59)

In [136]:
rnn.apply(params, jnp.array(hidden_state), (repeated_category, np.squeeze(input_line_tensor)))[0].shape

(3, 128)

In [137]:
from einops import repeat

In [138]:
category_tensor.shape

(1, 18)

In [142]:
jax.lax.scan(lambda hidden, inp: rnn.apply(params, hidden, inp),
            jnp.array(hidden_state),
            (repeated_category, np.squeeze(input_line_tensor))
)

(DeviceArray([[-0.08747377,  0.02278742, -0.06387921, -0.15278795,
                0.06176033, -0.09011275, -0.054451  , -0.17392935,
                0.12405825,  0.0074779 , -0.04134857, -0.07790678,
               -0.01646499, -0.2850389 ,  0.12619592,  0.06178052,
               -0.09304049, -0.18222329,  0.22363707, -0.00632647,
                0.0179695 ,  0.23133937,  0.0538008 , -0.10557015,
               -0.10366531,  0.06853012,  0.00603865, -0.14605796,
               -0.03334313, -0.01715118,  0.18280727, -0.19975418,
               -0.03592777,  0.0023947 ,  0.00790655,  0.13143978,
               -0.01390804, -0.13061263, -0.02872819,  0.07798855,
               -0.00784443, -0.05432156,  0.10904018,  0.19138736,
                0.10140104,  0.026749  ,  0.12840232, -0.10855062,
               -0.21728131,  0.06692445, -0.01994992, -0.2886571 ,
                0.11453028, -0.0422584 , -0.00528734, -0.00176928,
                0.04072424,  0.07691555, -0.24232778,  0.01171

In [232]:
def cross_entropy_loss(predictions, label):
    return -jnp.log(predictions[label])

In [233]:
import functools
from einops import repeat

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def train_rnn(apply_fn, state, category_tensor, input_line_tensor, hidden, targets):
    category_tensor = repeat(category_tensor, 'b l -> (repeat b) l', repeat=input_line_tensor.shape[0])

    def loss_fn(params):
        _, predictions = jax.lax.scan(lambda hidden, inp: apply_fn({"params":params}, hidden, inp),
                jnp.array(hidden),
                (category_tensor, np.squeeze(input_line_tensor))
        )
        return jnp.mean(jax.vmap(cross_entropy_loss)(jnp.squeeze(predictions[-1]), targets))

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    return state.apply_gradients(grads=grads), loss, grads

In [234]:
import os
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

def get_tensorboard_logger(
    experiment_name: str, base_log_path: str = "tensorboard_logs"
):
    log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: python -m tensorboard.main --logdir '{}'".format(full_log_path)
    )
    return train_writer


In [235]:
import optax
from flax.training import train_state  # Useful dataclass to keep train state

def create_train_state(rng, model, learning_rate):
    """Creates initial `TrainState`."""
    params = model.init(rng, jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES))), False)['params']
    tx = optax.chain(
        optax.clip_by_global_norm(10.0),
        optax.adam(learning_rate),
    )
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx
    )

In [238]:
import pandas as pd
import tqdm

def flatten(d):
    df = pd.json_normalize(d, sep='_')
    return df.to_dict(orient='records')[0]

def train_static_rnn(
    num_epochs,
    model,
    seed: int = 0,
    lr: float = 0.0001,
    ):
    writer = get_tensorboard_logger("RNNCHAR")
    rng = jax.random.PRNGKey(seed)
    state = create_train_state(rng, model, lr)

    bar = tqdm.tqdm(np.arange(num_epochs))

    try:
        for i in bar:
            category_tensor, input_line_tensor, target_line_tensor = randomTrainingExample()
            state, loss, grads = train_rnn(model.apply, state, category_tensor, input_line_tensor, init_hidden(N_HIDDEN), target_line_tensor)
            grad_dict = {k:dict(grads[k]) for k in grads.keys()}
            # grad_dict = flatten(grad_dict)

            # grad_dict = {k: {kk: np.sum(vv).item() for kk, vv in v.items()}
            #             for k, v in grad_dict.items()}
            # grad_dict 

            metrics = {"loss":loss.item(), }
            for key in metrics:
                writer.add_scalar(key, metrics[key], i)

            bar.set_description('Loss: {}'.format(loss.item()))

    except KeyboardInterrupt as e:
        return model, state

In [239]:
train_static_rnn(100000, rnn)

Follow tensorboard logs with: python -m tensorboard.main --logdir '/home/shyam/Code/hyper-nn/notebooks/jax/tensorboard_logs/RNNCHAR_2022-03-25 19:35:04.330497'


Loss: 3.0658493041992188: 100%|██████████| 1000/1000 [00:41<00:00, 23.85it/s]


In [147]:

import torch.nn as nn

criterion = nn.NLLLoss()

def train_static_hyper_rnn_step(static_hyper_rnn, optimizer, category_tensor, input_line_tensor, target_line_tensor):
    target_line_tensor = target_line_tensor.unsqueeze(-1).to(static_hyper_rnn.device)
    hidden = target_network.initHidden().to(static_hyper_rnn.device)

    optimizer.zero_grad()

    loss = 0

    generated_params, embedding_output = static_hyper_rnn.generate_params()

    for i in range(input_line_tensor.size(0)):
        out, _, _ = static_hyper_rnn(inp=(category_tensor.to(static_hyper_rnn.device), input_line_tensor[i].to(static_hyper_rnn.device), hidden), generated_params=generated_params)
        output, hidden = out
        l = criterion(output, target_line_tensor[i])
        loss += l

    loss.backward()
    torch.nn.utils.clip_grad_norm_(static_hyper_rnn.parameters(), 10.0)
    optimizer.step()

    grad_dict = {}
    for n, W in static_hyper_rnn.named_parameters():
        if W.grad is not None:
            grad_dict["{}_grad".format(n)] = float(torch.sum(W.grad).item())

    # for p in rnn.parameters():
    #     p.data.add_(p.grad.data, alpha=-learning_rate)

    return output, {"loss":loss.item() / input_line_tensor.size(0), **grad_dict}

In [148]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import os
import tqdm

def get_tensorboard_logger(
    experiment_name: str, base_log_path: str = "tensorboard_logs"
):
    log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: python -m tensorboard.main --logdir '{}'".format(full_log_path)
    )
    return train_writer


In [149]:
from tqdm import tqdm
import numpy as np

def train(hypernet, train_iter_fn, lr, n_iters):
    writer = get_tensorboard_logger("HyperRNN")
    optimizer = torch.optim.Adam(hypernet.parameters(), lr=lr)
    bar = tqdm(np.arange(n_iters))
    for i in bar:
        category_tensor, input_line_tensor, target_line_tensor = randomTrainingExample()
        _, metrics = train_iter_fn(hypernet, optimizer, category_tensor, input_line_tensor, target_line_tensor)

        for key in metrics:
            writer.add_scalar(key, metrics[key], i)

        loss = metrics['loss']
        bar.set_description('Loss: {}'.format(loss))


### Hypernetwork

In [150]:
from hypernn.torch.hypernet import TorchHyperNetwork
from hypernn.torch.weight_generator import TorchWeightGenerator, DefaultTorchWeightGenerator
from hypernn.torch.embedding_module import TorchEmbeddingModule, DefaultTorchEmbeddingModule

In [151]:
target_network = TargetRNN(n_letters, 128, n_letters)
pytorch_total_params = sum(p.numel() for p in target_network.parameters() if p.requires_grad)
pytorch_total_params

87099

In [152]:
EMBEDDING_DIM = 8
NUM_EMBEDDINGS = 64

embedding_module = DefaultTorchEmbeddingModule.from_target(target_network, EMBEDDING_DIM, NUM_EMBEDDINGS)
weight_generator = DefaultTorchWeightGenerator.from_target(target_network, EMBEDDING_DIM, NUM_EMBEDDINGS)

In [153]:
hypernetwork = TorchHyperNetwork(
                                input_shape=((1, n_categories), (1, n_letters), (1, 128)),
                                target_network=target_network,
                                embedding_module=embedding_module,
                                weight_generator=weight_generator
                            )
pytorch_total_params = sum(p.numel() for p in hypernetwork.parameters() if p.requires_grad)
pytorch_total_params

12761

In [154]:
device = torch.device('cuda')
hypernetwork = hypernetwork.to(device)

In [155]:
learning_rate = 1e-4

train(hypernet=hypernetwork, train_iter_fn=train_static_hyper_rnn_step, lr=learning_rate, n_iters=100000)


Follow tensorboard logs with: python -m tensorboard.main --logdir '/home/shyam/Code/hyper-nn/notebooks/torch/tensorboard_logs/HyperRNN_2022-03-19 01:18:27.628167'


Loss: 1.933573341369629:  78%|███████▊  | 77943/100000 [13:19<03:46, 97.53it/s]   


KeyboardInterrupt: 

In [156]:
hypernetwork = hypernetwork.to(torch.device('cpu'))

In [165]:
max_length = 20

# Sample from a category and starting letter
def sample(category, start_letter='A'):
    with torch.no_grad():  # no need to track history in sampling
        category_tensor = categoryTensor(category)
        input = inputTensor(start_letter)
        hidden = target_network.initHidden()

        output_name = start_letter

        for i in range(max_length):
            out, _, _ = hypernetwork(inp=(category_tensor, input[0], hidden))
            output, hidden = out
            topv, topi = output.topk(1)
            topi = topi[0][0]
            if topi == n_letters - 1:
                break
            else:
                letter = all_letters[topi]
                output_name += letter
            input = inputTensor(letter)

        return output_name

# Get multiple samples from one category and multiple starting letters
def samples(category, start_letters='ABC'):
    for start_letter in start_letters:
        print(sample(category, start_letter))

samples('Russian', 'RUS')

samples('German', 'GER')

samples('Spanish', 'SPA')

samples('Chinese', 'CHI')

Rikevev
Uhikov
Sakikov
Gorn
Einter
Roster
Salla
Perez
Aranda
Chi
Hua
Iie


### Dynamic HyperNetwork

In [166]:
from typing import Optional, Any, Tuple
import functools
import torch.nn.functional as F

class DynamicTorchEmbeddingModule(TorchEmbeddingModule):
    def __init__(self, embedding_dim: int, num_embeddings: int, input_shape):
        super().__init__(embedding_dim, num_embeddings)
        self.rnn_hidden_dim = num_embeddings
        self.gru = nn.RNNCell(np.prod(input_shape), num_embeddings)
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)

    def forward(self, inp, hidden_state: Optional[torch.Tensor] = None):
        x = torch.cat(inp[:-1], -1)
        if hidden_state is None:
            hidden_state = torch.zeros(x.size(0), self.rnn_hidden_dim).to(self.device)
        hidden_state = torch.sigmoid(self.gru(x, hidden_state))
        indices = torch.arange(self.num_embeddings).to(self.device)
        embedding = self.embedding(indices)*hidden_state.view(self.num_embeddings, 1)
        return embedding, hidden_state

    def initHidden(self):
        return torch.zeros(1, self.num_embeddings).to(self.device)

class DynamicTorchWeightGenerator(TorchWeightGenerator):
    def __init__(self, embedding_dim: int, num_embeddings: int, hidden_dim: int, input_shape: Optional[Any] = None):
        super().__init__(embedding_dim, num_embeddings, hidden_dim, input_shape)
        self.linear1 = nn.Linear(embedding_dim, 16)
        self.linear2 = nn.Linear(16, hidden_dim)

    def forward(
        self, embedding: Tuple[torch.Tensor, torch.Tensor], inp: Optional[Any] = None
    ) -> torch.Tensor:
        x = self.linear1(embedding[0])
        x = F.relu(x)
        return self.linear2(x).view(-1)

In [168]:
target_network = TargetRNN(n_letters, 128, n_letters)
pytorch_total_params = sum(p.numel() for p in target_network.parameters() if p.requires_grad)
pytorch_total_params

87099

In [169]:
EMBEDDING_DIM = 8
NUM_EMBEDDINGS = 64

dynamic_embedding_module = DynamicTorchEmbeddingModule.from_target(target_network, EMBEDDING_DIM, NUM_EMBEDDINGS, input_shape=(n_categories+n_letters,))
dynamic_weight_generator = DynamicTorchWeightGenerator.from_target(target_network, EMBEDDING_DIM, NUM_EMBEDDINGS, input_shape=(n_categories+n_letters,))

In [171]:
dynamic_hypernetwork = TorchHyperNetwork(
                                input_shape=((1, n_categories), (1, n_letters), (1, 128)),
                                target_network=target_network,
                                embedding_module=dynamic_embedding_module,
                                weight_generator=dynamic_weight_generator
                            )
pytorch_total_params = sum(p.numel() for p in dynamic_hypernetwork.parameters() if p.requires_grad)
pytorch_total_params

32945

In [172]:
device = torch.device('cuda')
dynamic_hypernetwork = dynamic_hypernetwork.to(device)

In [173]:
import torch.nn as nn

criterion = nn.NLLLoss()

def train_dynamic_hyper_rnn_step(dynamic_hyper_rnn, optimizer, category_tensor, input_line_tensor, target_line_tensor):
    target_line_tensor = target_line_tensor.unsqueeze(-1).to(dynamic_hyper_rnn.device)
    hidden = target_network.initHidden().to(dynamic_hyper_rnn.device)
    hyper_hidden = dynamic_hyper_rnn.embedding_module.initHidden()

    optimizer.zero_grad()

    loss = 0

    for i in range(input_line_tensor.size(0)):
        out, _, embedding_output = dynamic_hyper_rnn(inp=(category_tensor.to(dynamic_hyper_rnn.device), input_line_tensor[i].to(dynamic_hyper_rnn.device), hidden), embedding_module_kwargs={"hidden_state":hyper_hidden})
        _, hyper_hidden = embedding_output
        output, hidden = out
        l = criterion(output, target_line_tensor[i])
        loss += l

    loss.backward()
    torch.nn.utils.clip_grad_norm_(dynamic_hyper_rnn.parameters(), 10.0)
    optimizer.step()

    grad_dict = {}
    for n, W in dynamic_hyper_rnn.named_parameters():
        if W.grad is not None:
            grad_dict["{}_grad".format(n)] = float(torch.sum(W.grad).item())

    # for p in rnn.parameters():
    #     p.data.add_(p.grad.data, alpha=-learning_rate)

    return output, {"loss":loss.item() / input_line_tensor.size(0), **grad_dict}

In [174]:
learning_rate = 1e-4

train(hypernet=dynamic_hypernetwork, train_iter_fn=train_dynamic_hyper_rnn_step, lr=learning_rate, n_iters=100000)

Follow tensorboard logs with: python -m tensorboard.main --logdir '/home/shyam/Code/hyper-nn/notebooks/torch/tensorboard_logs/HyperRNN_2022-03-19 01:33:49.362891'


Loss: 2.558781147003174:  35%|███▌      | 35088/100000 [09:54<18:19, 59.06it/s]  


KeyboardInterrupt: 

In [175]:
dynamic_hypernetwork = dynamic_hypernetwork.to(torch.device('cpu'))

In [177]:
max_length = 20

# Sample from a category and starting letter
def sample(category, start_letter='A'):
    with torch.no_grad():  # no need to track history in sampling
        category_tensor = categoryTensor(category)
        input = inputTensor(start_letter)
        hidden = target_network.initHidden()
        hyper_hidden = dynamic_hypernetwork.embedding_module.initHidden()

        output_name = start_letter

        hidden_states = []
        for i in range(max_length):
            out, _, embedding_output = dynamic_hypernetwork(inp=(category_tensor, input[0], hidden), embedding_module_kwargs={"hidden_state":hyper_hidden})
            embedding, hyper_hidden = embedding_output
            output, hidden = out
            topv, topi = output.topk(1)
            topi = topi[0][0]
            if topi == n_letters - 1:
                break
            else:
                letter = all_letters[topi]
                output_name += letter
            input = inputTensor(letter)

        return output_name

# Get multiple samples from one category and multiple starting letters
def samples(category, start_letters='ABC'):
    for start_letter in start_letters:
        print(sample(category, start_letter))

samples('Russian', 'RUS')

samples('German', 'GER')

samples('Spanish', 'SPA')

samples('Chinese', 'CHI')

Rovakhin
Uellov
Shintis
Garen
Eung
Roner
Salera
Parera
Alana
Chan
Han
Ion
