# STEMS Quickstart Guide

This notebook demonstrates the basic usage of the STEMS (Segmented Time-series End-to-end Multi-scale/Multimodal System) package.

In [None]:
import numpy as np
import torch
from stems.model import STEMSModel
from stems.data import TimeSeriesDataset, create_dataloader
from stems.utils import set_seed

## 1. Create Synthetic Data

In [None]:
# Set random seed for reproducibility
set_seed(42)

# Generate synthetic time series data
n_samples = 100
time_steps = 1000
n_channels = 1

data = []
labels = []

for _ in range(n_samples):
    # Create a random time series
    x = np.random.randn(n_channels, time_steps)
    y = np.random.randint(0, 2)
    data.append(x)
    labels.append(y)

## 2. Create Dataset and DataLoader

In [None]:
# Create dataset
dataset = TimeSeriesDataset(data, labels)

# Create dataloader
dataloader = create_dataloader(
    dataset,
    batch_size=16,
    shuffle=True,
    dynamic_batching=True
)

## 3. Initialize and Train Model

In [None]:
# Initialize model
model = STEMSModel()

# Set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

# Training loop
n_epochs = 5

for epoch in range(n_epochs):
    model.train()
    total_loss = 0
    
    for batch_x, batch_y in dataloader:
        optimizer.zero_grad()
        
        # Forward pass
        logits, kl_div = model(batch_x)
        loss = criterion(logits, batch_y) + kl_div
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

## 4. Make Predictions

In [None]:
# Switch to evaluation mode
model.eval()

# Get predictions for a batch
batch_x, batch_y = next(iter(dataloader))
with torch.no_grad():
    logits, _ = model(batch_x)
    probs = torch.softmax(logits, dim=-1)
    preds = torch.argmax(logits, dim=-1)

print("Predictions:", preds)
print("Ground Truth:", batch_y)