#### Tutorial adapted from https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html

In [44]:
# uncomment this to enable jax gpu preallocation, might lead to memory issues

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [45]:
from __future__ import unicode_literals, print_function, division
from io import open
import glob
import os
import unicodedata
import string

all_letters = string.ascii_letters + " .,;'-"
n_letters = len(all_letters) + 1 # Plus EOS marker

def findFiles(path): return glob.glob(path)

# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )

# Read a file and split into lines
def readLines(filename):
    with open(filename, encoding='utf-8') as some_file:
        return [unicodeToAscii(line.strip()) for line in some_file]

# Build the category_lines dictionary, a list of lines per category
category_lines = {}
all_categories = []
for filename in findFiles('../names/names/*.txt'):
    category = os.path.splitext(os.path.basename(filename))[0]
    all_categories.append(category)
    lines = readLines(filename)
    category_lines[category] = lines

n_categories = len(all_categories)

if n_categories == 0:
    raise RuntimeError('Data not found. Make sure that you downloaded data '
        'from https://download.pytorch.org/tutorial/data.zip and extract it to '
        'the current directory.')

print('# categories:', n_categories, all_categories)
print(unicodeToAscii("O'Néàl"))

# categories: 18 ['Portuguese', 'Czech', 'Korean', 'Arabic', 'English', 'Russian', 'German', 'Spanish', 'Vietnamese', 'Polish', 'Irish', 'Japanese', 'French', 'Scottish', 'Greek', 'Chinese', 'Italian', 'Dutch']
O'Neal


In [46]:
from typing import Sequence, List, Tuple, Optional

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

In [47]:
from functools import partial

class TargetRNN(nn.Module):

    output_size: int

    @nn.compact
    def __call__(self, hidden, inp_tuple, train: bool = False):
        category, inp = inp_tuple
        x = jnp.concatenate((category, inp), axis=-1)
        hidden, _ = nn.GRUCell()(hidden, x)
        output = nn.Dense(self.output_size)(hidden)
        output = nn.Dropout(0.1, deterministic = not train)(output)
        output = nn.softmax(output, axis = -1)
        return hidden, (category, output)

Training

In [48]:
import jax
import jax.numpy as jnp

In [49]:
import random

# Random item from a list
def randomChoice(l):
    return l[random.randint(0, len(l) - 1)]

# Get a random category and random line from that category
def randomTrainingPair():
    category = randomChoice(all_categories)
    line = randomChoice(category_lines[category])
    return category, line

# One-hot vector for category
def categoryTensor(category):
    li = all_categories.index(category)
    tensor = np.zeros((1, n_categories))
    tensor[0][li] = 1
    return jnp.array(tensor)

# One-hot matrix of first to last letters (not including EOS) for input
def inputTensor(line):
    tensor = np.zeros((len(line), 1, n_letters))
    for li in range(len(line)):
        letter = line[li]
        tensor[li][0][all_letters.find(letter)] = 1
    return jnp.array(tensor)

# LongTensor of second letter to end (EOS) for target
def targetTensor(line):
    letter_indexes = [all_letters.find(line[li]) for li in range(1, len(line))]
    letter_indexes.append(n_letters - 1) # EOS
    return jnp.array(letter_indexes)

# Make category, input, and target tensors from a random category, line pair
def randomTrainingExample():
    category, line = randomTrainingPair()
    category_tensor = categoryTensor(category)
    input_line_tensor = inputTensor(line)
    target_line_tensor = targetTensor(line)
    return category_tensor, input_line_tensor, target_line_tensor

In [50]:
INPUT_DIMS = n_letters
OUTPUT_SIZE = n_letters
N_CATEGORIES = n_categories
N_HIDDEN = 128

def init_hidden(n_hidden: int, batch_size: int = 1):
    return jnp.zeros((batch_size, n_hidden))

In [51]:
import optax
from flax.training import train_state  # Useful dataclass to keep train state

