In [1]:
import pickle
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torchvision
import matplotlib.pyplot as plt
import sinabs
import sinabs.activation
import sinabs.layers as sl
from sinabs.from_torch import from_model
import os
import shutil
from sklearn.model_selection import KFold
print(torch.cuda.is_available())
torch.cuda.get_device_name(0)
torch.cuda.empty_cache()

True


Data

In [33]:
class MyData(Dataset):
    def __init__(self, root_dir, is_spiking=False):
        self.root_dir = root_dir
        self.data_path = os.listdir(self.root_dir)
        self.is_spiking = is_spiking
    
    def __getitem__(self, idx):
        data_name = self.data_path[idx]
        data_item_path = os.path.join(self.root_dir, data_name)
        with open(data_item_path, 'rb') as f:
            data = np.load(f)

        data = torch.from_numpy(data).float()
        data = torch.transpose(data, 0, 1)
        data = data.view(-1, 10, 10)
        if self.is_spiking:
            data = data.unsqueeze(1)
        else:
            data = data.sum(0)
            data = data.unsqueeze(0)

        for i in range(10):
            if data_name[-i-1] == '_':
                label = data_name[-i:-4]
                break
        label = eval(label)
        label = int((label/8)) % 27
        label = torch.tensor(label, dtype=torch.long)
        
        return data, label

    def __len__(self):
        return len(self.data_path)

In [35]:
root_dir_1 = 'F:\Files\PhD\Braille\Data/braille-27letters-sphere/effect-xyposition/xyposition-r1/train'
root_dir_2 = 'F:\Files\PhD\Braille\Data/braille-27letters-sphere/effect-xyposition/xyposition-r1/test'

train_data = MyData(root_dir_1, False)
test_data  = MyData(root_dir_2, True)

print(train_data[0][0].shape, test_data[0][0].shape)
print(len(train_data), len(test_data))

torch.Size([1, 10, 10]) torch.Size([700, 1, 10, 10])
8640 8640


Training

In [42]:
a = torch.rand(81, 1, 10, 10)
# a = sl.FlattenTime()(a)
print(a.shape)
cnn = nn.Sequential(
        nn.Conv2d(1, 20, 3, 1, bias=False),
        nn.ReLU(),
        nn.AvgPool2d(2, 2),
        nn.Conv2d(20, 32, 3, 1, bias=False),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(128, 400, bias=False),
        nn.ReLU(),
        nn.Linear(400, 27, bias=False)
    )
b = cnn(a)
# b = b.unflatten(0, (81, 56700//81))
b.shape

torch.Size([81, 1, 10, 10])


torch.Size([81, 27])

In [43]:
lr = 1e-4

device = 'cuda:0'
k_folds = 10
kfold = KFold(n_splits=k_folds, shuffle=False)
batch_size = 81
loss_fn = nn.CrossEntropyLoss()

fold_train_acc = []
fold_val_acc = []

