In [112]:
from pathlib import Path
import logging
import argparse
import pickle

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import slayerSNN as snn
from dataset import ViTacDataset

In [113]:
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
log = logging.getLogger()

In [114]:
params = snn.params('network_config/network_ae.yml')
writer = SummaryWriter(".")

In [115]:
# implement one layer
class OneLayerMLP(torch.nn.Module):
    def __init__(self, params, input_size, output_size):
        super(OneLayerMLP, self).__init__()
        self.slayer = snn.layer(params["neuron"], params["simulation"])
        self.fc = self.slayer.dense(input_size, output_size)

    def forward(self, spike_input):
        spike_output = self.slayer.spike(self.slayer.psp(self.fc(spike_input)))
        return spike_output

class FFT1AE(torch.nn.Module):
    def __init__(self, params, hidden_size):
        super(FFT1AE, self).__init__()
        self.encoder = OneLayerMLP(params, 156*2, hidden_size)
        self.decoder = OneLayerMLP(params, hidden_size, 156*2)

    def forward(self, spike_input):
        # convert to fft
        # input 8,156,1,1,325
        real_part = spike_input.squeeze().unsqueeze(dim=-1) # 8,156,325,1
        complex_part = torch.zeros_like(real_part) # 8,156,325,1
        input_spike = torch.cat([real_part, complex_part], dim=3) # 8,156,325,2
        spike_input_fft = torch.fft(input_spike, 1) # 8,156,325,2
        spike_input_fft = spike_input_fft.permute(0,1,3,2) # 8,156,2,325
        # convert 8,156,2,325 -> 8,312,1,1,325
        spike_input_fft = spike_input_fft.reshape(-1, 156*2, 1,1,325) # 8,312,1,1,325
        encoded_spike_fft = self.encoder(spike_input_fft)     
        spike_output_fft = self.decoder(encoded_spike)
        spike_output = torch.ifft(spike_output_fft)
        spike_output = spike_output[...,0]
        spike_output = spike_output.unsqueeze(2).unsqueeze(2)
        return spike_output, encoded_spike_fft

In [100]:
class FLAGS():
    data_dir = "/home/jethro/aug13_full"
    batch_size=8
    sample_file=1
    lr = 0.01
    hidden_size = 32
    epochs=500
    checkpoint_dir = "."

In [101]:
args = FLAGS()

In [102]:
device = torch.device("cuda:01")
net = FFT1AE(params, args.hidden_size).to(device)

In [103]:
error = snn.loss(params).to(device)
#criteria = error.spikeTime()
optimizer = torch.optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=0.5)

In [104]:
train_dataset = ViTacDataset(
    path=args.data_dir,
    sample_file=f"train_80_20_{args.sample_file}.txt",
    output_size=20,
    spiking=True,
    mode='tact',
)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=8,
)
test_dataset = ViTacDataset(
    path=args.data_dir,
    sample_file=f"test_80_20_{args.sample_file}.txt",
    output_size=20,
    spiking=True,
    mode='tact',
)
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=8,
)

In [105]:
a,b,c = train_dataset[0]

In [106]:
a.shape, b.shape, c

(torch.Size([156, 1, 1, 325]), torch.Size([20, 1, 1, 1]), 0)

In [107]:
a = a.squeeze()

In [108]:
a.shape

torch.Size([156, 325])

In [109]:
a[0]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1.,
        0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1.,
        1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0.,
        0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 0., 0., 0., 0., 1., 1., 0., 0.,
        0., 1., 0., 1., 0., 0., 0., 0., 

In [110]:
def _train():
    correct = 0
    num_samples = 0
    net.train()
    for data, target, label in train_loader:
        data = data.to(device)
        target = target.to(device)
        output, _ = net.forward(data)
        num_samples += len(label)
        spike_loss = error.spikeTime(output, data)

        optimizer.zero_grad()
        spike_loss.backward()
        optimizer.step()

    writer.add_scalar("loss/train", spike_loss / len(train_loader), epoch)

def _test():
    correct = 0
    num_samples = 0
    net.eval()
    with torch.no_grad():
        for data, target, label in test_loader:
            data = data.to(device)
            target = target.to(device)
            output, _ = net.forward(data)
            num_samples += len(label)
            spike_loss = error.spikeTime(output, data)  # numSpikes

        writer.add_scalar("loss/test", spike_loss / len(test_loader), epoch)


def _save_model(epoch):
    log.info(f"Writing model at epoch {epoch}...")
    checkpoint_path = Path(args.checkpoint_dir) / f"weights_{epoch:03d}.pt"
    model_path = Path(args.checkpoint_dir) / f"model_{epoch:03d}.pt"
    torch.save(net.state_dict(), checkpoint_path)

In [111]:
for epoch in range(1, args.epochs + 1):
    _train()
    if epoch % 10 == 0:
        _test()
    if epoch % 100 == 0:
        _save_model(epoch)

RuntimeError: Expected 5-dimensional input for 5-dimensional weight 32 156 1 1 1, but got 4-dimensional input of size [8, 156, 325, 2] instead

In [None]:
with open("args.pkl", "wb") as f:
    pickle.dump(args, f)