In [2]:
# PyTorchs
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR #학습률 스케줄링
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from torchvision.models import resnet50, resnet101

# SnnTorchs
import snntorch as snn
from snntorch import surrogate
#from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt

In [12]:
class SpikingResNet101_2(nn.Module):
    def __init__(self, original_model):
        super(SpikingResNet101_2, self).__init__()

        # Copy the initial layers before Conv2_x
        self.conv1 = original_model.conv1
        self.bn1 = original_model.bn1
        self.maxpool = original_model.maxpool

        # Modify the Conv2_x layer group by replacing ReLU with snn.Leaky
        self.layer1 = self.modify_layer_group(original_model.layer1)

        """
        for i, block in enumerate(original_model.layer1):
            for j, layer in enumerate(block.children()):
        """     

        # Copy the remaining layers without modification
        self.layer2 = original_model.layer2
        self.layer3 = original_model.layer3
        self.layer4 = original_model.layer4
        self.avgpool = original_model.avgpool
        self.fc = original_model.fc

    def modify_layer_group(self, layer_group):
        for idx, layer in enumerate(layer_group):
            if isinstance(layer.relu, nn.ReLU):
                layer_group[idx].relu = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid(), init_hidden=True)
        return layer_group

    def reset_mem(self):
        for module in self.modules():
            if isinstance(module, snn.Leaky):
                module.reset_mem()
    
    def forward(self, x, num_steps=10):
        # Forward pass through the modified network
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.maxpool(x)

        # Forward pass through modified layer1 (Conv2_x)
        spk_rec = []
        self.reset_mem()
        for step in range(num_steps):
            spk= self.layer1(x)
            spk_rec.append(spk)

        # Average over time steps
        x = torch.mean(torch.stack(spk_rec), dim=0)

        # Forward pass through the remaining layers
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [13]:
original_model = resnet101(weights=None, num_classes=10)
model = SpikingResNet101_2(original_model)

In [14]:
inputs = torch.randn((1,3,224,224))

In [None]:
print(model)

In [15]:
outputs = model(inputs)

In [16]:
labels = torch.randn((1,10))

In [17]:
criterion = nn.CrossEntropyLoss()

In [18]:
loss = criterion(outputs, labels)

In [19]:
loss.backward()