for fold, (train_ids, val_ids) in enumerate(kfold.split(train_data)):
    # 模型初始化
    cnn = nn.Sequential(
        nn.Conv2d(1, 20, 3, 1, bias=False),
        nn.ReLU(),
        nn.AvgPool2d(2, 2),
        nn.Conv2d(20, 32, 3, 1, bias=False),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(128, 400, bias=False),
        nn.ReLU(),
        nn.Linear(400, 27, bias=False)
    )
    optimizer = torch.optim.Adam(cnn.parameters(), lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
    # 分割数据集
    train_sub = Subset(train_data, train_ids)
    val_sub = Subset(train_data, val_ids)
    # 创建数据加载器
    train_loader = DataLoader(train_sub, batch_size, shuffle=True, drop_last=True)
    val_loader  = DataLoader(val_sub, batch_size, drop_last=True)

    # 训练模型
    cnn.train()
    epochs = 40
    for e in range(epochs):
        running_loss = 0.
        acc = 0
        scheduler.step()
        for i, (input, target) in enumerate(train_loader):
            optimizer.zero_grad()

            cnn = cnn.to(device)
            input = input.to(device)
            target = target.to(device)
            output = cnn(input)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()

            running_loss += loss

            for j in range(batch_size):
                if output[j].argmax() == target[j]:
                    acc += 1

        print("epoch: %d, accuracy: %.2f%%, running_loss: %.2f, current_lr: %.6f" 
              % (e, acc/len(train_sub)*100, running_loss, scheduler.get_last_lr()[0]) )
    fold_train_acc.append(np.around(acc/len(train_sub)*100, 2))

    # 验证模型
    acc_num = 0
    for i, (data, target) in enumerate(val_loader):
        with torch.no_grad():
            data = data.to(device)
            target = target.to(device)
            output = cnn(data)

        for j in range(batch_size):
            if output[j].argmax() == target[j]:
                acc_num += 1
    print("accuracy on validation set: %.2f%%" % (acc_num/len(val_sub)*100))
    fold_val_acc.append(np.around(acc_num/len(val_sub)*100, 2))

epoch: 0, accuracy: 9.34%, running_loss: 513.64, current_lr: 0.000100
epoch: 1, accuracy: 21.42%, running_loss: 258.95, current_lr: 0.000100
epoch: 2, accuracy: 30.59%, running_loss: 221.63, current_lr: 0.000100
epoch: 3, accuracy: 37.10%, running_loss: 195.32, current_lr: 0.000100
epoch: 4, accuracy: 42.32%, running_loss: 174.78, current_lr: 0.000100
epoch: 5, accuracy: 48.32%, running_loss: 156.22, current_lr: 0.000100
epoch: 6, accuracy: 54.01%, running_loss: 139.55, current_lr: 0.000100
epoch: 7, accuracy: 58.29%, running_loss: 125.93, current_lr: 0.000100
epoch: 8, accuracy: 64.67%, running_loss: 111.71, current_lr: 0.000100
epoch: 9, accuracy: 68.47%, running_loss: 100.27, current_lr: 0.000100
epoch: 10, accuracy: 71.98%, running_loss: 90.87, current_lr: 0.000100
epoch: 11, accuracy: 76.07%, running_loss: 80.44, current_lr: 0.000100
epoch: 12, accuracy: 79.30%, running_loss: 72.15, current_lr: 0.000100
epoch: 13, accuracy: 81.69%, running_loss: 64.83, current_lr: 0.000100
epoch: 

KeyboardInterrupt: 

In [None]:
lr = 1e-4

device = 'cuda:0'
k_folds = 10
kfold = KFold(n_splits=k_folds, shuffle=False)
batch_size = 81
loss_fn = nn.CrossEntropyLoss()

fold_train_acc = []
fold_val_acc = []

for fold, (train_ids, val_ids) in enumerate(kfold.split(train_data)):
    # 模型初始化
    cnn = nn.Sequential(
        nn.Conv2d(1, 10, 3, 1, bias=False),
        nn.ReLU(),
        nn.AvgPool2d(2, 2),
        nn.Conv2d(10, 20, 3, 1, bias=False),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(80, 200, bias=False),
        nn.ReLU(),
        nn.Linear(200, 27, bias=False)
    )
    cnn = from_model(cnn, batch_size=81, input_shape=(1, 10, 10), 
                              add_spiking_output=True, synops=False, num_timesteps=700).to(device)
    optimizer = torch.optim.Adam(cnn.parameters(), lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
    # 分割数据集
    train_sub = Subset(train_data, train_ids)
    val_sub = Subset(train_data, val_ids)
    # 创建数据加载器
    train_loader = DataLoader(train_sub, batch_size, shuffle=True, drop_last=True)
    val_loader  = DataLoader(val_sub, batch_size, drop_last=True)

    # 训练模型
    cnn.train()
    epochs = 40
    for e in range(epochs):
        running_loss = 0.
        acc = 0
        scheduler.step()
        for i, (input, target) in enumerate(train_loader):
            optimizer.zero_grad()
            cnn.reset_states()

            input = input.to(device)
            input = sl.FlattenTime()(input)
            target = target.to(device)
            output = cnn(input)
            output = output.unflatten(0, (81, output.shape[0] // 81))
            sum_output = output.sum(1)
            loss = loss_fn(sum_output, target)
            loss.backward()
            optimizer.step()

            running_loss += loss

            for j in range(batch_size):
                if sum_output[j].argmax() == target[j]:
                    acc += 1

        print("epoch: %d, accuracy: %.2f%%, running_loss: %.2f, current_lr: %.6f" 
              % (e, acc/len(train_sub)*100, running_loss, scheduler.get_last_lr()[0]) )
    fold_train_acc.append(np.around(acc/len(train_sub)*100, 2))

    # 验证模型
    acc_num = 0
    for i, (data, target) in enumerate(val_loader):
        with torch.no_grad():
            cnn.reset_states()
            data = data.to(device)
            target = target.to(device)
            output = cnn(data)
            sum_output = output.sum(1)

        for j in range(batch_size):
            if sum_output[j].argmax() == target[j]:
                acc_num += 1
    print("accuracy on validation set: %.2f%%" % (acc_num/len(val_sub)*100))
    fold_val_acc.append(np.around(acc_num/len(val_sub)*100, 2))

In [None]:
print(fold_train_acc)
print(fold_val_acc)

sum = 0
cnt = 0
for i in fold_train_acc:
    sum += i
    cnt += 1
print(sum/cnt)

sum = 0
cnt = 0
for i in fold_val_acc:
    sum += i
    cnt += 1
print(sum/cnt)

In [44]:
model_path = './models/ANN-CNN.pth'
torch.save(cnn, model_path)

Testing

In [45]:
model = torch.load(model_path)
snn = from_model(model, input_shape=(1,10,10), add_spiking_output=True, synops=False, num_timesteps=700)
test_loader  = DataLoader(test_data, batch_size, drop_last=True)

acc_num = 0
for i, (data, target) in enumerate(test_loader):
    with torch.no_grad():
        snn.reset_states()
        data = data.to(device)
        data = sl.FlattenTime()(data)
        target = target.to(device)
        output = snn(data)
        output = output.unflatten(0, (81, output.shape[0] // 81))
        sum_output = output.sum(1)

    for j in range(batch_size):
        if sum_output[j].argmax() == target[j]:
            acc_num += 1
print("accuracy on testing set: %.2f%%" % (acc_num/len(test_data)*100))

accuracy on testing set: 22.87%
