In [1]:
from models.ShaSpec import SpecificEncoder, SharedEncoder
import torch

# (batch_size, F=filter_num, T, C)
# input shape for one single modality
input_shape = (1, 1, 128, 18)  # Initial configuration

# Initialize a dummy input tensor
x_1 = torch.randn(input_shape)

N = 6  # Number of modalities (devices)
# Create a list of N tensors, each with the same shape as x_1
modalities = [torch.randn(input_shape) for _ in range(N)]

# Stack the modalities along the num_of_sensor_channels dimension
x = torch.cat(modalities, dim=3)  # The shape of x will be [1, 1, 128, 18*N]

print("-"*16, "Specific Encoder", "-"*16)
# Filter size = kernel size
specific_encoder = SpecificEncoder(input_shape, filter_num=64, filter_size=5, activation='ReLU', sa_div=8)
specific_output = specific_encoder(x_1)
# Print the output shape
print("Specific encoder output shape:", specific_output.shape)

print("\n")
print("-"*16, "Shared Encoder", "-"*16)
# Input shape: 
shared_encoder = SharedEncoder(N, input_shape, filter_num=64, filter_size=5, activation='ReLU', sa_div=8)
shared_output = shared_encoder(x)
# Print the output shape
# Iterate over the output tuple list and print the shape of each tensor
for i, modality_tensor in enumerate(shared_output):
    print(f"Shape of modality {i} output tensor:", modality_tensor.shape)


---------------- Specific Encoder ----------------
--> batch_size, F=1, T, C
Before applying conv layers:  torch.Size([1, 1, 128, 18])
After applying conv layer 1:  torch.Size([1, 64, 62, 18])
After applying conv layer 2:  torch.Size([1, 64, 29, 18])
After applying conv layer 3:  torch.Size([1, 64, 13, 18])
After applying conv layers:  torch.Size([1, 64, 5, 18])
--> batch_size, F=filter_num, T*, C
size  torch.Size([1, 64, 18, 1])
torch.Size([1, 64, 18, 1])
size  torch.Size([1, 64, 18, 1])
torch.Size([1, 64, 18, 1])
size  torch.Size([1, 64, 18, 1])
torch.Size([1, 64, 18, 1])
size  torch.Size([1, 64, 18, 1])
torch.Size([1, 64, 18, 1])
size  torch.Size([1, 64, 18, 1])
torch.Size([1, 64, 18, 1])
After applying self-attention:  torch.Size([1, 64, 18, 5])
After permuting:  torch.Size([1, 18, 64, 5])
After reshaping:  torch.Size([1, 18, 320])
After passing through fc layer:  torch.Size([1, 18, 128])
Specific encoder output shape: torch.Size([1, 18, 128])


---------------- Shared Encoder ----