# RNN / LSTM / GRU

## setup

In [1]:
from datetime import datetime
from torchvision import datasets, transforms
import os
import wandb
import requests
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from tqdm.notebook import tqdm

device = 'cuda' if t.cuda.is_available() else 'cpu'

## utils

In [2]:
# download alice in wonderland
path = f'data/alice.txt'
if not os.path.exists(path):
    url = 'https://www.gutenberg.org/cache/epub/11/pg11.txt'
    book = requests.get(url).content
    book = book.decode('ascii', 'ignore')
    with open(path, 'w') as file:
        file.write(book)

with open(path, 'r') as file:
    book = file.read()
assert (start := book.find('\nDown the Rabbit-Hole')) != -1
assert (end := book.find('THE END')) != -1
book = book[start: end]

In [3]:
vocab = set(book)
d_vocab = len(vocab)
d_hidden = 100
d_batch = 10000
atoi = {a: i for i, a in enumerate(sorted(vocab))}
itoa = {i: a for a, i in atoi.items()}

In [4]:
def to_dataloader(text, seq_len=25, batch_size=d_batch):
    x = [text[i:i+seq_len] for i in range(0, len(text)-seq_len-1, seq_len // 2)]
    y = [text[i+1:i+seq_len+1] for i in range(0, len(text)-seq_len-1, seq_len // 2)]
    x = t.tensor([[atoi[a] for a in s] for s in x])
    y = t.tensor([[atoi[a] for a in s] for s in y])
    dataset = TensorDataset(x, y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataloader = to_dataloader(book, 50)

In [5]:
def train(model, dataloader, epochs=2001, d_vocab=d_vocab, opt=None, lr=3e-4, filename='', wnb=True):
    model.train()
    model = model.to(device)
    if opt is None: opt = t.optim.Adam(model.parameters(), lr=lr)
    if wnb: wandb.init(project=filename)
    for epoch in tqdm(range(epochs)):
        for xs, ys in dataloader:
            out = model(F.one_hot(xs, num_classes=d_vocab).float().to(device))
            loss = F.cross_entropy(out.permute(0, 2, 1), ys.to(device))
            opt.zero_grad()
            loss.backward()
            opt.step()
        if wnb:
            wandb.log({'loss': loss.item()})
        if wnb and epoch % 100 == 0:
            wandb.log({'sample_html': wandb.Html(f'<p>{model.sample(d_sample=200)}</p>')})
        if epoch % 50 == 0:
            print(f'loss={loss.item():.4f}')
        if epoch % 1000 == 0:
            print(model.sample())
        if epoch % 1000 == 0:
            t.save(model.state_dict(), f'weights/{filename}_{datetime.now().strftime("%Y-%m-%d_%Hh%M")}.pt')
    if wnb: wandb.finish()

## Recurrent Neural Networks (RNN)

### model

In [6]:
class RNN(nn.Module):
    def __init__(self, d_in=10, d_hidden=20, d_out=30):
        super().__init__()
        self.d_hidden = d_hidden
        self.embed = nn.Linear(d_in, d_hidden)
        self.hidden = nn.Linear(d_hidden, d_hidden)
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, memory=None, return_memory=False):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if memory is None: memory = t.zeros(batch, self.hidden.in_features, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            memory = F.tanh(self.embed(x) + self.hidden(memory))
            outs.append(self.unembed(memory))
        if return_memory:
            return t.stack(outs, dim=1), memory
        return t.stack(outs, dim=1)

    @t.no_grad()
    def sample(self, seed='A', d_sample=100):
        text = seed
        x = F.one_hot(t.tensor([atoi[seed]]), num_classes=d_vocab).float().to(device)
        h_prev = t.zeros(1, self.d_hidden, device=x.device)
        while len(text) < d_sample:
            h_prev = F.tanh(self.embed(x) + self.hidden(h_prev))
            out = self.unembed(h_prev)
            probs = out[0].softmax(-1)
            next_sample = t.multinomial(probs, num_samples=1)
            text += itoa[next_sample.item()]
            x = F.one_hot(next_sample, num_classes=d_vocab).float().to(device)
        return text

rnn = RNN(d_vocab, d_hidden, d_vocab).to(device)
rnn.sample()

'A;.]lyshH[ryyQq\nGUFHYQ_*d, yE!VGS:UwiMui_LHcPI-J-RKuM e?S\n,yi*t:hGDZ!-ys*Rngmz)XYLT.IM:AvKjhKx[pLReE'

### train

In [7]:
# train(rnn, dataloader, filename='rnn', epochs=50)

In [8]:
# print(rnn.sample(d_sample=2000))

## Long Short-Term Memory (LSTM)

### model

In [9]:
class LSTMCell(nn.Module):
    def __init__(self, d_in, d_hidden):
        super().__init__()
        self.d_hidden = d_hidden
        self.W_f = nn.Linear(d_in + d_hidden, d_hidden)  # forget gate
        self.W_i = nn.Linear(d_in + d_hidden, d_hidden)  # input gate
        self.W_c = nn.Linear(d_in + d_hidden, d_hidden)  # cell state update
        self.W_o = nn.Linear(d_in + d_hidden, d_hidden)  # output gate

    def forward(self, x, h_prev, c_prev):
        x = t.cat((x, h_prev), dim=1)
        # handle long-term memory `C`
        f_gate = t.sigmoid(self.W_f(x))
        i_gate = t.sigmoid(self.W_i(x))
        c_update = t.tanh(self.W_c(x))
        c_prev = f_gate * c_prev + i_gate * c_update
        # handle short-term memory `h`
        o_gate = t.sigmoid(self.W_o(x))
        h_prev = o_gate * t.tanh(c_prev)
        return h_prev, c_prev

class LSTM(nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.d_hidden = d_hidden
        self.lstm_cell = LSTMCell(d_in, d_hidden)
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, h_prev=None, c_prev=None):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if h_prev is None: h_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        if c_prev is None: c_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            h_prev, c_prev = self.lstm_cell(x, h_prev, c_prev)
            outs.append(self.unembed(h_prev))
        return t.stack(outs, dim=1)

    @t.no_grad()
    def sample(self, seed='A', d_sample=100):
        text = seed
        x = F.one_hot(t.tensor([atoi[seed]]), num_classes=d_vocab).float().to(device)
        h_prev = t.zeros(1, self.d_hidden, device=x.device)
        c_prev = t.zeros(1, self.d_hidden, device=x.device)
        while len(text) < d_sample:
            h_prev, c_prev = self.lstm_cell(x, h_prev, c_prev)
            out = self.unembed(h_prev)
            probs = out[0].softmax(-1)
            next_sample = t.multinomial(probs, num_samples=1)
            text += itoa[next_sample.item()]
            x = F.one_hot(next_sample, num_classes=d_vocab).float().to(device)
        return text

lstm = LSTM(d_vocab, d_hidden, d_vocab).to(device)
lstm.sample()

'ADyCsofivY_DRhXFLBOUpi[WXw;KEZ,Yl_wKBRGp-ZC(o:!z:\npo[O_FhIVjy\nf!CqRH.(B\nHm\nYWnwe?_UFD?ljPuD)-d(mg!lZ'

### train

In [10]:
# train(lstm, dataloader, filename='lstm')

In [None]:
# print(lstm.sample(d_sample=2000))

## Gated Recurrent Unit (GRU)

### model

In [11]:
class GRUCell(nn.Module):
    def __init__(self, d_in, d_hidden):
        super().__init__()
        self.d_hidden = d_hidden
        self.W_r = nn.Linear(d_in + d_hidden, d_hidden)  # reset gate
        self.W_z = nn.Linear(d_in + d_hidden, d_hidden)  # update gate
        self.W_h = nn.Linear(d_in + d_hidden, d_hidden)  # hidden state update

    def forward(self, x, h_prev):
        cat = t.cat((x, h_prev), dim=1)
        r_gate = t.sigmoid(self.W_r(cat))
        z_gate = t.sigmoid(self.W_z(cat))
        h_candidate = t.tanh(self.W_h(t.cat((x, r_gate * h_prev), dim=1)))
        h_prev = (1 - z_gate) * h_prev + z_gate * h_candidate
        return h_prev

class GRU(nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.d_hidden = d_hidden
        self.gru_cell = GRUCell(d_in, d_hidden)
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, h_prev=None):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if h_prev is None: h_prev = t.zeros(batch, self.d_hidden, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            h_prev = self.gru_cell(x, h_prev)
            outs.append(self.unembed(h_prev))
        return t.stack(outs, dim=1)

    @t.no_grad()
    def sample(self, seed='A', d_sample=100):
        text = seed
        x = F.one_hot(t.tensor([atoi[seed]]), num_classes=d_vocab).float().to(device)
        h_prev = t.zeros(1, self.d_hidden, device=x.device)
        while len(text) < d_sample:
            h_prev = self.gru_cell(x, h_prev)
            out = self.unembed(h_prev)
            probs = out[0].softmax(-1)
            next_sample = t.multinomial(probs, num_samples=1)
            text += itoa[next_sample.item()]
            x = F.one_hot(next_sample, num_classes=d_vocab).float().to(device)
        return text

gru = GRU(d_vocab, d_hidden, d_vocab).to(device)
gru.sample()

'AGj_;?sanBE-dBoWZ!*ZT]yfKNXFEYSvupGHs]GfOO-dovxhFDpjyNdax.IY GTyTVdmil,,SMNF?qKYca\ns;)OfB*SrfOqJNDh?'

### train

In [None]:
# train(gru, dataloader, filename='gru', epochs=50)

In [None]:
# print(gru.sample(d_sample=2000))

## MNIST

In [12]:
batch_size=6400
vocab_size=10
hidden_size=100

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST('./mnist/', download=True, train=True, transform=transform)
testset = datasets.MNIST('./mnist/', download=True, train=False, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=True)

In [23]:
@t.no_grad()
def accuracy(model, dataloader=testloader):
    model.eval()
    correct, total = 0, 0
    for xs, ys in dataloader:
        out = model(xs.squeeze().to(device))[:, -1]
        correct += (out.argmax(-1) == ys.to(device)).sum()
        total += len(xs)
    model.train()
    return correct / total

In [24]:
def train_mnist(model, dataloader, epochs=2001, d_vocab=d_vocab, opt=None, lr=3e-4, filename='', wnb=True):
    model.train()
    model = model.to(device)
    if opt is None: opt = t.optim.Adam(model.parameters(), lr=lr)
    if wnb: wandb.init(project=filename)
    for epoch in tqdm(range(epochs)):
        for xs, ys in dataloader:
            out = model(xs.squeeze().to(device))[:, -1]
            loss = F.cross_entropy(out, ys.to(device))
            opt.zero_grad()
            loss.backward()
            opt.step()
        if wnb:
            wandb.log({'loss': loss.item(), 'accuracy': accuracy(model)})
        if epoch % 50 == 0:
            print(f'loss={loss.item():.4f} accuracy={accuracy(model)}')
        if epoch % 10000 == 0:
            t.save(model.state_dict(), f'weights/{filename}_{datetime.now().strftime("%Y-%m-%d_%Hh%M")}.pt')
    if wnb: wandb.finish()

In [None]:
lstm_mnist = LSTM(28, d_hidden, vocab_size).to(device)
train_mnist(lstm_mnist, trainloader, epochs=20)
# gru_mnist = GRU(28, hidden_size, vocab_size).to(device)
# train_mnist(gru_mnist, trainloader, epochs=30, wnb=False)

## Stacked LSTM

### model

In [7]:
class LSTMCell(nn.Module):
    def __init__(self, d_in, d_hidden):
        super().__init__()
        self.d_hidden = d_hidden
        self.W_f = nn.Linear(d_in + d_hidden, d_hidden)  # forget gate
        self.W_i = nn.Linear(d_in + d_hidden, d_hidden)  # input gate
        self.W_c = nn.Linear(d_in + d_hidden, d_hidden)  # cell state update
        self.W_o = nn.Linear(d_in + d_hidden, d_hidden)  # output gate

    def forward(self, x, h_prev, c_prev):
        x = t.cat((x, h_prev), dim=1)
        # handle long-term memory `C`
        f_gate = t.sigmoid(self.W_f(x))
        i_gate = t.sigmoid(self.W_i(x))
        c_update = t.tanh(self.W_c(x))
        c_prev = f_gate * c_prev + i_gate * c_update
        # handle short-term memory `h`
        o_gate = t.sigmoid(self.W_o(x))
        h_prev = o_gate * t.tanh(c_prev)
        return h_prev, c_prev

class StackedLSTM(nn.Module):
    def __init__(self, d_in, d_hidden, d_out, d_layers):
        super().__init__()
        self.d_hidden = d_hidden
        self.d_layers = d_layers
        self.lstm_cells = nn.ModuleList([LSTMCell(d_in if l == 0 else d_hidden, d_hidden) for l in range(d_layers)])
        self.unembed = nn.Linear(d_hidden, d_out)

    def forward(self, xs, h_prev=None, c_prev=None):
        # xs: (batch, d_context, d_vocab)
        batch, d_context, _ = xs.shape
        outs = []
        if h_prev is None: h_prev = t.zeros(self.d_layers, batch, self.d_hidden, device=xs.device)
        if c_prev is None: c_prev = t.zeros(self.d_layers, batch, self.d_hidden, device=xs.device)
        for i in range(d_context):
            x = xs[:, i]
            h_next, c_next = [], []
            for lstm_cell, h, c in zip(self.lstm_cells, h_prev, c_prev):
                h, c = lstm_cell(x, h, c)
                h_next.append(h)
                c_next.append(c)
                x = h
            outs.append(self.unembed(h))
            h_prev = t.stack(h_next)
            c_prev = t.stack(c_next)
        return t.stack(outs, dim=1)

    @t.no_grad()
    def sample(self, seed='A', d_sample=100):
        text = seed
        h_prev = t.zeros(self.d_layers, 1, self.d_hidden, device=device)
        c_prev = t.zeros(self.d_layers, 1, self.d_hidden, device=device)
        while len(text) < d_sample:
            x = F.one_hot(t.tensor([atoi[text[-1]]]), num_classes=d_vocab).float().to(device)
            h_next, c_next = [], []
            for lstm_cell, h, c in zip(self.lstm_cells, h_prev, c_prev):
                h, c = lstm_cell(x, h, c)
                h_next.append(h)
                c_next.append(c)
                x = h
            h_prev = t.stack(h_next)
            c_prev = t.stack(c_next)
            out = self.unembed(h)
            probs = out[0].softmax(-1)
            next_sample = t.multinomial(probs, num_samples=1)
            text += itoa[next_sample.item()]
        return text

stacked_lstm = StackedLSTM(d_vocab, d_hidden, d_vocab, 3).to(device)
stacked_lstm.sample()

"A#]Epf7K*ahT9pN*rDDgPlwBM-iBT#2M9.Z?sqzFu'P*i1FK]u'dEc5z956S\rkW!1-F/weG$,\ns0Spff,[%N?GqdF'U!A\r#atgOR"

### train

In [None]:
train(stacked_lstm, dataloader, filename='stacked_lstm_3_layers_long', epochs=1000000)

In [None]:
# t.save(stacked_lstm.state_dict(), f'weights/stacked_lstm_3_{datetime.now().strftime("%Y-%m-%d_%Hh%M")}.pt')

In [14]:
# t.save(atoi, 'weights/atoi.pt')
# t.save(itoa, 'weights/itoa.pt')

In [11]:
print(stacked_lstm.sample(d_sample=2000))

Alice,
idestions wint, said the Caterpillar angrily, but the Hatter in her
wait the answeret! Healvertation about reason terms of
chimney, he was no rest at find
you do without at once, said the Hatter, ancomemations
though stated up to proxe Ner nestytete of twents, the March Hare said downwar: in the fish as, and Alice
called of the trademark came now, the leavinid Come, lets
before Twong! it consined is streng, strangleth, a goed the Mock Turtle, all crieuding out a mently, freew hey senmed, and cans addatupal.

Hid the Pinge do, astoge, But Im snaits
an extraowed the Duchess began my time).

To bires facesses, I think (she was round the table. Now Im harnt but
thought they were minkey opened tone, she tried her again how when these
llate on brough had you voice hear, I may you neitebeering that _hou__ wrorged Stull poe!

And when the hodeled lead nothing.

The book her they slowly of something hunder on a day draw an a
Projend folles terd alwargo
ang)ss
to its liteyyeray arm way, a

## visualize

### lib

In [32]:
from io import BytesIO
from IPython.display import display, HTML
import base64
import json
import matplotlib.pyplot as plt

In [33]:
def saliency(self, text):
    xs = F.one_hot(t.tensor([[atoi[c] for c in text]]), num_classes=d_vocab).float().to(device)
    xs.requires_grad = True
    _, d_context, _ = xs.shape
    outs = []
    h_prev = t.zeros(self.d_layers, 1, self.d_hidden, device=device)
    c_prev = t.zeros(self.d_layers, 1, self.d_hidden, device=device)
    for i in range(d_context):
        x = xs[:, i]
        h_next, c_next = [], []
        for lstm_cell, h, c in zip(self.lstm_cells, h_prev, c_prev):
            h, c = lstm_cell(x, h, c)
            h_next.append(h)
            c_next.append(c)
            x = h
        outs.append(self.unembed(h))
        h_prev = t.stack(h_next)
        c_prev = t.stack(c_next)
    saliencies = [()] # nothing caused the first token
    for idx, (out, tok) in enumerate(zip(outs, text[1:])):
        tok_idx = atoi[tok]
        loss = -out[0, tok_idx]
        xs.grad = None
        loss.backward(retain_graph=True)
        grads = []
        for i in range(idx + 1):
            inp_idx = atoi[text[i]]
            grads.append(xs.grad[0, i, inp_idx].item())
        saliencies.append(grads)
    return saliencies

def stats(self, text):
    xs = F.one_hot(t.tensor([[atoi[c] for c in text]]), num_classes=d_vocab).float().to(device)
    batch, d_context, _ = xs.shape
    outs = []
    hs = []
    h_prev = t.zeros(self.d_layers, batch, self.d_hidden, device=xs.device)
    c_prev = t.zeros(self.d_layers, batch, self.d_hidden, device=xs.device)
    for i in range(d_context):
        x = xs[:, i]
        h_next, c_next = [], []
        for lstm_cell, h, c in zip(self.lstm_cells, h_prev, c_prev):
            h, c = lstm_cell(x, h, c)
            h_next.append(h)
            c_next.append(c)
            x = h
        hs.append(h)
        outs.append(self.unembed(h))
        h_prev = t.stack(h_next)
        c_prev = t.stack(c_next)
    return t.stack(outs, dim=1), t.stack(hs, dim=1)

StackedLSTM.saliency = saliency
StackedLSTM.stats = stats

In [39]:
def to_b64_img(activation, width=10):
    size = activation.shape[0]
    assert size % width == 0
    activation = activation.detach().cpu()

    activation = activation.view(-1, 10)
    plt.figure(figsize=(2, 2))
    plt.imshow(activation)
    plt.axis('off')
    plt.tight_layout()
    buf = BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    return base64.b64encode(buf.getvalue()).decode('utf-8')

def norm_saliencies(saliencies):
    maxi = max([s for saliency in saliencies for s in saliency])
    mini = min([s for saliency in saliencies for s in saliency])
    scale = max(maxi, abs(mini))
    return [[s / scale for s in saliency] for saliency in saliencies]

def to_ch(v):
    token = itoa[v]
    # prevent my json to explode
    if token == '\r': token = '\\\\r'
    return token

def to_probs(out, token, k=5):
    probs = out.softmax(dim=-1)
    tops = probs.topk(k)
    pairs = [[to_ch(i.item()), v.item()] for i, v in zip(tops.indices, tops.values)]
    if token not in [p for p, _ in pairs]:
        pairs += [['...', 0.], [token, probs[atoi[token]].item()]]
    return pairs

def get_db(model, tokens):
    model.eval()
    db = {}
    default_img = to_b64_img(t.zeros(100))
    outs, activations = model.stats(tokens)
    outs, activations = outs[0], activations[0]
    db['default_image'] = default_img
    db['tokens'] = list(tokens) if isinstance(tokens, str) else tokens
    db['colors'] = norm_saliencies(model.saliency(tokens))
    db['images'] = [default_img] + [to_b64_img(activation) for activation in activations]
    db['probs'] = [[]] + [to_probs(out, token) for out, token in zip(outs, tokens[1:])]
    return json.dumps(db)

In [40]:
start_template = '''
<style>
.container {
    display: flex;
    width: 100%;
    gap: 50px;
    align-items: flex-start;
}
.infos {
    flex: 0 0 auto;
    display: flex;
    flex-direction: column;
    gap: 20px;
}
.current {
    background-color: green;
}
.prob {
    display: flex;
    gap: 20px;
}
.activations {
}
.probs .token {
    min-width: 8px;
    min-height: 16px;
}
.tokens {
    flex: 1 1 auto;
    display: flex;
    flex-wrap: wrap;
    gap: 10px;
}
.token {
    border: 1px solid grey;
    padding: 3px 10px 3px 10px;
}
</style>
    
<div class="container">
    <div class="infos">
        <div class="activations"><img /></div>
        <div class="probs"></div>
    </div>
    <div class="tokens"></div>
</div>
<script>
(() => { // start scope
'''

end_template = '''
function handle_mouseover(event) {
    el = event.target
    idx = el.getAttribute('data-idx');
    // image of the weights
    document.querySelector('.activations img').src = "data:image/png;base64," + db['images'][idx];
    // probability for the top tokens
    probs = document.querySelector('.infos .probs');
    db['probs'][idx].forEach((vals) => {
        div = document.createElement('div');
        div.classList.add('prob');
        prob = document.createElement('div');
        prob.textContent = vals[1].toFixed(2);
        token = document.createElement('div');
        token.classList.add('token');
        if (vals[0] === db['tokens'][idx]) { token.classList.add('current'); }
        token.textContent = vals[0]
        div.appendChild(prob);
        div.appendChild(token);
        probs.appendChild(div);
    });
    // saliency maps for tokens
    document.querySelectorAll('.tokens .token').forEach((el, i) => {
        if (i < idx) {
            value = db['colors'][idx][i];
            if (value < 0) {
                const t = (value + 1) / 1;
                red = Math.round(255 * t);
                green = Math.round(255 * t);
                blue = Math.round(255 * (1 - t) + 255 * t);
            } else {
                const t = value / 1;
                red = Math.round(255);
                green = Math.round(255 * (1 - t));
                blue = Math.round(255 * (1 - t));
            }
            el.style.backgroundColor = `rgb(${red}, ${green}, ${blue})`;
        }
    });
}

function handle_mouseout(event) {
    document.querySelector('.infos .activations img').src = "data:image/png;base64," + db['default_image'];
    document.querySelector('.infos .probs').innerHTML = '';
    document.querySelectorAll('.token').forEach((el, i) => {
        el.style.backgroundColor = '';
    });
}

function start() {
    el = document.querySelector('.container .tokens');
    db['tokens'].forEach((token, idx) => {
        span = document.createElement('span');
        span.textContent = token;
        span.setAttribute('class', 'token');
        span.setAttribute('data-idx', idx);
        span.addEventListener('mouseover', handle_mouseover);
        span.addEventListener('mouseout', handle_mouseout);
        el.appendChild(span);
    });
    document.querySelector('.activations img').src = "data:image/png;base64," + db['default_image'];
}

start();
})(); // end scope
</script>
'''

def build_html_widget(model, tokens):
    db_json = get_db(model, tokens)
    db_str = f"var db = JSON.parse('{db_json}');"
    return start_template + db_str + end_template

def display_interactive(model, tokens):
    html = build_html_widget(model, tokens)
    display(HTML(html))

### demo

In [41]:
display_interactive(stacked_lstm, 'Alice, said Alice in a little paper all alomer frog to the end.')

In [None]:
print(build_html_widget(stacked_lstm, 'Alice, said Alice in a little paper all alomer frog to the end.')