In [1]:
from elephant.neurons import *
from elephant.synapses import *
from elephant import HAM
from tqdm.auto import tqdm

import itertools
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
data_path = os.path.expanduser('~/data/tinyshakespeare/input.txt')
with open(data_path, 'r') as f:
    text = f.read()
vocab = set(text)
n_vocab = len(vocab)
char_to_token = { c: i for i, c in enumerate(sorted(list(vocab))) }
token_to_char = { i: c for c, i in char_to_token.items() }
data = np.array([char_to_token[c] for c in text], dtype=np.uint8)
n_data = len(data)

print(f'#vocab  = {n_vocab}')
print(f'#tokens = {n_data/1e6:.4f}M')

#vocab  = 65
#tokens = 1.1154M


In [3]:
n_token = 64
n_embed = 128
n_heads = 16
n_proj  = n_embed
beta_attn = 10.0
beta_mem = 10.0
device = torch.device('cuda')
dtype = torch.float32

In [4]:
encoder = nn.Embedding(n_vocab, n_embed, device=device, dtype=dtype)
decoder = nn.Linear(n_embed, n_vocab, device=device, dtype=dtype)
encoder.weight = decoder.weight

In [5]:
params_enc = set(itertools.chain(encoder.parameters(), decoder.parameters()))
optim_enc = torch.optim.AdamW(params_enc, lr=1e-2, weight_decay=1e-4)
optim_enc.zero_grad(set_to_none=True)

In [6]:
batch_size = 100_000
n_examples = n_data
n_batches = n_examples // batch_size
n_epochs = 20

for e in range(n_epochs):
    pbar = tqdm(range(n_batches))
    for i in pbar:
        tokens = torch.tensor(data[i*batch_size:min((i+1)*batch_size,n_data)], dtype=torch.long).to(device=device)
        logits = decoder(encoder(tokens))
        loss = F.cross_entropy(logits, tokens, ignore_index=-1)
        pbar.set_description(f'loss = {loss.item():.6f}')
        loss.backward()
        optim_enc.step()
        optim_enc.zero_grad(set_to_none=True)

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

In [7]:
sample = text[:100]
print(sample)
tokens = torch.tensor([char_to_token[c] for c in sample], dtype=torch.long).to(device=device)
tokens_out = torch.argmax(decoder(encoder(tokens)), dim=-1).cpu().tolist()
print('========')
for i in tokens_out:
    print(token_to_char[i], end='')

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You

In [8]:
neurons  = { 'embeds': LayerNormNeuron(shape=(n_token, n_embed), use_bias=True, bias_dims={1}, device=device, dtype=dtype) }
synapses = {
    'attn': AttentionSynapse(n_embed, n_heads, n_proj, beta=beta_attn, device=device, dtype=dtype),
    'mem': HopfieldSynapse(n_embed, 4*n_embed, beta=beta_mem, device=device, dtype=dtype)
}
connections = {
    'attn': ['embeds', 'embeds'],
    'mem': ['embeds']
}
model = HAM(neurons, synapses, connections)
n_params = sum(p.numel() for p in model.parameters())
print(f'#params = {n_params/1e6:.4f}M')

#params = 0.5899M


In [9]:
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
optim.zero_grad(set_to_none=True)

In [10]:
batch_size = 1000
n_examples = n_data
n_batches  = n_examples // batch_size
n_token = 64

pbar = tqdm(range(n_batches))
for i in pbar:
    starts = np.random.randint(0, n_data-n_token, size=(batch_size,))
    tokens = torch.tensor(np.array([data[s:s+n_token] for s in starts]), dtype=torch.long).to(device)
    embeds = encoder(tokens)
    xs = model.init_states(
        batch_size=batch_size,
        values={ 'embeds': embeds.requires_grad_() },
        requires_grad=True,
        device=device,
        dtype=dtype
    )
    gs = model.activations(xs)
    grads, energy = model.dEdg(xs, gs, create_graph=True, return_energy=True)
    grad_loss = torch.cat([torch.norm(g.view(g.shape[0], -1), dim=1, keepdim=True) for g in grads.values()], dim=1).mean(dim=1)
    loss = torch.mean(grad_loss + 0.00001*energy)
    
    mean_grad_loss = torch.mean(grad_loss).item()
    mean_energy = torch.mean(energy).item()
    pbar.set_description(f'grad loss = {mean_grad_loss:.4f}, energy = {mean_energy:.4f}, loss = {loss.item():.4f}')
    
    loss.backward()
    optim.step()
    optim.zero_grad(set_to_none=True)
    

  0%|          | 0/1115 [00:00<?, ?it/s]

In [11]:
text = '''ANTONIO:
Do you not hear me speak?

SEBASTIAN:
'''

model.neurons['embeds'].allow_variable_input = True

all_tokens = [char_to_token[c] for c in text]
response_length = 64
pbar = tqdm(range(response_length))

for i in pbar:
    
    tokens = all_tokens[-min(len(all_tokens), n_token-1):]
    tokens.append(0)
    tokens = torch.tensor(tokens, dtype=torch.long).to(device).view(1, -1)
    embeds = encoder(tokens)
    embeds[:,-1,:] = 0.0
    
    xs = model.init_states(
        batch_size=1,
        values={ 'embeds': embeds.requires_grad_() },
        requires_grad=True,
        device=device,
        dtype=dtype
    )
    gs = model.activations(xs)
    xs, gs = model.energy_descent(xs, gs, max_iter=1000, tol=1e-4, create_graph=False)
    logits = decoder(xs['embeds'][:,-1,:])
    probs = torch.softmax(logits, dim=-1)
    all_tokens.append(torch.argmax(probs.flatten()).item())

  0%|          | 0/64 [00:00<?, ?it/s]

In [12]:
for i in all_tokens:
    print(token_to_char[i], end='')

ANTONIO:
Do you not hear me speak?

SEBASTIAN:
-N.3LLL!!L!!LLL3!!GzI!!!!!!X!XRRR!!!!!!T!!!Rp!!,ppp!!!!!!!!!Is;T