In [10]:
import torch
import sys
import torch.nn.functional as F
from torch.cuda import amp
from spikingjelly.activation_based import functional, surrogate, neuron
from spikingjelly.activation_based.model import parametric_lif_net
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
import os
import argparse
import datetime


`python -m spikingjelly.activation_based.examples.classify_dvsg -T 20 -device cuda:0 -b 16 -epochs 64 -data-dir /data/tanghao/datasets/DVS128Gesture/ -amp -opt adam -lr 0.001 -j 8`

In [11]:
data_dir='/data/tanghao/datasets/DVS128Gesture'
T=20
channel=128
b=16
j=8
device='cuda:7'
lr=0.001
epochs=1
out_dir='./logs'
opt='adam'

In [12]:
net = parametric_lif_net.DVSGestureNet(channels=channel, spiking_neuron=neuron.LIFNode, surrogate_function=surrogate.ATan(), detach_reset=True)

functional.set_step_mode(net, 'm')
# functional.set_backend(net, 'cupy', instance=neuron.LIFNode)
net.to(device)
print(net)

DVSGestureNet(
  (conv_fc): Sequential(
    (0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (2): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (6): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (8): Con

In [13]:
train_set = DVS128Gesture(root=data_dir, train=True, data_type='frame', frames_number=T, split_by='number')
test_set = DVS128Gesture(root=data_dir, train=False, data_type='frame', frames_number=T, split_by='number')

The directory [/data/tanghao/datasets/DVS128Gesture/frames_number_20_split_by_number] already exists.
The directory [/data/tanghao/datasets/DVS128Gesture/frames_number_20_split_by_number] already exists.


In [14]:
train_data_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=b,
    shuffle=True,
    drop_last=True,
    num_workers=j,
    pin_memory=True
)

test_data_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=b,
    shuffle=True,
    drop_last=False,
    num_workers=j,
    pin_memory=True
)

In [15]:
scaler = amp.GradScaler()

In [16]:
start_epoch = 0
max_test_acc = -1


# optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)


lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)


out_dir = os.path.join(out_dir, f'T{T}_b{b}_{opt}_lr{lr}_c{channel}')


out_dir += '_amp'
# out_dir += '_cupy'

if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    print(f'Mkdir {out_dir}.')

writer = SummaryWriter(out_dir, purge_step=start_epoch)

In [17]:
for epoch in range(start_epoch, epochs):
    start_time = time.time()
    net.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for frame, label in train_data_loader:
        optimizer.zero_grad()
        frame = frame.to(device)
        frame = frame.transpose(0, 1)  # [N, T, C, H, W] -> [T, N, C, H, W]
        label = label.to(device)
        label_onehot = F.one_hot(label, 11).float()

        if scaler is not None:
            with amp.autocast():
                out_fr = net(frame).mean(0)
                loss = F.mse_loss(out_fr, label_onehot)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out_fr = net(frame).mean(0)
            loss = F.mse_loss(out_fr, label_onehot)
            loss.backward()
            optimizer.step()

        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        functional.reset_net(net)

    train_time = time.time()
    train_speed = train_samples / (train_time - start_time)
    train_loss /= train_samples
    train_acc /= train_samples

    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('train_acc', train_acc, epoch)
    lr_scheduler.step()

    net.eval()
    test_loss = 0
    test_acc = 0
    test_samples = 0
    with torch.no_grad():
        for frame, label in test_data_loader:
            frame = frame.to(device)
            frame = frame.transpose(0, 1)  # [N, T, C, H, W] -> [T, N, C, H, W]
            label = label.to(device)
            label_onehot = F.one_hot(label, 11).float()
            out_fr = net(frame).mean(0)
            loss = F.mse_loss(out_fr, label_onehot)
            test_samples += label.numel()
            test_loss += loss.item() * label.numel()
            test_acc += (out_fr.argmax(1) == label).float().sum().item()
            functional.reset_net(net)
    test_time = time.time()
    test_speed = test_samples / (test_time - train_time)
    test_loss /= test_samples
    test_acc /= test_samples
    writer.add_scalar('test_loss', test_loss, epoch)
    writer.add_scalar('test_acc', test_acc, epoch)

    save_max = False
    if test_acc > max_test_acc:
        max_test_acc = test_acc
        save_max = True

    checkpoint = {
        'net': net.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'epoch': epoch,
        'max_test_acc': max_test_acc
    }

    if save_max:
        torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth'))

    torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth'))

    print(out_dir)
    print(f'epoch = {epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}, max_test_acc ={max_test_acc: .4f}')
    print(f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s')
    print(f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n')

RuntimeError: CUDA out of memory. Tried to allocate 40.00 MiB (GPU 7; 11.91 GiB total capacity; 6.49 GiB already allocated; 7.88 MiB free; 6.53 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF