In [1]:
import torch
import torch.nn as nn
import torch.fx as fx
from collections import deque



In [2]:
class CircularBuffer(nn.Module):
    def __init__(self, buffer_size: int):
        super().__init__()
        self.buffer_size = buffer_size
#         self.register_buffer("index", torch.tensor(0, dtype=int))
#         self.index = torch.tensor(0, dtype=int, requires_grad=False)
#         self.index.zero_()
#         self.register_buffer('buffer', torch.zeros([4, self.buffer_size]))
#         self.buffer = torch.zeros([4, self.buffer_size])

    def forward(self, x: torch.Tensor):
        if not hasattr(self, "buffer"):
            self.register_buffer('buffer', torch.zeros(( [*x.shape[0:-1], self.buffer_size]), dtype=x.dtype, device=x.device))
        self.buffer[..., 0:-1] = self.buffer[..., 1:]
        self.buffer[..., -1] = x[..., -1]
#         self.buffer = torch.cat((self.buffer[..., 1:], x), dim=-1).detach()
        return self.buffer
#         self.buffer[..., -1] = x[..., -1]
    
#         self.index = (self.index + 1) % self.buffer_size


#     def get_buffer(self) -> torch.Tensor:
#         if self.buffer is None:
#             raise ValueError("Buffer is not initialized. Call append() first.")
#         return self.buffer


#     def reset(self):
#         if self.buffer is not None:
#             self.buffer.zero_()
#         self.index.zero_()
#         self.filled = False

In [3]:
# 1. Define a sample model
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(4, 4, kernel_size=3)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(4, 4, kernel_size=3)
        self.conv3 = nn.Conv1d(4, 4, kernel_size=3)
        self.cir_buffer1 = CircularBuffer(3)
        self.cir_buffer2 = CircularBuffer(3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        print(f"x shape after relu: {x}")
        x = self.cir_buffer1(x)
        print(f"x shape after cir buffer: {x}")
        x = self.conv2(x)
        x = self.cir_buffer2(x)
        x = self.conv3(x)
        return x

# 2. Trace the model
model = MyModel().eval()
# traced = fx.symbolic_trace(model)

# 3. Graph transformation: add residual from last 3 Conv1d inputs with shape checks
class ResidualConv1d(fx.Transformer):
    def __init__(self, module: fx.GraphModule):
        super().__init__(module)
#         self.input_buffer = deque(maxlen=3)  # Circular buffer for last 3 inputs
        self.modules = dict(module.named_modules())

    def call_module(self, target, args, kwargs):
        module = self.modules[target]
        
        if isinstance(module, nn.Conv1d):
            x = args[0]
#             cr = CircularBuffer(3)
#             cr.append(x)
#             new_x = cr.get_buffer()
#             conv_out = super().call_module(target, [new_x, *args[1:]], kwargs)

            # Original Conv1d computation
            conv_out = super().call_module(target, args, kwargs)

            # Initialize residual with conv_out
#             residual = conv_out

            # Add current input if shapes match
#             if x.meta.get('tensor_meta') and conv_out.meta.get('tensor_meta'):
#                 if x.meta['tensor_meta'].shape == conv_out.meta['tensor_meta'].shape:
#             residual = residual + x

            # Add previous inputs from buffer if shapes match
#             for prev_input in self.input_buffer:
#                 if prev_input.meta.get('tensor_meta') and conv_out.meta.get('tensor_meta'):
#                     if prev_input.meta['tensor_meta'].shape == conv_out.meta['tensor_meta'].shape:
#                 residual = residual + prev_input

            # Update the buffer with the current input
#             self.input_buffer.append(x)

            return conv_out

        return super().call_module(target, args, kwargs)

# 4. Apply transformation
# transformed = ResidualConv1d(traced).transform()

# 5. Script to TorchScript
# scripted = torch.jit.script(transformed)

# 6. Test
x = torch.randn(4, 3)
out = model(x)
out = model(x)
out = model(x)
out = model(x)
# print("Output shape:", out.shape)

# scripted_model = torch.jit.trace(model, x, check_trace=False).eval()
# scripted_model.save("torchscrpited_model.pth")


x shape after relu: tensor([[1.3870],
        [0.0000],
        [0.0104],
        [0.7143]], grad_fn=<ReluBackward0>)
x shape after cir buffer: tensor([[0.0000, 0.0000, 1.3870],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0104],
        [0.0000, 0.0000, 0.7143]], grad_fn=<CopySlices>)
x shape after relu: tensor([[1.3870],
        [0.0000],
        [0.0104],
        [0.7143]], grad_fn=<ReluBackward0>)
x shape after cir buffer: tensor([[0.0000, 1.3870, 1.3870],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0104, 0.0104],
        [0.0000, 0.7143, 0.7143]], grad_fn=<CopySlices>)
x shape after relu: tensor([[1.3870],
        [0.0000],
        [0.0104],
        [0.7143]], grad_fn=<ReluBackward0>)
x shape after cir buffer: tensor([[1.3870, 1.3870, 1.3870],
        [0.0000, 0.0000, 0.0000],
        [0.0104, 0.0104, 0.0104],
        [0.7143, 0.7143, 0.7143]], grad_fn=<CopySlices>)
x shape after relu: tensor([[1.3870],
        [0.0000],
        [0.0104],
        [0.7143]]