In [None]:
import torch
import spikingjelly
from spikingjelly.activation_based.model import spiking_resnet
from spikingjelly.activation_based import surrogate, neuron, functional, layer
from torchvision.datasets import MNIST
import numpy as np
from spikingjelly.activation_based.model.sew_resnet import sew_resnet18
from torch import nn
from torchmetrics import Accuracy

In [None]:
class DummyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = layer.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = layer.BatchNorm2d(16)
        self.sn1 = neuron.LIFNode(tau=2.0, surrogate_function=surrogate.Sigmoid())
        self.avgpool = layer.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(16, 10)
        self.sn2 = neuron.LIFNode(tau=2.0, surrogate_function=surrogate.Sigmoid())
        

    
    def forward(self, x):
        print(f"RAW INPUT SHAPE: {x.shape}")
        print(x)
        x = self.conv1(x)
        print(f"AFTER CONV SHAPE: {x.shape}")
        print(x)
        x = self.bn1(x)
        print(f"AFTER BN SHAPE: {x.shape}")
        print(x)
        x = self.sn1(x)
        print(f"AFTER SN SHAPE: {x.shape}")
        print(x)
        print(f"ANY SPIKES: {torch.max(x) > 0}")
        x = self.avgpool(x)
        print(f"AFTER AVG SHAPE: {x.shape}")
        print(x)
        print(f"ANY SPIKES: {torch.max(x) > 0}")
        x  = torch.flatten(x, 1)
        print(f"AFTER FLATTEN SHAPE: {x.shape}")
        print(x)
        x = self.fc1(x)
        print(f"AFTER FC SHAPE: {x.shape}")
        print(x)
        x = self.sn2(x)
        print(f"AFTER SN2 SHAPE: {x.shape}")
        print(x)

        return x
    

dummy_input = torch.randn(1, 1, 28, 28)
dummy_model = DummyModel()
dummy_model.eval()