def create_train_state(rng, net, learning_rate, input_shape):
    """Creates initial `TrainState`."""
    params = net.init(rng, jnp.ones((1,N_CATEGORIES)), jnp.ones((1,INPUT_DIMS)), jnp.ones((1,N_HIDDEN)), False)['params']
    tx = optax.chain(
        optax.clip_by_global_norm(10.0),
        optax.adam(learning_rate),
    )
    return train_state.TrainState.create(
        apply_fn=net.apply, params=params, tx=tx)

In [52]:
from einops import repeat

In [53]:
def cross_entropy_loss(predictions, label):
    return -jnp.log(predictions[label])

In [166]:
import functools
from einops import repeat

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def rnn_forward(apply_fn, params, category_tensor, input_line_tensor, hidden):
    category_tensor = repeat(category_tensor, 'b l -> (repeat b) l', repeat=input_line_tensor.shape[0])
    _, predictions = jax.lax.scan(lambda hidden, inp: apply_fn({"params":params}, hidden, inp),
            jnp.array(hidden),
            (category_tensor, np.squeeze(input_line_tensor)),
    )
    return predictions[-1]

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def train_rnn(apply_fn, state, category_tensor, input_line_tensor, hidden, targets):
    # category_tensor = repeat(category_tensor, 'b l -> (repeat b) l', repeat=input_line_tensor.shape[0])

    def loss_fn(params):
        predictions = []
        inp = np.squeeze(input_line_tensor)
        length = np.squeeze(input_line_tensor).shape[0]
        hidden_state = jnp.array(hidden)
        for i in range(length):
            hidden_state, output = apply_fn({"params":params}, hidden_state, (category_tensor, inp[i:i+1]))
            predictions.append(np.squeeze(output[1]))
        predictions = jnp.stack(predictions)

        # _, predictions = jax.lax.scan(lambda hidden, inp: apply_fn({"params":params}, hidden, inp),
        #         jnp.array(hidden),
        #         (category_tensor, np.squeeze(input_line_tensor))
        # )
        return jnp.mean(jax.vmap(cross_entropy_loss)(jnp.squeeze(predictions), targets))

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    return state.apply_gradients(grads=grads), loss, grads

In [167]:
import os
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

def get_tensorboard_logger(
    experiment_name: str, base_log_path: str = "tensorboard_logs"
):
    log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: python -m tensorboard.main --logdir '{}'".format(full_log_path)
    )
    return train_writer


In [168]:
import optax
from flax.training import train_state  # Useful dataclass to keep train state

def create_train_state(rng, model, learning_rate):
    """Creates initial `TrainState`."""
    params = model.init(rng, jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES))), False)['params']
    tx = optax.chain(
        optax.clip_by_global_norm(10.0),
        optax.adam(learning_rate),
    )
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx
    )

In [169]:
import pandas as pd
import tqdm

def flatten(d):
    df = pd.json_normalize(d, sep='_')
    return df.to_dict(orient='records')[0]

def train_static_rnn(
    num_epochs,
    model,
    seed: int = 0,
    lr: float = 0.0001,
    ):
    writer = get_tensorboard_logger("RNNCHAR")
    rng = jax.random.PRNGKey(seed)
    state = create_train_state(rng, model, lr)

    bar = tqdm.tqdm(np.arange(num_epochs))

    try:
        for i in bar:
            category_tensor, input_line_tensor, target_line_tensor = randomTrainingExample()
            state, loss, grads = train_rnn(model.apply, state, category_tensor, input_line_tensor, init_hidden(N_HIDDEN), target_line_tensor)
            grad_dict = {k:dict(grads[k]) for k in grads.keys()}
            # grad_dict = flatten(grad_dict)

            # grad_dict = {k: {kk: np.sum(vv).item() for kk, vv in v.items()}
            #             for k, v in grad_dict.items()}
            # grad_dict 

            metrics = {"loss":loss.item(), }
            for key in metrics:
                writer.add_scalar(key, metrics[key], i)

            bar.set_description('Loss: {}'.format(loss.item()))
        return model, state
    except KeyboardInterrupt as e:
        return model, state

In [170]:
rnn = TargetRNN(OUTPUT_SIZE)

In [171]:
model, state = train_static_rnn(100000, rnn)

Follow tensorboard logs with: python -m tensorboard.main --logdir '/home/shyam/Code/hyper-nn/notebooks/jax/tensorboard_logs/RNNCHAR_2022-03-26 20:56:43.008387'


Loss: 1.655625343322754: 100%|██████████| 100000/100000 [06:45<00:00, 246.31it/s] 


In [172]:
max_length = 20

# Sample from a category and starting letter
def sample(category, start_letter='A'):
    category_tensor = categoryTensor(category)
    input = inputTensor(start_letter)
    hidden = init_hidden(N_HIDDEN)

    output_name = start_letter

    for i in range(max_length):
        hidden, out = rnn.apply({"params":state.params}, hidden, (category_tensor, input[0]))
        category, output = out
        # topv, topi = output.topk(1)
        topi = np.argmax(np.squeeze(output), axis=-1)
        # topi = topi[0][0]
        if topi == n_letters - 1:
            break
        else:
            letter = all_letters[topi]
            output_name += letter
        input = inputTensor(letter)

    return output_name

# Get multiple samples from one category and multiple starting letters
def samples(category, start_letters='ABC'):
    for start_letter in start_letters:
        print(sample(category, start_letter))

samples('Russian', 'RUS')

samples('German', 'GER')

samples('Spanish', 'SPA')

samples('Chinese', 'CHI')

Roskin
Urinov
Shavanin
Garter
Ester
Romer
Sala
Parez
Alana
Chan
Han
Ino


### Hypernetwork

In [173]:
from typing import Optional, Any
import jax.numpy as jnp

from hypernn.jax.embedding_module import FlaxEmbeddingModule
from hypernn.jax.weight_generator import FlaxWeightGenerator
from hypernn.jax.hypernet import FlaxHyperNetwork


class CustomFlaxEmbeddingModule(FlaxEmbeddingModule):
    def setup(self):
        self.embedding = nn.Embed(self.num_embeddings, self.embedding_dim)

    def __call__(self, inp: Optional[Any] = None):
        indices = jnp.arange(0, self.num_embeddings)
        return self.embedding(indices)

class CustomFlaxWeightGenerator(FlaxWeightGenerator):
    def setup(self):
        self.dense1 = nn.Dense(32)
        self.dense2 = nn.Dense(self.hidden_dim)

    def __call__(self, embedding: jnp.array, inp: Optional[Any] = None):
        x = self.dense1(embedding)
        x = nn.relu(x)
        x = self.dense2(x)
        return x


In [174]:
hyper = FlaxHyperNetwork.from_target(
    target_network=rnn,
    inputs=[jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES)))],
    embedding_module=CustomFlaxEmbeddingModule,
    weight_generator=CustomFlaxWeightGenerator,
    embedding_dim = 32,
    num_embeddings = 512
)

In [175]:
rng = jax.random.PRNGKey(0)
params = hyper.init(rng, inp=[jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES)))])

In [176]:
hyper_generated_params = hyper.apply(params, method=hyper.generate_params)

In [177]:
generated_params, embeddings = hyper_generated_params

In [178]:
generated_params.shape

(87210,)

In [179]:
out, generated_params, embeddings = hyper.apply(params, inp=[jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES)))], generated_params=generated_params)

#### Static Hypernetwork

In [185]:
import functools
from einops import repeat
import pandas as pd
import tqdm

import optax
from flax.training import train_state  # Useful dataclass to keep train state

