In [1]:
from pathlib import Path
import logging
import argparse
import pickle

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import slayerSNN as snn
from dataset import ViTacDataset

In [29]:
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
log = logging.getLogger()

In [27]:
params = snn.params('network_ae.yml')
writer = SummaryWriter(".")

In [3]:
# implement one layer
class OneLayerMLP(torch.nn.Module):
    def __init__(self, params, input_size,output_size):
        super(OneLayerMLP, self).__init__()
        self.slayer = snn.layer(params["neuron"], params["simulation"])
        self.fc = self.slayer.dense(input_size, output_size)

    def forward(self, spike_input):
        spike_output = self.slayer.spike(self.slayer.psp(self.fc(spike_input)))
        return spike_output

class OneLayerAE(torch.nn.Module):
    def __init__(self, params, hidden_size):
        super(OneLayerAE, self).__init__()
        self.encoder = OneLayerMLP(params, 156, hidden_size)
        self.decoder = OneLayerMLP(params, hidden_size, 156)

    def forward(self, spike_input):
        encoded_spike = self.encoder(spike_input)
        spike_output = self.decoder(encoded_spike)
        return spike_output, encoded_spike 

In [4]:
class FLAGS():
    data_dir = "/home/jethro/aug13_full"
    batch_size=8
    sample_file=1
    lr = 0.01
    hidden_size = 32
    epochs=500
    checkpoint_dir = "."

In [5]:
args = FLAGS()

In [6]:
device = torch.device("cuda")
net = OneLayerAE(params, args.hidden_size).to(device)

In [7]:
error = snn.loss(params).to(device)
#criteria = error.spikeTime()
optimizer = torch.optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=0.5)

In [8]:
train_dataset = ViTacDataset(
    path=args.data_dir,
    sample_file=f"train_80_20_{args.sample_file}.txt",
    output_size=20,
    spiking=True,
    mode='tact',
)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=8,
)
test_dataset = ViTacDataset(
    path=args.data_dir,
    sample_file=f"test_80_20_{args.sample_file}.txt",
    output_size=20,
    spiking=True,
    mode='tact',
)
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=8,
)

In [19]:
a,b,c = train_dataset[0]

In [20]:
a.shape, b.shape, c

(torch.Size([156, 1, 1, 325]), torch.Size([20, 1, 1, 1]), 0)

In [25]:
def _train():
    correct = 0
    num_samples = 0
    net.train()
    for data, target, label in train_loader:
        data = data.to(device)
        target = target.to(device)
        output, _ = net.forward(data)
        num_samples += len(label)
        spike_loss = error.spikeTime(output, data)

        optimizer.zero_grad()
        spike_loss.backward()
        optimizer.step()

    writer.add_scalar("loss/train", spike_loss / len(train_loader), epoch)

def _test():
    correct = 0
    num_samples = 0
    net.eval()
    with torch.no_grad():
        for data, target, label in test_loader:
            data = data.to(device)
            target = target.to(device)
            output, _ = net.forward(data)
            num_samples += len(label)
            spike_loss = error.spikeTime(output, data)  # numSpikes

        writer.add_scalar("loss/test", spike_loss / len(test_loader), epoch)


def _save_model(epoch):
    log.info(f"Writing model at epoch {epoch}...")
    checkpoint_path = Path(args.checkpoint_dir) / f"weights_{epoch:03d}.pt"
    model_path = Path(args.checkpoint_dir) / f"model_{epoch:03d}.pt"
    torch.save(net.state_dict(), checkpoint_path)
    torch.save(net, model_path)

In [28]:
for epoch in range(1, args.epochs + 1):
    _train()
    if epoch % 10 == 0:
        _test()
    if epoch % 100 == 0:
        _save_model(epoch)

NameError: name 'log' is not defined

In [None]:
with open("args.pkl", "wb") as f:
    pickle.dump(args, f)