In [1]:
import torch
from torch.nn.utils import fuse_conv_bn_eval

from model import SDN, FusedSDN

#### 0. Architecture hyper-parametersm, checkpoint and helper function

In [2]:
d_model = 8
kernel_size = 8
n_layers = 1
save_file = "sdn.pkl"

In [3]:
model_checkpoint = "exp1/best_checkpoint.pth" # specific the checkpoint you want to convert
checkpoint = torch.load(model_checkpoint)

In [4]:
def copy_parameter(target, source):
    target.data.copy_(source)

#### 1. Build models.

##### a. Build SDN models

In [5]:
model = SDN(d_model, kernel_size, n_layers)
model.load_state_dict(checkpoint["model"])

<All keys matched successfully>

##### b. Build fused SDN model.

In [6]:
fused_model = FusedSDN(d_model, kernel_size, n_layers)

#### 2. Conversion

Convert the `dtype` of models to `double` in order to avoid precision errors and switch them to eval mode

In [7]:
model.double().eval()
fused_model.double().eval()

FusedSDN(
  (encoder): Conv1d(1, 8, kernel_size=(1,), stride=(1,), padding=(1,), bias=False)
  (spatial_layers): ModuleList(
    (0): Conv1d(1, 8, kernel_size=(8,), stride=(1,), padding=(8,))
  )
  (feature_layers): ModuleList(
    (0): Conv1d(8, 8, kernel_size=(1,), stride=(1,))
  )
  (decoder): Conv1d(8, 1, kernel_size=(1,), stride=(1,))
)

#### 3. Fusion

Fuse the encoder into the first spatial layer.

In [8]:
copy_parameter(
    fused_model.spatial_layers[0].weight,
    model.encoder.weight * model.spatial_layers[0][0].weight,
)
copy_parameter(fused_model.spatial_layers[0].bias, model.spatial_layers[0][0].bias)

Fuse all bn layers into its previous `conv1d` layer

In [9]:
fused_model.spatial_layers[0] = fuse_conv_bn_eval(
    fused_model.spatial_layers[0], model.spatial_layers[0][1]
)


for i in range(1, n_layers):
    fused_model.spatial_layers[i] = fuse_conv_bn_eval(
        model.spatial_layers[i][0], model.spatial_layers[i][1]
    )

for i in range(n_layers):
    fused_model.feature_layers[i] = fuse_conv_bn_eval(
        model.feature_layers[i][0], model.feature_layers[i][1]
    )

Copy the decoder

In [10]:
fused_model.decoder = model.decoder

#### 4. Test

In [11]:
x = torch.randn(64, 1, 1024).double()

In [12]:
torch.allclose(model(x), fused_model(x))

True

#### 5. Save

Here we use `torch.jit.trace` to turn the fused model into a `TorchScript`.

Note that we convert the `dtype` of model and input to `float`.

In [13]:
fused_model.float()  # in-place operation
x = x.float()
with torch.no_grad():
    traced_fused_model = torch.jit.trace(fused_model, x)

In [14]:
traced_fused_model.save(save_file)

In [15]:
reload_model = torch.jit.load(save_file)
torch.allclose(reload_model(x), fused_model(x))

True

Now, we can use the fused model without source code in our training of SNNs.

Here is an example:

In [16]:
from torch import nn


class StraightThroughEstimator(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return (x >= 0).to(x)

    @staticmethod
    def backward(ctx, grad_out):
        return grad_out


class SDNLIF(nn.Module):
    model_path = save_file

    def __init__(self, surrogate_func):
        super().__init__()
        self.model = torch.jit.load(self.model_path).eval()
        self.surrogate_func = surrogate_func

    def forward(self, x: torch.Tensor):
        m = self.pred(x)
        s = self.surrogate_func(m + x - 1.0)
        return s

    @torch.no_grad()
    def pred(self, x):
        shape = x.shape
        L = x.size(-1)
        return self.model(x.detach().view(-1, 1, L)).view(shape)


test_model = SDNLIF(StraightThroughEstimator.apply)
test_model(torch.randn(10, 1024))

tensor([[0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 1., 0., 0.],
        [1., 0., 0.,  ..., 1., 0., 1.],
        [0., 1., 0.,  ..., 0., 1., 0.]])