def create_hyper_train_state(rng, model, learning_rate):
    """Creates initial `TrainState`."""
    params = model.init(rng, [jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES))), False])['params']
    tx = optax.chain(
        optax.clip_by_global_norm(10.0),
        optax.adam(learning_rate),
    )
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx
    )

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def hyper_rnn_forward(apply_fn, params, category_tensor, input_line_tensor, hidden):
    category_tensor = repeat(category_tensor, 'b l -> (repeat b) l', repeat=input_line_tensor.shape[0])
    _, predictions = jax.lax.scan(lambda hidden, inp: apply_fn({"params":params}, hidden, inp),
            jnp.array(hidden),
            (category_tensor, np.squeeze(input_line_tensor)),
    )
    return predictions[-1]

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def train_hyper_rnn(apply_fn, state, category_tensor, input_line_tensor, hidden, targets):
    # category_tensor = repeat(category_tensor, 'b l -> (repeat b) l', repeat=input_line_tensor.shape[0])

    def loss_fn(params):
        predictions = []
        inp = np.squeeze(input_line_tensor)
        length = np.squeeze(input_line_tensor).shape[0]
        hidden_state = jnp.array(hidden)
        for i in range(length):
            out, _, _ = apply_fn({"params":params}, [hidden_state, (category_tensor, inp[i:i+1])])
            hidden_state, output = out
            predictions.append(np.squeeze(output[1]))
        predictions = jnp.stack(predictions)

        # _, predictions = jax.lax.scan(lambda hidden, inp: apply_fn({"params":params}, hidden, inp),
        #         jnp.array(hidden),
        #         (category_tensor, np.squeeze(input_line_tensor))
        # )
        return jnp.mean(jax.vmap(cross_entropy_loss)(jnp.squeeze(predictions), targets))

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    return astate.apply_gradients(grads=grads), loss, grads

def flatten(d):
    df = pd.json_normalize(d, sep='_')
    return df.to_dict(orient='records')[0]

def train_static_hyper_rnn(
    num_epochs,
    model,
    seed: int = 0,
    lr: float = 0.0001,
    ):
    writer = get_tensorboard_logger("HYPER_RNNCHAR")
    rng = jax.random.PRNGKey(seed)
    state = create_hyper_train_state(rng, model, lr)

    bar = tqdm.tqdm(np.arange(num_epochs))

    try:
        for i in bar:
            category_tensor, input_line_tensor, target_line_tensor = randomTrainingExample()
            state, loss, grads = train_hyper_rnn(model.apply, state, category_tensor, input_line_tensor, init_hidden(N_HIDDEN), target_line_tensor)
            grad_dict = {k:dict(grads[k]) for k in grads.keys()}
            # grad_dict = flatten(grad_dict)

            # grad_dict = {k: {kk: np.sum(vv).item() for kk, vv in v.items()}
            #             for k, v in grad_dict.items()}
            # grad_dict 

            metrics = {"loss":loss.item(), }
            for key in metrics:
                writer.add_scalar(key, metrics[key], i)

            bar.set_description('Loss: {}'.format(loss.item()))
        return model, state
    except KeyboardInterrupt as e:
        return model, state

In [186]:
_, hyper_state = train_static_hyper_rnn(100000, hyper)

Follow tensorboard logs with: python -m tensorboard.main --logdir '/home/shyam/Code/hyper-nn/notebooks/jax/tensorboard_logs/HYPER_RNNCHAR_2022-03-26 21:24:44.815926'


Loss: 1.7153233289718628: 100%|██████████| 100000/100000 [09:36<00:00, 173.53it/s]


In [188]:
max_length = 20

# Sample from a category and starting letter
def sample(category, start_letter='A'):
    category_tensor = categoryTensor(category)
    input = inputTensor(start_letter)
    hidden = init_hidden(N_HIDDEN)

    output_name = start_letter

    for i in range(max_length):
        hyper_out, _, _ = hyper.apply({"params":hyper_state.params}, inp=[hidden, (category_tensor, input[0])])
        hidden, out = hyper_out
        category, output = out
        # topv, topi = output.topk(1)
        topi = np.argmax(np.squeeze(output), axis=-1)
        # topi = topi[0][0]
        if topi == n_letters - 1:
            break
        else:
            letter = all_letters[topi]
            output_name += letter
        input = inputTensor(letter)

    return output_name

# Get multiple samples from one category and multiple starting letters
def samples(category, start_letters='ABC'):
    for start_letter in start_letters:
        print(sample(category, start_letter))

samples('Russian', 'RUS')

samples('German', 'GER')

samples('Spanish', 'SPA')

