In [5]:
%load_ext autoreload
%autoreload 2
from hmpai.pytorch.models import *
from hmpai.pytorch.utilities import get_summary_str
import xarray as xr
from torchinfo import summary
from pathlib import Path
from hmpai.data import CHANNELS_2D

In [2]:
data_path = Path("../data/sat1/split_stage_data_100hz.nc")
data = xr.load_dataset(data_path)
n_channels, n_samples, n_classes = (
    len(data.channels),
    len(data.samples),
    len(data.labels),
)
batch_size = 128

In [3]:
get_summ = lambda model, inp_size: summary(
    model, input_size=inp_size, row_settings=["ascii_only"], col_width=22
)
input_size = (batch_size, int(n_samples / 5), n_channels)
input_size_deep = (batch_size, n_samples, n_channels)
input_conv = (batch_size, n_samples, CHANNELS_2D.shape[1], CHANNELS_2D.shape[0])

In [23]:
model = SAT1Base(n_classes)
print(get_summ(model, input_size))

Layer (type)                             Output Shape           Param #
SAT1Base                                 [128, 5]               --
+ PartialConv2d                          [128, 64, 155, 30]     384
+ ReLU                                   [128, 64, 155, 30]     --
+ MaxPool2d                              [128, 64, 77, 30]      --
+ Conv2d                                 [128, 128, 75, 30]     24,704
+ ReLU                                   [128, 128, 75, 30]     --
+ MaxPool2d                              [128, 128, 37, 30]     --
+ Conv2d                                 [128, 256, 35, 30]     98,560
+ ReLU                                   [128, 256, 35, 30]     --
+ MaxPool2d                              [128, 256, 17, 30]     --
+ Flatten                                [128, 130560]          --
+ Linear                                 [128, 128]             16,711,808
+ ReLU                                   [128, 128]             --
+ Dropout                               



In [24]:
model = SAT1Topological(n_classes)
print(get_summ(model, input_conv))

Layer (type)                             Output Shape           Param #
SAT1Topological                          [128, 5]               --
+ PartialConv3d                          [128, 64, 795, 5, 8]   384
+ ReLU                                   [128, 64, 795, 5, 8]   --
+ MaxPool3d                              [128, 64, 397, 5, 8]   --
+ Conv3d                                 [128, 128, 395, 5, 8]  24,704
+ ReLU                                   [128, 128, 395, 5, 8]  --
+ MaxPool3d                              [128, 128, 197, 5, 8]  --
+ Conv3d                                 [128, 256, 195, 5, 8]  98,560
+ ReLU                                   [128, 256, 195, 5, 8]  --
+ MaxPool3d                              [128, 256, 97, 5, 8]   --
+ Flatten                                [128, 993280]          --
+ Linear                                 [128, 128]             127,139,968
+ ReLU                                   [128, 128]             --
+ Dropout                              

In [6]:
model = SAT1TopologicalConv(n_classes)
print(get_summ(model, input_conv))

Layer (type)                             Output Shape              Param #
SAT1TopologicalConv                      [128, 5]                  --
+ PartialConv3d                          [128, 64, 795, 3, 6]      2,944
+ ReLU                                   [128, 64, 795, 3, 6]      --
+ MaxPool3d                              [128, 64, 397, 3, 6]      --
+ Conv3d                                 [128, 128, 395, 1, 4]     221,312
+ ReLU                                   [128, 128, 395, 1, 4]     --
+ MaxPool3d                              [128, 128, 197, 1, 4]     --
+ Conv3d                                 [128, 256, 195, 1, 4]     98,560
+ ReLU                                   [128, 256, 195, 1, 4]     --
+ MaxPool3d                              [128, 256, 97, 1, 4]      --
+ Flatten                                [128, 99328]              --
+ Linear                                 [128, 128]                12,714,112
+ ReLU                                   [128, 128]              



In [25]:
model = SAT1Deep(n_classes)
print(get_summ(model, input_size_deep))

Layer (type)                             Output Shape           Param #
SAT1Deep                                 [128, 5]               --
+ PartialConv2d                          [128, 32, 775, 30]     832
+ ReLU                                   [128, 32, 775, 30]     --
+ MaxPool2d                              [128, 32, 387, 30]     --
+ Conv2d                                 [128, 64, 371, 30]     34,880
+ ReLU                                   [128, 64, 371, 30]     --
+ MaxPool2d                              [128, 64, 185, 30]     --
+ Conv2d                                 [128, 128, 175, 30]    90,240
+ ReLU                                   [128, 128, 175, 30]    --
+ MaxPool2d                              [128, 128, 87, 30]     --
+ Conv2d                                 [128, 256, 83, 30]     164,096
+ ReLU                                   [128, 256, 83, 30]     --
+ MaxPool2d                              [128, 256, 41, 30]     --
+ Conv2d                                 [1



In [27]:
model = SAT1LSTM(n_channels, n_samples, n_classes)
print(get_summ(model, input_size_deep))

Layer (type)                             Output Shape           Param #
SAT1LSTM                                 [128, 5]               --
+ LSTM                                   [102272, 256]          294,912
+ ReLU                                   [128, 799, 256]        --
+ Linear                                 [128, 799, 128]        32,896
+ Linear                                 [128, 799, 5]          645
Total params: 328,453
Trainable params: 328,453
Non-trainable params: 0
Total mult-adds (Units.TERABYTES): 7.72
Input size (MB): 12.27
Forward/backward pass size (MB): 318.27
Params size (MB): 1.31
Estimated Total Size (MB): 331.86




In [28]:
model = SAT1GRU(n_channels, n_samples, n_classes)
print(get_summ(model, input_size_deep))

Layer (type)                             Output Shape           Param #
SAT1GRU                                  [128, 5]               --
+ GRU                                    [102272, 256]          221,184
+ ReLU                                   [128, 799, 256]        --
+ Linear                                 [128, 799, 128]        32,896
+ Linear                                 [128, 799, 5]          645
Total params: 254,725
Trainable params: 254,725
Non-trainable params: 0
Total mult-adds (Units.TERABYTES): 5.79
Input size (MB): 12.27
Forward/backward pass size (MB): 318.27
Params size (MB): 1.02
Estimated Total Size (MB): 331.56




In [15]:
model = TransformerModel(n_features=n_channels, n_heads=10, ff_dim=512, n_layers=6, n_samples=n_samples, n_classes=n_classes)
# Remove masking before summary
print(get_summ(model, input_size_deep))

torch.Size([128, 161, 30])
torch.float32
Layer (type)                             Output Shape           Param #
TransformerModel                         [128, 5]               --
+ Linear                                 [128, 0, 30]           930
+ PositionalEncoding                     [128, 0, 30]           --
|    + Dropout                           [128, 0, 30]           --
+ TransformerEncoder                     [128, 0, 30]           --
|    + ModuleList                        --                     --
|    |    + TransformerEncoderLayer      [128, 0, 30]           35,102
|    |    + TransformerEncoderLayer      [128, 0, 30]           35,102
|    |    + TransformerEncoderLayer      [128, 0, 30]           35,102
|    |    + TransformerEncoderLayer      [128, 0, 30]           35,102
|    |    + TransformerEncoderLayer      [128, 0, 30]           35,102
|    |    + TransformerEncoderLayer      [128, 0, 30]           35,102
+ Linear                                 [128, 5]         