# A Two-Stage Prediction + Detection Framework for Real-time Epileptic Seizure Monitoring



In [2]:
print("A Two Stage Prediction + Detection framework: for Real time Epileptic Seizure Prediction")

A Two Stage Prediction + Detection framework: for Real time Epileptic Seizure Prediction


In [1]:
%pip install torch

Collecting torch
  Downloading torch-2.9.1-cp310-none-macosx_11_0_arm64.whl.metadata (30 kB)
Collecting filelock (from torch)
  Using cached filelock-3.20.0-py3-none-any.whl.metadata (2.1 kB)
Collecting sympy>=1.13.3 (from torch)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx>=2.5.1 (from torch)
  Using cached networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting fsspec>=0.8.5 (from torch)
  Using cached fsspec-2025.10.0-py3-none-any.whl.metadata (10 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading torch-2.9.1-cp310-none-macosx_11_0_arm64.whl (74.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.5/74.5 MB[0m [31m390.7 kB/s[0m eta [36m0:00:00[0m00:01[0m00:06[0m
[?25hUsing cached fsspec-2025.10.0-py3-none-any.whl (200 kB)
Using cached networkx-3.4.2-py3-none-any.whl (1.7 MB)
Using cached sympy-1.14.0-py3-none-any.whl (6.3 MB)
Using

## Workflow and Architecture of the propose two-stage prediction + detection model : PDNet

![PDNet Architecure](./assets/PDNet_architecture.png)

![Structure Parameters of the Proposed PDNet Model](./assets/parameters_of_PDNet.png)

Code for ResConv is adapted from : [Convolution with Residual Connection](https://medium.com/@chen-yu/building-a-customized-residual-cnn-with-pytorch-471810e894ed)

In [1]:
import torch
from torch import nn

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch: int, mid_ch:int, out_ch: int):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels = in_ch,
                               out_channels = mid_ch,
                               kernel_size = 1,
                               padding = 0)
        # self.bn1 = nn.BatchNorm1d(mid_ch)
        self.relu = nn.ReLU()
        
        self.conv2 = nn.Conv1d(in_channels = mid_ch,
                               out_channels = out_ch,
                               kernel_size = 5,
                               padding = 2)
        # self.bn2 = nn.BatchNorm1d(out_ch)
        
    def forward(self, x:torch.Tensor) -> torch.Tensor :
        x = self.relu(self.conv1(x))
        # print(x.shape)
        x = self.relu(self.conv2(x))
        # print(x.shape)
        return x
    
conv_block = ConvBlock(8,2,16)
x = torch.randn((16,8,256))
conv_block(x).shape

torch.Size([16, 16, 256])

In [3]:
class SharedLayer(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=in_ch, out_channels=out_ch, kernel_size=5, padding=2)
        self.maxpool = nn.MaxPool1d(kernel_size=2)
        self.conv2 = ConvBlock(8,2,16)
        self.conv3 = ConvBlock(16,4,16)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        print(x.shape)
        x = self.maxpool(x)
        print(x.shape)
        x = self.conv2(x)
        print(x.shape)
        x = self.maxpool(x)
        print(x.shape)
        x = self.conv3(x)
        print(x.shape)
        return x
        
x = torch.randn((16,512)).unsqueeze(dim=1)
shared_layer = SharedLayer(1,8)
shared_layer(x).shape

torch.Size([16, 8, 512])
torch.Size([16, 8, 256])
torch.Size([16, 16, 256])
torch.Size([16, 16, 128])
torch.Size([16, 16, 128])


torch.Size([16, 16, 128])

In [4]:
x = torch.randn((16,8,256))
conv_block = ConvBlock(8,2,16)
conv_block(x).shape

torch.Size([16, 16, 256])

In [5]:
x = torch.randn((16,512)).unsqueeze(1)
conv = nn.Conv1d(in_channels=1, out_channels=8, kernel_size=5, padding=2)
conv(x).shape

torch.Size([16, 8, 512])

In [6]:
class ResConv(nn.Module):
    def __init__(self, in_ch1:int, in_ch2:int, in_ch3: int, out_ch: int) -> None :
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=in_ch1,
                               out_channels = in_ch2,
                               kernel_size = 1,
                               padding = 0)
        self.bn1 = nn.BatchNorm1d(in_ch2)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(in_channels = in_ch2,
                               out_channels = in_ch3,
                               kernel_size = 5,
                               padding = 2)
        self.bn2 = nn.BatchNorm1d(in_ch3)
        self.conv3 = nn.Conv1d(in_channels=in_ch3,
                               out_channels = out_ch,
                               kernel_size = 1,
                               padding = 0)
        self.bn3 = nn.BatchNorm1d(out_ch)
        
    def forward(self, x):
        print(x.shape)
        x = self.relu(self.bn1(self.conv1(x)))
        print(x.shape)
        x = self.relu(self.bn2(self.conv2(x)))
        print(x.shape)
        x = self.relu(self.bn3(self.conv3(x)))
        print(x.shape)
        print(x.shape)
        return x
    
res_conv = ResConv(256,16,16,32)

x = torch.randn((256,128)).unsqueeze(dim=0)
res_conv(x).squeeze(dim=0).shape

torch.Size([1, 256, 128])
torch.Size([1, 16, 128])
torch.Size([1, 16, 128])
torch.Size([1, 32, 128])
torch.Size([1, 32, 128])


torch.Size([32, 128])

In [7]:
# Testing for flattening the tensor to specified size as per the paper
x = torch.randn((16,16,128))
flatten = nn.Flatten(start_dim=0, end_dim=1)
flatten(x).shape

torch.Size([256, 128])

In [8]:
# Defining the global average pooling
class GlobalAveragePooling(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return x.mean(dim=2)
    
    
x = torch.randn((1,64,128))
gap = GlobalAveragePooling()
gap(x).shape

torch.Size([1, 64])

**How to Calculate Padding :** 

L_out = (L_in + 2 * padding - kernel_size) / stride + 1

Since stride = 1:

=> L_out = L_in + 2 * padding - kernel_size + 1   (L_in = L_out)

=> padding = (kernel_size - 1)/2


In [9]:
class PDNetModelV1(nn.Module):
    def __init__(self, in_ch, out_ch, is_shared_layer = False):
        super().__init__()
        self.shared_layer = nn.Sequential(
            SharedLayer(in_ch = in_ch, out_ch=out_ch)
        )
        self.flatten = nn.Flatten(start_dim=0, end_dim=1)
        
        self.prediction_layer = nn.Sequential(
            ConvBlock(256,8,64),
            GlobalAveragePooling(),
            nn.Linear(in_features=64, out_features=2)
        )
        
        self.detection_layer = nn.Sequential(
            ResConv(256,16,16,32),
            ResConv(32,32,32,48),
            ResConv(48,48,48,64),   
            GlobalAveragePooling(),
            nn.Linear(in_features=64, out_features=11)
        )
        
        self.is_shared_layer = is_shared_layer
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.shared_layer(x)
        print(x.shape)
        
        x = self.flatten(x)
        print(x.shape)
        
        if self.is_shared_layer == False :
            x = x.unsqueeze(dim=0)
            x = self.prediction_layer(x).squeeze(0)
            return x
        
        # x1 = self.prediction_layer(x).squeeze(dim=0)
        # print(x1.shape)
        
        x = x.unsqueeze(dim=0)
        x = self.detection_layer(x)

        return x


In [10]:
model_1_0 = PDNetModelV1(in_ch=1, out_ch=8, is_shared_layer=False)

x = torch.randn((16,512)).unsqueeze(dim=1)
print(x.shape)
model_1_0(x).shape

torch.Size([16, 1, 512])
torch.Size([16, 8, 512])
torch.Size([16, 8, 256])
torch.Size([16, 16, 256])
torch.Size([16, 16, 128])
torch.Size([16, 16, 128])
torch.Size([16, 16, 128])
torch.Size([256, 128])


torch.Size([2])

In [11]:
model_1_1 = PDNetModelV1(in_ch=1, out_ch=8, is_shared_layer=True)

x = torch.randn((16,512)).unsqueeze(dim=1)
print(x.shape)
model_1_1(x).shape

torch.Size([16, 1, 512])
torch.Size([16, 8, 512])
torch.Size([16, 8, 256])
torch.Size([16, 16, 256])
torch.Size([16, 16, 128])
torch.Size([16, 16, 128])
torch.Size([16, 16, 128])
torch.Size([256, 128])
torch.Size([1, 256, 128])
torch.Size([1, 16, 128])
torch.Size([1, 16, 128])
torch.Size([1, 32, 128])
torch.Size([1, 32, 128])
torch.Size([1, 32, 128])
torch.Size([1, 32, 128])
torch.Size([1, 32, 128])
torch.Size([1, 48, 128])
torch.Size([1, 48, 128])
torch.Size([1, 48, 128])
torch.Size([1, 48, 128])
torch.Size([1, 48, 128])
torch.Size([1, 64, 128])
torch.Size([1, 64, 128])


torch.Size([1, 11])