RAW INPUT SHAPE: torch.Size([1, 1, 28, 28])
tensor([[[[-3.4519e-01,  9.1727e-02, -4.6358e-01, -1.2695e+00,  1.2693e+00,
            2.4674e-02, -5.7668e-01, -1.6995e+00, -1.1081e+00, -9.3229e-01,
            1.4014e+00,  2.2841e-01, -7.8185e-01, -8.9732e-01,  3.4407e-01,
            1.0203e+00,  5.0452e-01, -1.1258e+00,  3.7464e-01,  3.4553e-01,
           -2.7075e-03, -4.5000e-01, -1.1760e+00,  1.7577e-02, -1.0302e+00,
            3.5220e-02, -3.3523e-01,  7.1252e-01],
          [ 1.2366e+00, -1.5124e+00,  1.0330e+00,  1.9119e+00, -1.1545e+00,
           -1.5810e-01,  1.2688e+00, -1.8234e+00, -6.9158e-01,  1.2743e+00,
            1.2276e+00, -1.0755e+00,  2.0838e-01,  1.8423e-01, -2.8349e-01,
            1.8100e+00, -1.6409e+00,  2.2434e+00, -3.6947e-01,  6.6441e-01,
           -4.6507e-01, -7.3347e-01, -1.3598e+00,  7.5758e-02,  1.4354e+00,
           -5.5126e-01,  1.2559e+00, -8.4321e-01],
          [ 3.2449e-01,  7.4277e-01,  5.5155e-01,  3.5231e-01, -2.7423e-01,
           -7.9102

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
def SewResnet18(
    n_channels : int =1 ,
    neuron_model: neuron.BaseNode = neuron.LIFNode,
    surrogate_function: surrogate.SurrogateFunctionBase = surrogate.Sigmoid,
) -> nn.Module:
    net = sew_resnet18(
        pretrained=False,
        spiking_neuron=neuron_model,
        cnf="IAND",
        surrogate_function=surrogate_function(),
        # detach_reset=True,
    )
    net.conv1 = layer.Conv2d(
        n_channels,
        64,
        kernel_size=(7, 7),
        stride=(2, 2),
        padding=(3, 3),
        bias=False,
    )
    net.fc = layer.Linear(512, 10)
    return net

In [None]:
class MNISTRepeated(MNIST):
    def __init__(self, *args, repeat=1, **kwargs):
        super().__init__(*args, **kwargs)
        self.repeat = repeat

    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        img_tensor =  torch.tensor(np.array(img), dtype=torch.float32)
        img_tensor = img_tensor.repeat(self.repeat, 1, 1).unsqueeze(1)
        return img_tensor, target

In [None]:
epochs = 10
batch_size = 8
model = SewResnet18(n_channels=1)
functional.set_step_mode(model, step_mode='m')

repeats = 10
mnist_dataset_repeat_train = MNISTRepeated(root = "./data" ,train=True, repeat=repeats, download=True)
mnist_dataset_repeat_test = MNISTRepeated(root = "./data" ,train=False, repeat=repeats, download=True)
criterion = nn.CrossEntropyLoss()
accuracy_metirc = Accuracy(task = "multiclass",num_classes=10)

train_loader = torch.utils.data.DataLoader(
    mnist_dataset_repeat_train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
)
test_loader = torch.utils.data.DataLoader(
    mnist_dataset_repeat_test,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-3)

In [None]:
img, target = next(iter(train_loader))

In [None]:
import matplotlib.pyplot as plt
import numpy as np
def plot_tensor_sequence(tensor: torch.Tensor):
    """
    Creates a Matplotlib plot to visualize a sequence of grayscale images
    from a 5D tensor of shape [T, B, C, H, W].

    Args:
        tensor (torch.Tensor): A 5D tensor containing the image sequence.
                                 T: Time dimension (number of frames).
                                 B: Batch size (number of samples).
                                 C: Channel dimension (should be 1 for grayscale).
                                 H: Height of the images.
                                 W: Width of the images.
    """
    T, B, C, H, W = tensor.shape

    if C != 1:
        raise ValueError("The channel dimension (C) should be 1 for grayscale images.")

    fig, axes = plt.subplots(B, T, figsize=(2 * T, 2 * B))  # Adjust figure size as needed

    # Handle the case where B or T is 1 to ensure axes is iterable
    if B == 1:
        axes = np.array([axes])
    if T == 1:
        axes = axes[:, np.newaxis]

    for b in range(B):
        for t in range(T):
            image = tensor[t, b, 0, :, :].cpu().numpy()
            axes[b, t].imshow(image, cmap='gray')
            axes[b, t].axis('off')  # Turn off axis labels and ticks

    plt.tight_layout()
    plt.show()

plot_tensor_sequence(img.transpose(0,1))

In [None]:
# Display original shape
print(f"Original shape: {img.shape}")

# Swap first and second dimensions
img_transposed = img.transpose(0, 1)

# Display new shape
print(f"Transposed shape: {img_transposed.shape}")

In [None]:
out = model(img_transposed).mean(dim=0)

In [None]:
from tqdm import tqdm
epoch_progbar = tqdm(range(epochs), desc="Epoch")
dataloader_progbar = tqdm(train_loader, desc="Dataloader")
for epoch in epoch_progbar:
    model.train()

    epoch_loss = 0
    epoch_preds = []
    epoch_targets = []
    for i, (img, target) in enumerate(dataloader_progbar):
        optimizer.zero_grad()
        out = model(img.transpose(0, 1)/255).mean(0)

        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_preds.append(out)
        epoch_targets.append(target)
        functional.reset_net(model)
        print(f"Epoch {epoch}: {i}/{len(train_loader)}: loss: {loss.item():.4f}")
    epoch_preds = torch.cat(epoch_preds)
    epoch_targets = torch.cat(epoch_targets)
    epoch_loss /= len(train_loader)
    epoch_acc = accuracy_metirc(epoch_preds, epoch_targets)
    print(f"Epoch {epoch}: loss: {epoch_loss:.4f}, acc: {epoch_acc:.4f}")
