In [None]:
from signalflow.nn.layer.lstm_encoder import LSTMEncoder
from signalflow.nn.layer.transformer_encoder import TransformerEncoder
from signalflow.nn.head.classifier import MLPClassifierHead
from signalflow.nn.model.temporal_classificator import TemporalClassificator

# Variant 1: LSTM + Simple head
encoder = LSTMEncoder(
    input_size=10,
    hidden_size=64,
    num_layers=2,
    bidirectional=True  # output_size = 128
)

head = MLPClassifierHead(
    input_size=encoder.output_size,  # 128
    num_classes=3,
    hidden_sizes=[],  # No hidden layers
)

model = TemporalClassificator(
    encoder=encoder,
    head=head,
    num_classes=3
)

# Variant 2: Transformer + MLP head
encoder = TransformerEncoder(
    input_size=10,
    d_model=64,
    nhead=4,
    num_layers=3,
    pooling="mean"
)

head = MLPClassifierHead(
    input_size=encoder.output_size,  # 64
    num_classes=3,
    hidden_sizes=[128, 64],  # Two hidden layers
    dropout=0.3
)

model = TemporalClassificator(
    encoder=encoder,
    head=head,
    learning_rate=1e-3
)

# Train
import lightning as L

trainer = L.Trainer(max_epochs=50, accelerator="auto")
trainer.fit(model, train_loader, val_loader)