In [1]:
import torch

import sys
sys.path.append('../')

from NeuralGraph import NeuralGraph

import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import dataset

from torch import nn

device = "cpu"

In [2]:
# One listener node
# n_random random nodes
# brain_size brain nodes
# One speaker node

n_random = 1
brain_size = 16
emb_dim = 64
T = 3

listener_node = np.arange(0, 1)
random_nodes = np.arange(1, n_random+1)
brain_nodes = np.arange(1+n_random, 1+n_random+brain_size)
speaker_node = np.arange(1+n_random+brain_size, 2+n_random+brain_size)

connections = [(i, j) for nodes1, nodes2 in [(listener_node, brain_nodes), (random_nodes, brain_nodes), (brain_nodes, brain_nodes), (brain_nodes, speaker_node)] for i in nodes1 for j in nodes2]

print(len(connections))

gen_graph = NeuralGraph(brain_size+n_random+2, 1+n_random, 1, connections, ch_n=32, ch_e=32, ch_k=32, ch_inp=emb_dim, ch_out=emb_dim, device=device)

def generate(x, n_tokens=16):
    # x should be of shape (bs, seqlen, emb_dim)
    assert len(x.shape) == 3 and x.shape[-1] == emb_dim
    
    # Append noise
    x = torch.cat([x.unsqueeze(-2), torch.rand(x.shape[0], x.shape[1], n_random, x.shape[2])], axis=-2)
    gen_graph.init_vals(batch_size=x.shape[0])

    for token in range(x.shape[1]):
        gen_graph.apply_vals(x[:, token])
        for t in range(T):
            gen_graph.timestep(t=t)

    outputs = []
    for _ in range(n_tokens-1):
        outputs.append(gen_graph.read_outputs())
        for t in range(T):
            gen_graph.timestep()    
    outputs.append(gen_graph.read_outputs())

    outputs = torch.stack(outputs, axis=1)

    return outputs.squeeze(-2)

304


In [3]:
listener_node = np.arange(0, 1)
brain_nodes = np.arange(1, 1+brain_size)
guesser_node = np.arange(1+brain_size, 2+brain_size)

connections = [(i, j) for nodes1, nodes2 in [(listener_node, brain_nodes), (brain_nodes, brain_nodes), (brain_nodes, guesser_node)] for i in nodes1 for j in nodes2]

print(len(connections))

critic_graph = NeuralGraph(brain_size+2, 1, 1, connections, ch_n=32, ch_e=32, ch_k=32, ch_inp=emb_dim, ch_out=1, device=device)

def score(x):
    # x should be of shape (bs, seqlen, emb_dim)
    assert len(x.shape) == 3 and x.shape[-1] == emb_dim

    critic_graph.init_vals(batch_size=x.shape[0])

    for token in range(x.shape[1]):
        critic_graph.apply_vals(x[:, token].unsqueeze(-2))
        for t in range(T):
            critic_graph.timestep(t=t)

    return critic_graph.read_outputs().squeeze(-1)


288


In [4]:
inps = torch.randn(4, 8, emb_dim)
fake = generate(inps)
print(fake.shape)
scores = score(fake)
print(scores.shape)

torch.Size([4, 16, 64])
torch.Size([4])


In [5]:
batch_size = 16
eval_batch_size = 16



from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

def data_process(raw_text_iter: dataset.IterableDataset):
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

# ``train_iter`` was "consumed" by the process of building the vocab,
# so we have to create it again
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def batchify(data, bsz: int):
    """Divides the data into ``bsz`` separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Arguments:
        data: Tensor, shape ``[N]``
        bsz: int, batch size

    Returns:
        Tensor of shape ``[N // bsz, bsz]``
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

train_data = batchify(train_data, batch_size)  # shape ``[seq_len, batch_size]``
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

In [6]:
bptt = 16
def get_batch(source, i: int):
    """
    Args:
        source: Tensor, shape ``[full_seq_len, batch_size]``
        i: int

    Returns:
        tuple (data, target), where data has shape ``[seq_len, batch_size]`` and
        target has shape ``[seq_len * batch_size]``
    """
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+seq_len:i+2*seq_len]# .reshape(-1)
    return torch.swapaxes(data, 0, 1), torch.swapaxes(target, 0, 1)

In [7]:
x, y = get_batch(train_data, 0)
print(x.shape, y.shape)

embedder = nn.Embedding(len(vocab), emb_dim)

emb_x, emb_y = embedder(x), embedder(y)

print(emb_x.shape, emb_y.shape)

torch.Size([16, 16]) torch.Size([16, 16])
torch.Size([16, 16, 64]) torch.Size([16, 16, 64])


In [8]:
gen_opt = torch.optim.Adam(gen_graph.parameters(), lr=1e-4)
crit_opt = torch.optim.Adam(critic_graph.parameters(), lr=1e-4)

In [9]:
def compute_gp(real, fake):

    eps = torch.rand(batch_size, 1, 1).to(device)
    eps = eps.expand_as(real)
    interpolation = eps * real + (1 - eps) * fake

    # print(interpolation.shape)
    interp_logits = score(interpolation)
    grad_outputs = torch.ones_like(interp_logits)

    gradients = torch.autograd.grad(
        outputs=interp_logits,
        inputs=interpolation,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
    )[0]

    gradients = gradients.view(batch_size, -1)
    grad_norm = gradients.norm(2, 1)
    return torch.mean((grad_norm - 1) ** 2)

In [16]:
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt*2)):
    inp, real = get_batch(train_data, i)
    emb_inp = embedder(inp)
    emb_real = embedder(real)

    # Gen update

    # gen_graph.train()
    # critic_graph.eval()

    # gen_opt.zero_grad()
    # fake = generate(emb_inp, n_tokens=bptt)
    # fake_scores = score(torch.cat([emb_inp, fake], axis=1))
    # gen_loss = -fake_scores.mean()
    # gen_loss.backward()
    # gen_opt.step()


    # Crit update
    
    gen_graph.eval()
    critic_graph.train()

    crit_opt.zero_grad()
    fake2 = generate(emb_inp, n_tokens=bptt)
    true_scores = score(torch.cat([emb_inp, emb_real], axis=1))
    fake_scores = score(torch.cat([emb_inp, fake2], axis=1))
    crit_loss = fake_scores.mean() - true_scores.mean() + compute_gp(emb_real, fake2)
    crit_loss.backward()
    crit_opt.step()

    print(f"Gen loss : {gen_loss.item():0.5f} | Crit loss : {crit_loss.item():0.5f}")

Gen loss : -39.18990 | Crit loss : 0.99485
