In [17]:
import torch
import torch.nn as nn

import tensorflow as tf

import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm
import numpy as np
import cv2

import random
import importlib

import NeuralGraph
importlib.reload(NeuralGraph)

<module 'NeuralGraph' from 'c:\\Users\\Alec\\Documents\\GitHub\\ngraph_lang\\NeuralGraph.py'>

In [18]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train, x_test = x_train/255.0, x_test/255.0

SIZE = 7

train = {i:[] for i in range(8)}
test = {i:[] for i in range(8, 10)}

for img, label in zip(x_train, y_train):
    if label < 8:
        train[label].append(cv2.resize(img, dsize=[SIZE, SIZE]).reshape(SIZE**2))
    else:
        test[label].append(cv2.resize(img, dsize=[SIZE, SIZE]).reshape(SIZE**2))

for label in train.keys():
    train[label] = np.array(train[label])

for label in test.keys():
    test[label] = np.array(test[label])

In [19]:
n_classes = 2

shape = [SIZE**2, 16, n_classes]
connections = []

for i in range(len(shape)-1):
    for j in range(shape[i]):
        for k in range(shape[i+1]):
            connections.append((sum(shape[:i])+j, sum(shape[:i+1])+k))

print(len(connections))

816


In [41]:
device = "cuda"

graph = NeuralGraph.NeuralGraph(sum(shape), shape[0], shape[-1], connections, leakage=.25, value_init="random", init_value_std=.1, 
    aggregation="attention", n_heads=1, use_label=True, device=device, n_models=1, decay=.5)

In [42]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(graph.parameters(), lr=1e-2)

In [43]:
TRAIN_EXAMPLES = 10
TEST_EXAMPLES = 10

STEPS = 100_000
BATCH_SIZE = 32
TIME = 1
DT = .25

def get_batch_data():
    classes = np.stack([np.random.choice(8, size=(n_classes), replace=False) for _ in range(BATCH_SIZE)])

    y_train = np.random.randint(n_classes, size=(BATCH_SIZE, TRAIN_EXAMPLES))
    x_train = []
    for batch_classes, y in zip(classes, y_train):
        x_train.append([])
        for class_ in batch_classes[y]:
            x_train[-1].append(random.choice(train[class_]))
        x_train[-1] = np.stack(x_train[-1])
    x_train = np.stack(x_train)

    y_test = np.random.randint(n_classes, size=(BATCH_SIZE, TEST_EXAMPLES))
    x_test = []
    for batch_classes, y in zip(classes, y_test):
        x_test.append([])
        for class_ in batch_classes[y]:
            x_test[-1].append(random.choice(train[class_]))
        x_test[-1] = np.stack(x_test[-1])
    x_test = np.stack(x_test)
    return torch.Tensor(x_train).to(device), torch.Tensor(y_train).long().to(device), torch.Tensor(x_test).to(device), torch.Tensor(y_test).long().to(device), classes

In [None]:
log = []

In [44]:
bar = tqdm(range(STEPS))
for _ in bar:
    x_train, y_train, x_test, y_test, classes = get_batch_data()

    graph.init_vals(nodes=True, edges=True, batch_size=BATCH_SIZE)
    graph.detach_vals()
    optimizer.zero_grad()

    y_input = nn.functional.one_hot(y_train, n_classes).float()

    graph.learn(x_train, y_input, time=TIME, dt=DT)
    pred = graph.predict(x_test, time=TIME, dt=DT)

    accs = (pred.argmax(2) == y_test).float()

    y_label = nn.functional.one_hot(y_test, n_classes).float()

    task_loss = criterion(pred, y_label)
    overflow = graph.overflow()

    loss = task_loss + overflow
    loss.backward()

    if overflow < .1 and graph.decay >= .25:
        graph.decay -= .05

    torch.nn.utils.clip_grad_norm_(graph.parameters(), 1.0)
    optimizer.step()

    entry = {'loss': task_loss.item(), 'acc': accs.mean().item(), "overflow": overflow.item(), "decay":graph.decay}
    log.append(entry)
    bar.set_postfix({"loss":np.mean([e["loss"] for e in log[-10:]]), "acc":np.mean([e["acc"] for e in log[-10:]]), "overflow":entry["overflow"], "decay":entry["decay"]})

  0%|          | 0/10000 [00:00<?, ?it/s]