In [1]:
import superneuromat as snm
from superneuromat import SNN
import pandas as pd
import numpy as np
from numpy import typing as npt
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from dataclasses import dataclass, field


# from: https://stackoverflow.com/a/21894086/2712730
class bidict(dict):
    """Creates a dictionary that supports reverse lookups via the .inverse attribute.

    Args:
        dict (dict): The original dictionary.

    Properties:
        inverse (dict): A dictionary that maps values to keys.
    """
    def __init__(self, *args, **kwargs):
        super(bidict, self).__init__(*args, **kwargs)
        self.inverse = {}
        for key, value in self.items():
            self.inverse.setdefault(value, []).append(key)

    def __setitem__(self, key, value):
        if key in self:
            self.inverse[self[key]].remove(key)
        super(bidict, self).__setitem__(key, value)
        self.inverse.setdefault(value, []).append(key)

    def __delitem__(self, key):
        self.inverse.setdefault(self[key], []).remove(key)
        if self[key] in self.inverse and not self.inverse[self[key]]:
            del self.inverse[self[key]]
        super(bidict, self).__delitem__(key)


def train_test_split_indices(papers, test_size=0.2, rng=None) -> tuple[npt.NDArray[np.generic], npt.NDArray[np.generic]]:
    """
    Splits a list of papers into train and test indices.

    Parameters
    ----------
    papers : iterable of papers with ids
        List or dict of papers to split.
    test_size : float, optional
        if < 1, proportion of papers to reserve for testing, by default 0.2
        if > 1, number of papers to reserve for testing
    rng : int, optional
        Random state for the random number generator, default uses numpy's random

    Returns
    -------
    train_indices : list of int
        List of indices for the training set.
    test_indices : list of int
        List of indices for the testing set.
    """
    n = papers if isinstance(papers, int) else len(papers)
    if test_size < 1:
        test_size = int(np.floor(test_size * n))  # number of papers in test
    if rng is None:
        rng = np.random.default_rng()  # setup rng
    elif isinstance(rng, int):
        rng = np.random.default_rng(rng)

    if isinstance(papers, int):
        indices = np.arange(n)  # generate indices
    elif isinstance(papers, (list, tuple)):
        indices = papers  # assume papers is a list of indices
    else:  # assume papers is a dict or mapping of papers with a .values() method
        indices = [paper.idx for paper in papers.values()]  # grab indices from dict entries
    indices = np.asarray(indices)

    # shuffle and split
    rng.shuffle(indices)
    test_indices = indices[:test_size]
    train_indices = indices[test_size:]
    return train_indices, test_indices


In [2]:
rng = np.random.default_rng(1)

### Load CiteSeer dataset

In [3]:
@dataclass
class Paper:
    idx: str  # Paper ID
    label: str  # Paper category/topic
    features: tuple[bool | int | float, ...] = ()  # binary features
    citations: list[str] = field(default_factory=list)  # IDs of papers cited by this paper


papers = {}
missing = set()

# Load in training data
content = pd.read_csv("data/Cora/cora/cora.content", sep="\t", header=None, dtype=object)
citations = pd.read_csv("data/Cora/cora/cora.cites", sep="\t", header=None, dtype=object)

labels = set()  # set of unique labels

for paper in content.itertuples(index=False):  # create papers from data
    idx = paper[0]
    features = tuple([int(feature) for feature in paper[1:-1]])  # parse features
    papers[idx] = Paper(idx, paper[-1], features)  # create paper object
    labels.add(paper[-1])  # label is the last column. add to set of labels

for paper_idx, citation in citations.itertuples(index=False):
    try:  # parse citations
        if citation not in papers:
            missing.add(citation)  # if citation is missing, add to missing list
            continue
        if citation in papers[paper_idx].citations:
            continue  # skip if citation is already in citations
        papers[paper_idx].citations.append(citation)
    except KeyError:
        missing.add(paper_idx)  # if paper is missing, add to missing list

print(f"Loaded {len(papers)} papers and skipped {len(missing)} broken references.")
labels = list(sorted(labels))  # sort labels
# TODO: WHY DOES THIS MATTER? Set order is random, depending on interpreter hash seed.
# but why does this give me different results if I have randommized weights...

labels  # show labels set

Loaded 2708 papers and skipped 0 broken references.


['Case_Based',
 'Genetic_Algorithms',
 'Neural_Networks',
 'Probabilistic_Methods',
 'Reinforcement_Learning',
 'Rule_Learning',
 'Theory']

#### Create train / test split  
Create two lists of IDs

In [4]:
train_idxs, test_idxs = train_test_split_indices(papers, test_size=0.2, rng=rng)

#### Build the network

In [5]:
# Initialize our model
model = SNN()

