# Multi-asset forecasting with an LSTM

This notebook contains code for training and forecasting 22 stocks from the OMXS30 index
with an LSTM model.

Import data and create the dataloader

In [3]:
import pandas as pd
from data.MultiAssetDataset import MultiAssetDataset
from torch.utils.data import DataLoader

df = pd.read_csv("data/OMXS22_model_features_raw.csv", index_col = "Date", parse_dates = True)

tickers = df["Ticker"].unique().tolist()
features = ["Return","Volume","SMA20","EMA20","RSI14","ReturnVola20"]
window   = 60
horizon  = 100

train_stop = pd.Timestamp("2020-01-01")
val_stop   = pd.Timestamp("2023-12-31")

df_train = df[df.index < train_stop]
df_val   = df[(df.index >= train_stop) & (df.index < val_stop)]
df_test  = df[df.index >= val_stop]

ds_train = MultiAssetDataset(df_train, tickers, features, window, horizon)
ds_val   = MultiAssetDataset(df_val,   tickers, features, window, horizon)
ds_test  = MultiAssetDataset(df_test,  tickers, features, window, horizon)

train_loader = DataLoader(ds_train, batch_size=32, shuffle=True)
val_loader   = DataLoader(ds_val,   batch_size=32)
test_loader  = DataLoader(ds_test,  batch_size=32)

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from multi_asset_models.MultiAssetLSTM import MultiAssetLSTM

model = MultiAssetLSTM(
    n_assets=len(tickers),
    n_features=len(features),
    hidden_size=64,
    num_layers=2,
    dropout=0.1,
    horizon=horizon
)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

model.train()
epoch_loss = 0.0
for Xb, yb in train_loader:
    optimizer.zero_grad()
    y_hat = model(Xb)              # (B, A, 1)
    loss  = criterion(y_hat, yb)   # MSE across all assets
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()
epoch_loss /= len(train_loader)

print(f"Train epoch loss: {epoch_loss:.4f}")

# 7) Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
    for Xb, yb in val_loader:
        y_hat = model(Xb)
        val_loss += criterion(y_hat, yb).item()
val_loss /= len(val_loader)
print(f"Validation loss: {val_loss:.4f}")

Train epoch loss: 0.0303
Validation loss: 0.0388


In [5]:
print(val_loss*100)

3.884540581040912