samples('Chinese', 'CHI')

Raiz
Uartov
Sakov
Gante
Ester
Romer
Sala
Para
Arano
Chang
Han
Ii


### Dynamic HyperNetwork

In [191]:
np.zeros((1,3)).reshape(3,1)

array([[0.],
       [0.],
       [0.]])

In [221]:
from typing import Optional, Any
import jax.numpy as jnp

from hypernn.jax.embedding_module import FlaxEmbeddingModule
from hypernn.jax.weight_generator import FlaxWeightGenerator
from hypernn.jax.hypernet import FlaxHyperNetwork


class DynamicFlaxEmbeddingModule(FlaxEmbeddingModule):
    def setup(self):
        self.embedding = nn.Embed(self.num_embeddings, self.embedding_dim)
        self.gru = nn.GRUCell()

    def __call__(self, inp: Optional[Any] = None, hidden: Optional[np.array] = None):
        inp = inp[1]
        if hidden is None:
            hidden = jnp.zeros((1, self.num_embeddings))
        category, input_tensor = inp
        x = jnp.concatenate((category, input_tensor), axis=-1)
        hidden, _ = self.gru(hidden, x)
        indices = jnp.arange(0, self.num_embeddings)
        return self.embedding(indices)*hidden.reshape(self.num_embeddings, 1), hidden

class DynamicFlaxWeightGenerator(FlaxWeightGenerator):
    def setup(self):
        self.dense1 = nn.Dense(32)
        self.dense2 = nn.Dense(self.hidden_dim)

    def __call__(self, embedding: jnp.array, inp: Optional[Any] = None):
        x = self.dense1(embedding[0])
        x = nn.relu(x)
        x = self.dense2(x)
        return x


In [222]:
import functools
from einops import repeat
import pandas as pd
import tqdm

import optax
from flax.training import train_state  # Useful dataclass to keep train state

def create_hyper_train_state(rng, model, learning_rate):
    """Creates initial `TrainState`."""
    params = model.init(rng, [jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES))), False])['params']
    tx = optax.chain(
        optax.clip_by_global_norm(10.0),
        optax.adam(learning_rate),
    )
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx
    )

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def hyper_rnn_forward(apply_fn, params, category_tensor, input_line_tensor, hidden):
    category_tensor = repeat(category_tensor, 'b l -> (repeat b) l', repeat=input_line_tensor.shape[0])
    _, predictions = jax.lax.scan(lambda hidden, inp: apply_fn({"params":params}, hidden, inp),
            jnp.array(hidden),
            (category_tensor, np.squeeze(input_line_tensor)),
    )
    return predictions[-1]

@functools.partial(jax.jit, static_argnames=('apply_fn'))
def train_hyper_rnn(apply_fn, state, category_tensor, input_line_tensor, hidden, targets):
    # category_tensor = repeat(category_tensor, 'b l -> (repeat b) l', repeat=input_line_tensor.shape[0])

    def loss_fn(params):
        predictions = []
        inp = np.squeeze(input_line_tensor)
        length = np.squeeze(input_line_tensor).shape[0]
        hidden_state = jnp.array(hidden)
        hyper_hidden = None
        for i in range(length):
            out, _, embedding_output = apply_fn({"params":params}, [hidden_state, (category_tensor, inp[i:i+1])], embedding_module_kwargs={"hidden":hyper_hidden})
            hyper_hidden = embedding_output[1]
            hidden_state, output = out
            predictions.append(np.squeeze(output[1]))
        predictions = jnp.stack(predictions)

        # _, predictions = jax.lax.scan(lambda hidden, inp: apply_fn({"params":params}, hidden, inp),
        #         jnp.array(hidden),
        #         (category_tensor, np.squeeze(input_line_tensor))
        # )
        return jnp.mean(jax.vmap(cross_entropy_loss)(jnp.squeeze(predictions), targets))

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    return state.apply_gradients(grads=grads), loss, grads

def flatten(d):
    df = pd.json_normalize(d, sep='_')
    return df.to_dict(orient='records')[0]