# Create our output neurons, set threshold very high so that we control when they spike during training.
# dict mapping {category: neuron_id}
lbl_threshold = 1.05
strong_connection = 5.0
weak_connection = 1.0
unknown_connection = 0.00001
# unknown_connection = 0.0

lbl_neurons = bidict({label: model.create_neuron(threshold=lbl_threshold, leak=0).idx for label in labels})

# Create our input neurons, one for each pixel of the image resolution.
paper_neurons = bidict()  # dict mapping {paper_id: neuron_id}

# make a neuron for each paper
for paper in papers.values():
    paper_neurons[paper.idx] = neuron_id = model.create_neuron(1, 0.).idx

    if paper.idx in train_idxs:
        # Make an explicitly STRONG synapse connecting the input to the output
        output_id = lbl_neurons[paper.label]  # (training paper to topic)
        model.create_synapse(neuron_id, output_id, weight=strong_connection, stdp_enabled=False, delay=1)
        model.create_synapse(output_id, neuron_id, weight=strong_connection, stdp_enabled=False, delay=1)
    else:  # test set, connect this input neuron to all output neurons
        # Connect our input neuron to output neurons (test paper to topic)
        for output_id in lbl_neurons.values():
            # Randomize initial weight
            weight1 = (rng.choice([-1, 1]) * unknown_connection) + 1
            weight2 = (rng.choice([-1, 1]) * unknown_connection) + 1
            # Make a synapse connecting the input to the output
            model.create_synapse(neuron_id, output_id, weight=weight1, stdp_enabled=True, delay=1)
            model.create_synapse(output_id, neuron_id, weight=weight2, stdp_enabled=True, delay=1)

# connect papers by their citations (paper to paper connections)
for paper in papers.values():
    for citation in paper.citations:
        if paper.idx == citation:
            continue  # don't cite yourself
        try:
            model.create_synapse(paper_neurons[paper.idx], paper_neurons[citation], weight=weak_connection, stdp_enabled=True, delay=1)  # noqa
            model.create_synapse(paper_neurons[citation], paper_neurons[paper.idx], weight=weak_connection, stdp_enabled=True, delay=1)  # noqa
        except RuntimeError:
            continue  # skip if citation loop

#### loop through our dataset and add spikes

In [6]:
timestep = 0
train_idxs_augmented = np.concat((train_idxs, ) * 3)
for idx in train_idxs_augmented:
    paper = papers[idx]  # get paper by id
    # Add a spike that will exceed the threshold for the respective label neuron
    model.add_spike(timestep + 1, lbl_neurons[paper.label], strong_connection + 1)
    # Add spikes to the paper
    model.add_spike(timestep, paper_neurons[paper.idx], strong_connection + 1)
    timestep += 1

### Setup the model and perform a training pass

In [7]:
# Set up our stdp, only one timestep because we only want it looking at the results
# of what our input layer does to our output layer
model.stdp_setup(Apos=[1e-3, 5e-4], Aneg=[-1e-4, -5e-5], negative_update=True, positive_update=True)
# model.setup()

print(model.short())
print(f"{timestep} time steps will be simulated.")

# Simulate
with tqdm(total=timestep) as pbar:
    model.simulate(time_steps=timestep, callback=lambda _s, _t, _n: pbar.update())

SNN with 2715 neurons and 22464 synapses @ 0x27aad538890
6501 time steps will be simulated.


  0%|          | 0/6501 [00:00<?, ?it/s]

### inference

In [8]:
seed = 10026
modelc = model.copy()
modelc.reset()
modelc.release_mem()

In [9]:
print(f"Seed: {seed}")
rng = np.random.default_rng(seed)
seed += 1
model2 = modelc.copy()
model2.stdp = False
nlabels = len(labels)
# model2.enable_stdp = np.zeros(model.num_neurons)

model2.neuron_thresholds[:nlabels] = [99999] * nlabels

# paper = random.choice(list(papers.values()))  # pick a random paper
# paper = papers[random.choice(train_idxs)]  # pick a random paper from training set
paper = papers[rng.choice(test_idxs)]  # pick a random paper from testing set

print(f"Paper: {paper.idx}\tCategory: {paper.label}")
idx: int = paper_neurons[paper.idx]
neuron = model2.neurons[idx]

spike_n = 2

for t in range(spike_n):
    neuron.add_spike(t, strong_connection + 1)

model2.simulate(spike_n + 1)

output_spikes = np.sum(model2.ispikes[-(spike_n + 1):, :nlabels], axis=0)

spiked_ids = {idx: lbl_neurons.inverse[idx][0]
              for idx, spiked in enumerate(output_spikes) if spiked}
if spiked_ids:
    print(f"Spiked categories:")
    for idx, category in spiked_ids.items():
        print(f"\t{idx}\t{category}")
