In [1]:
import pickle
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torchvision
import matplotlib.pyplot as plt
import sinabs
import sinabs.activation
import sinabs.layers as sl
from sinabs.from_torch import from_model
import os

Data

In [7]:
class MyData(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.data_path = os.listdir(self.root_dir)
    
    def __getitem__(self, idx):
        data_name = self.data_path[idx]
        data_item_path = os.path.join(self.root_dir, data_name)
        with open(data_item_path, 'rb') as f:
            data = np.load(f)

        data = torch.from_numpy(data).float()
        data = torch.transpose(data, 0, 1)

        label = torch.tensor(eval(data_name[-5]), dtype=torch.long)
        
        return data, label

    def __len__(self):
        return len(self.data_path)

In [8]:
root_dir_1 = 'F:\Files\PhD\Braille\Data/train'
root_dir_2 = 'F:\Files\PhD\Braille\Data/test'

train_data = MyData(root_dir_1)
test_data  = MyData(root_dir_2)

batch_size = 50
train_loader = DataLoader(train_data, batch_size, shuffle=True)
test_loader  = DataLoader(test_data, batch_size)

In [12]:
train_data[2][0].shape

torch.Size([620, 100])

Model

In [14]:
linear_model = nn.Sequential(
    nn.Linear(100, 200),
    nn.ReLU(),
    nn.Linear(200, 400),
    nn.ReLU(),
    nn.Linear(400, 11)
)
linear_model = from_model(linear_model, batch_size=50, input_shape=(1, 620, 100), add_spiking_output=True)

loss_fn = nn.CrossEntropyLoss()
# lr = 1e-4
# optimizer = torch.optim.Adam(linear_model.parameters(), lr)

Training

In [26]:
linear_model.train()
lr = 1e-5
optimizer = torch.optim.Adam(linear_model.parameters(), lr)
acc = 0

epochs = 10
for e in range(epochs):
    running_loss = 0.
    for i, (input, target) in enumerate(train_loader):
        optimizer.zero_grad()
        linear_model.reset_states()

        output = linear_model(input)
        sum_output = output.sum(1)
        loss = loss_fn(sum_output, target)
        loss.backward()
        optimizer.step()

        running_loss += loss

        for j in range(batch_size):
            if sum_output[j].argmax() == target[j]:
                acc += 1

    print("epoch: %d, accuracy: %.2f%%, running_loss: %.2f" % (e, acc/len(train_data)*100, running_loss) )
    acc = 0

epoch: 0, accuracy: 97.36%, running_loss: 1.54
epoch: 1, accuracy: 98.18%, running_loss: 1.11
epoch: 2, accuracy: 98.36%, running_loss: 0.86
epoch: 3, accuracy: 98.27%, running_loss: 1.09
epoch: 4, accuracy: 97.55%, running_loss: 1.38
epoch: 5, accuracy: 98.73%, running_loss: 1.01
epoch: 6, accuracy: 98.73%, running_loss: 0.93
epoch: 7, accuracy: 98.55%, running_loss: 0.92
epoch: 8, accuracy: 98.27%, running_loss: 1.03
epoch: 9, accuracy: 98.64%, running_loss: 1.00


In [27]:
model_path = './models/train.pth'
torch.save(linear_model, model_path)

In [28]:
f = torch.load(model_path)

Testing

In [29]:
acc_num = 0
for i, (data, target) in enumerate(test_loader):
    with torch.no_grad():
        f.reset_states()
        output = f(data)
        sum_output = output.sum(1)

    for j in range(batch_size):
        if sum_output[j].argmax() == target[j]:
            acc_num += 1
print("accuracy on testing set: %.2f%%" % (acc_num/len(test_data)*100))

accuracy on testing set: 98.82%