def train_static_hyper_rnn(
    num_epochs,
    model,
    seed: int = 0,
    lr: float = 0.0001,
    ):
    writer = get_tensorboard_logger("HYPER_RNNCHAR")
    rng = jax.random.PRNGKey(seed)
    state = create_hyper_train_state(rng, model, lr)

    bar = tqdm.tqdm(np.arange(num_epochs))

    try:
        for i in bar:
            category_tensor, input_line_tensor, target_line_tensor = randomTrainingExample()
            state, loss, grads = train_hyper_rnn(model.apply, state, category_tensor, input_line_tensor, init_hidden(N_HIDDEN), target_line_tensor)
            grad_dict = {k:dict(grads[k]) for k in grads.keys()}
            # grad_dict = flatten(grad_dict)

            # grad_dict = {k: {kk: np.sum(vv).item() for kk, vv in v.items()}
            #             for k, v in grad_dict.items()}
            # grad_dict 

            metrics = {"loss":loss.item(), }
            for key in metrics:
                writer.add_scalar(key, metrics[key], i)

            bar.set_description('Loss: {}'.format(loss.item()))
        return model, state
    except KeyboardInterrupt as e:
        return model, state

In [223]:
hyper = FlaxHyperNetwork.from_target(
    target_network=rnn,
    inputs=[jnp.zeros((1, N_HIDDEN)), (jnp.zeros((1, INPUT_DIMS)),  jnp.zeros((1, N_CATEGORIES)))],
    embedding_module=DynamicFlaxEmbeddingModule,
    weight_generator=DynamicFlaxWeightGenerator,
    embedding_dim = 32,
    num_embeddings = 512
)

In [224]:
train_static_hyper_rnn(100000, hyper)

Follow tensorboard logs with: python -m tensorboard.main --logdir '/home/shyam/Code/hyper-nn/notebooks/jax/tensorboard_logs/HYPER_RNNCHAR_2022-03-26 23:15:59.721965'


Loss: 4.0839009284973145:   0%|          | 1/100000 [00:06<172:23:09,  6.21s/it]

In [None]:
from typing import Optional, Any, Tuple
import functools
import torch.nn.functional as F

class DynamicTorchEmbeddingModule(TorchEmbeddingModule):
    def __init__(self, embedding_dim: int, num_embeddings: int, input_shape):
        super().__init__(embedding_dim, num_embeddings)
        self.rnn_hidden_dim = num_embeddings
        self.gru = nn.RNNCell(np.prod(input_shape), num_embeddings)
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)

    def forward(self, inp, hidden_state: Optional[torch.Tensor] = None):
        x = torch.cat(inp[:-1], -1)
        if hidden_state is None:
            hidden_state = torch.zeros(x.size(0), self.rnn_hidden_dim).to(self.device)
        hidden_state = torch.sigmoid(self.gru(x, hidden_state))
        indices = torch.arange(self.num_embeddings).to(self.device)
        embedding = self.embedding(indices)*hidden_state.view(self.num_embeddings, 1)
        return embedding, hidden_state

    def initHidden(self):
        return torch.zeros(1, self.num_embeddings).to(self.device)

class DynamicTorchWeightGenerator(TorchWeightGenerator):
    def __init__(self, embedding_dim: int, num_embeddings: int, hidden_dim: int, input_shape: Optional[Any] = None):
        super().__init__(embedding_dim, num_embeddings, hidden_dim, input_shape)
        self.linear1 = nn.Linear(embedding_dim, 16)
        self.linear2 = nn.Linear(16, hidden_dim)

    def forward(
        self, embedding: Tuple[torch.Tensor, torch.Tensor], inp: Optional[Any] = None
    ) -> torch.Tensor:
        x = self.linear1(embedding[0])
        x = F.relu(x)
        return self.linear2(x).view(-1)

In [None]:
target_network = TargetRNN(n_letters, 128, n_letters)
pytorch_total_params = sum(p.numel() for p in target_network.parameters() if p.requires_grad)
pytorch_total_params

In [None]:
EMBEDDING_DIM = 8
NUM_EMBEDDINGS = 64