else:
    print("No category spiked")

print('=' * 10)

lbl_by_threshold = sorted((enumerate(model2.neuron_states[:nlabels])), key=lambda x: x[1], reverse=True)
for i, v in lbl_by_threshold:
    category = lbl_neurons.inverse[i][0]
    syn = model2.get_synapse(neuron, i)
    assert syn is not None
    print(f"{i}\t{v: 5.6f}\t{syn.weight: 5.6f}\t{category}")
print(model2.neuron_states[:nlabels])
print(model2.neuron_thresholds[:nlabels])
print(snm.print_spike_train(model2.ispikes[:, :nlabels]))

Seed: 10026
Paper: 1140040	Category: Neural_Networks
No category spiked
2	 36.489080	 10.744540	Neural_Networks
0	 21.489120	 10.744560	Case_Based
3	 21.489120	 10.744560	Probabilistic_Methods
4	 21.489120	 10.744560	Reinforcement_Learning
5	 21.489120	 10.744560	Rule_Learning
6	 21.489120	 10.744560	Theory
1	 21.489080	 10.744540	Genetic_Algorithms
[21.489120000000444, 21.489080000000442, 36.48908000000044, 21.489120000000444, 21.489120000000444, 21.489120000000444, 21.489120000000444]
[99999, 99999, 99999, 99999, 99999, 99999, 99999]
t|id0 1 2 3 4 5 6  
0: [│ │ │ │ │ │ │ ]
1: [│ │ │ │ │ │ │ ]
2: [│ │ │ │ │ │ │ ]
None


In [10]:
# test_threshold = 6.058
test_threshold = 99999


def evaluate_paper(paper_idx):
    model_temp = modelc.copy()
    model_temp.stdp = False
    model_temp.neuron_thresholds[:nlabels] = [test_threshold] * nlabels
    for t in range(2):
        model_temp.add_spike(t, paper_neurons[paper_idx], strong_connection + 1)
    model_temp.simulate(1 + 2)
    # output_spikes = np.sum(model_temp.ispikes[-(1):, :nlabels], axis=0)
    model_temp.release_mem()
    return model_temp.neuron_states[:nlabels]

In [11]:
if "results" in locals():
    oldresults = results.copy()

In [12]:
# single-threaded (slow, but better for debugging)
results = [evaluate_paper(paper_idx) for paper_idx in tqdm(test_idxs)]

  0%|          | 0/541 [00:00<?, ?it/s]

In [13]:
if "oldresults" in locals():
    print(np.array_equiv(oldresults, results))
    print(np.allclose(results, oldresults, atol=1e-5))

In [14]:
from tabulate import tabulate

tp, fp, tn, fn = 0, 0, 0, 0
correct = 0
total = len(test_idxs)

for actual_idx, charges in zip(test_idxs, results):
    correct_label = papers[actual_idx].label
    # guesses = {lbl_neurons.inverse[idx][0] for idx, spiked in enumerate(spikes) if spiked}
    charges = [(charge, lbl_neurons.inverse[idx][0]) for idx, charge in enumerate(charges)]
    charges = sorted(charges, reverse=True)
    guesses = [label for charge, label in charges if charge == charges[0][0]]
    correct += correct_label in guesses and len(guesses) == 1
    for label in labels:
        if label == correct_label:
            if correct_label in guesses:
                tp += 1
            else:
                fn += 1
        else:
            if label in guesses:
                fp += 1
            else:
                tn += 1

n_guesses = total - fn
print(f"tp: {tp} / {n_guesses} / {total} | Perfect: {correct} / {n_guesses} / {total} (correct / attempted / total)")
print(f"tp:        {tp / total:>.6f} | Perfect: {correct / total:>.6f} | F1:        {tp / (tp + (0.5 * (fp + fn))):>.6f}")
print(f"Precision: {tp / (tp + fp):>.6f} | Recall:  {tp / (tp + fn):>.6f} | Accuracy:  {(tp + tn) / (tp + tn + fp + fn):>.6f}")

tabulate([[correct, tp, fp, tn, fn, n_guesses, total,
           model.apos, model.aneg, lbl_threshold,
           strong_connection, weak_connection, unknown_connection, test_threshold]], tablefmt="html")

tp: 452 / 452 / 541 | Perfect: 441 / 452 / 541 (correct / attempted / total)
tp:        0.835490 | Perfect: 0.815157 | F1:        0.814414
Precision: 0.794376 | Recall:  0.835490 | Accuracy:  0.945603


0,1,2,3,4,5,6,7,8,9,10,11,12,13
441,452,117,3129,89,452,541,"[0.001, 0.0005]","[-0.0001, -5e-05]",1.05,5,1,1e-05,99999