dynamic_embedding_module = DynamicTorchEmbeddingModule.from_target(target_network, EMBEDDING_DIM, NUM_EMBEDDINGS, input_shape=(n_categories+n_letters,))
dynamic_weight_generator = DynamicTorchWeightGenerator.from_target(target_network, EMBEDDING_DIM, NUM_EMBEDDINGS, input_shape=(n_categories+n_letters,))

In [None]:
dynamic_hypernetwork = TorchHyperNetwork(
                                input_shape=((1, n_categories), (1, n_letters), (1, 128)),
                                target_network=target_network,
                                embedding_module=dynamic_embedding_module,
                                weight_generator=dynamic_weight_generator
                            )
pytorch_total_params = sum(p.numel() for p in dynamic_hypernetwork.parameters() if p.requires_grad)
pytorch_total_params

In [None]:
device = torch.device('cuda')
dynamic_hypernetwork = dynamic_hypernetwork.to(device)

In [None]:
import torch.nn as nn

criterion = nn.NLLLoss()

def train_dynamic_hyper_rnn_step(dynamic_hyper_rnn, optimizer, category_tensor, input_line_tensor, target_line_tensor):
    target_line_tensor = target_line_tensor.unsqueeze(-1).to(dynamic_hyper_rnn.device)
    hidden = target_network.initHidden().to(dynamic_hyper_rnn.device)
    hyper_hidden = dynamic_hyper_rnn.embedding_module.initHidden()

    optimizer.zero_grad()

    loss = 0

    for i in range(input_line_tensor.size(0)):
        out, _, embedding_output = dynamic_hyper_rnn(inp=(category_tensor.to(dynamic_hyper_rnn.device), input_line_tensor[i].to(dynamic_hyper_rnn.device), hidden), embedding_module_kwargs={"hidden_state":hyper_hidden})
        _, hyper_hidden = embedding_output
        output, hidden = out
        l = criterion(output, target_line_tensor[i])
        loss += l

    loss.backward()
    torch.nn.utils.clip_grad_norm_(dynamic_hyper_rnn.parameters(), 10.0)
    optimizer.step()

    grad_dict = {}
    for n, W in dynamic_hyper_rnn.named_parameters():
        if W.grad is not None:
            grad_dict["{}_grad".format(n)] = float(torch.sum(W.grad).item())

    # for p in rnn.parameters():
    #     p.data.add_(p.grad.data, alpha=-learning_rate)

    return output, {"loss":loss.item() / input_line_tensor.size(0), **grad_dict}

In [None]:
learning_rate = 1e-4

train(hypernet=dynamic_hypernetwork, train_iter_fn=train_dynamic_hyper_rnn_step, lr=learning_rate, n_iters=100000)

In [None]:
dynamic_hypernetwork = dynamic_hypernetwork.to(torch.device('cpu'))

In [None]:
max_length = 20

# Sample from a category and starting letter
def sample(category, start_letter='A'):
    with torch.no_grad():  # no need to track history in sampling
        category_tensor = categoryTensor(category)
        input = inputTensor(start_letter)
        hidden = target_network.initHidden()
        hyper_hidden = dynamic_hypernetwork.embedding_module.initHidden()

        output_name = start_letter

        hidden_states = []
        for i in range(max_length):
            out, _, embedding_output = dynamic_hypernetwork(inp=(category_tensor, input[0], hidden), embedding_module_kwargs={"hidden_state":hyper_hidden})
            embedding, hyper_hidden = embedding_output
            output, hidden = out
            topv, topi = output.topk(1)
            topi = topi[0][0]
            if topi == n_letters - 1:
                break
            else:
                letter = all_letters[topi]
                output_name += letter
            input = inputTensor(letter)

        return output_name

# Get multiple samples from one category and multiple starting letters
def samples(category, start_letters='ABC'):
    for start_letter in start_letters:
        print(sample(category, start_letter))

samples('Russian', 'RUS')

samples('German', 'GER')

samples('Spanish', 'SPA')

samples('Chinese', 'CHI')