### [Only In Google Colab] Setup TiRex in Colab

Make sure that you selected a **GPU runtime in Google Colab**
Runtime ->  Change Runtime Type -> Select A100 / L4 / T4

In [1]:
# Only for Google Colab Notebook!

import os

# Clone TiRep Repo
!git clone https://github.com/NX-AI/tirex

# Install TiRex
os.chdir('/content/tirex')
!pip install .[gluonts]

# Set Workin Dir to notebooks folder
os.chdir('/content/tirex/examples')

Cloning into 'tirex'...
remote: Enumerating objects: 50, done.[K
remote: Counting objects: 100% (50/50), done.[K
remote: Compressing objects: 100% (37/37), done.[K
remote: Total 50 (delta 10), reused 49 (delta 10), pack-reused 0 (from 0)[K
Receiving objects: 100% (50/50), 54.81 KiB | 18.27 MiB/s, done.
Resolving deltas: 100% (10/10), done.
Processing /content/tirex
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting xlstm (from tirex==1.0.0)
  Downloading xlstm-2.0.4-py3-none-any.whl.metadata (24 kB)
Collecting ninja (from tirex==1.0.0)
  Downloading ninja-1.11.1.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.0 kB)
Collecting lightning (from tirex==1.0.0)
  Downloading lightning-2.5.1.post0-py3-none-any.whl.metadata (39 kB)
Collecting dacite (from tirex==1.0.0)
  Downloading dacite-1.9.2-py3-none-any.whl.metadata (17 kB)
Collecting

### Imports and Load Data

In [2]:
from pathlib import Path

import numpy as np
import torch
from util_plot import plot_fc

# import os
# os.environ["TIREX_NO_CUDA"] = "1"   # Experimental!!: Turns off sLSTM CUDA kernels if you have problems but be aware of the downsides! (see repository FAQ)
from tirex import ForecastModel, load_model


### TiRex Forecast in 2 Lines

1) Load Model
2) Generate Forecast

In [3]:
import yfinance as yf
import numpy as np
from datetime import datetime
import pandas as pd

# === Parameters ===
INTERVAL = "1h"
CONTEXT_LEN = 48         # Use 48 hours as input context
PREDICTION_LEN = 3       # Predict 3 hours ahead
SPREAD = 0.03
THRESHOLD = 0.02

# === Load latest data ===
df = yf.download("JPY=X", period="5d", interval=INTERVAL)
series = df['Close'].dropna().reset_index(drop=True).to_numpy()
import pytz

# === Timezone and Entry/Exit Time Calculation ===
latest_candle_time_utc = df.index[-1]
latest_candle_time_tokyo = latest_candle_time_utc.tz_convert('Asia/Tokyo')

# Entry = next hour after latest candle
target_entry_time_utc = latest_candle_time_utc + pd.Timedelta(hours=1)
target_entry_time_tokyo = target_entry_time_utc.tz_convert('Asia/Tokyo')

# Exit = entry + prediction_len hours
exit_time_utc = target_entry_time_utc + pd.Timedelta(hours=PREDICTION_LEN)
exit_time_tokyo = exit_time_utc.tz_convert('Asia/Tokyo')

# Local time now in Tokyo
now_tokyo = datetime.now(pytz.timezone("Asia/Tokyo"))

# === Extract timestamp info ===
last_ctx_index = -(PREDICTION_LEN + 1)
predicting_from_time_utc = df.index[last_ctx_index]  # Last timestamp in context
predicting_from_time_local = predicting_from_time_utc.tz_convert('Asia/Tokyo')  # your local time
exit_time_local = predicting_from_time_local + pd.Timedelta(hours=PREDICTION_LEN)

# === Use only most recent window ===
ctx_s = series[-(CONTEXT_LEN + PREDICTION_LEN):-PREDICTION_LEN]

# === Normalize ===
ctx_mean = ctx_s.mean()
ctx_std = ctx_s.std()
ctx_s_norm = (ctx_s - ctx_mean) / ctx_std

# === Load model and forecast ===
model = load_model("NX-AI/TiRex")
quantiles, _ = model.forecast(ctx_s_norm, prediction_length=PREDICTION_LEN)

# === Extract prediction ===
def extract_median_forecast(quantiles):
    median_index = 4  # 0.5 quantile
    forecast_norm = []
    for q in quantiles:
        q = np.array(q).flatten()
        if len(q) > median_index:
            forecast_norm.append(q[median_index])
        else:
            forecast_norm.append(q[0])
    return np.array(forecast_norm)

forecast_norm = extract_median_forecast(quantiles)
forecast = forecast_norm * ctx_std + ctx_mean
forecast = np.clip(forecast, 110, 160)
forecast_trend = forecast[-1] - forecast[0]
prediction = forecast.mean()
current_price = float(ctx_s[-1])

# === Generate trading decision ===
if forecast_trend > THRESHOLD:
    signal = "SHORT (contrarian)"
elif forecast_trend < -THRESHOLD:
    signal = "BUY (contrarian)"
else:
    signal = "FLAT"

# === Output ===
from_zone = 'Asia/Tokyo'  # set your local timezone


print("=== TiRex 1H Forecast ===")
print(f"📈 Latest Candle (UTC):      {latest_candle_time_utc.strftime('%Y-%m-%d %H:%M')}")
print(f"📈 Latest Candle (Tokyo):    {latest_candle_time_tokyo.strftime('%Y-%m-%d %H:%M')}")
print(f"🎯 Target Entry Time (Tokyo): {target_entry_time_tokyo.strftime('%Y-%m-%d %H:%M')}")
print(f"📅 Target Exit Time (Tokyo):  {exit_time_tokyo.strftime('%Y-%m-%d %H:%M')}")
print(f"🕐 Local Time Now (Tokyo):    {now_tokyo.strftime('%Y-%m-%d %H:%M')}")
print(f"Current Price:  {current_price:.3f}")
print(f"Predicted Mean: {prediction:.3f}")
print(f"Forecast Trend: {forecast_trend:.4f}")
print(f"Suggested Trade: {signal}")

YF.download() has changed argument auto_adjust default to True


[*********************100%***********************]  1 of 1 completed
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.ckpt:   0%|          | 0.00/141M [00:00<?, ?B/s]

Using /root/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...
Creating extension directory /root/.cache/torch_extensions/py311_cu124/slstm_HS512BS8NH4NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py311_cu124/slstm_HS512BS8NH4NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module slstm_HS512BS8NH4NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module slstm_HS512BS8NH4NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0...
  @conditional_decorator(
  @conditional_decorator(
Using /root/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...
No modifications detected for re-loaded extension module slstm_HS512BS8NH4NS4DBfDRbDWbDGbDSbDAf

=== TiRex 1H Forecast ===
📈 Latest Candle (UTC):      2025-06-10 04:00
📈 Latest Candle (Tokyo):    2025-06-10 13:00
🎯 Target Entry Time (Tokyo): 2025-06-10 14:00
📅 Target Exit Time (Tokyo):  2025-06-10 17:00
🕐 Local Time Now (Tokyo):    2025-06-10 13:28
Current Price:  145.010
Predicted Mean: 144.414
Forecast Trend: 1.2919
Suggested Trade: SHORT (contrarian)


  current_price = float(ctx_s[-1])


### Input Options

TiRex supports forecasting with different input types

In [None]:
data = torch.tensor(np.genfromtxt(Path.cwd() / "air_passengers.csv"))  # Load Example

# Torch tensor (2D or 1D)
quantiles, means = model.forecast(context=data, prediction_length=24)
print("Predictions (Torch tensor):\n", type(quantiles), quantiles.shape)

# List of Torch tensors (List of 1D) - will be padded
list_torch_data = [data, data, data]
quantiles, means = model.forecast(context=list_torch_data, prediction_length=24, batch_size=2)
print("Predictions (List of Torch tensors):\n", type(quantiles), quantiles.shape)

# NumPy array (2D or 1D)
quantiles, means = model.forecast(context=data.numpy(), prediction_length=24, output_type="torch")
print("Predictions (NumPy):\n", type(quantiles), quantiles.shape)


# List of NumPy arrays (List of 1D) - will be padded
list_numpy_data = [data.numpy()]  # Split into 3 sequences
quantiles, means = model.forecast(context=list_numpy_data, prediction_length=24)
print("Predictions (List of NumPy arrays):\n", type(quantiles), quantiles.shape)


# GluonTS Dataset
try:
    from typing import cast

    from gluonts.dataset import Dataset

    gluon_dataset = cast(Dataset, [{"target": data, "item_id": 1}, {"target": data, "item_id": 22}])
    quantiles, means = model.forecast_gluon(gluon_dataset, prediction_length=24)
    print("Predictions GluonDataset:\n", type(quantiles), quantiles.shape)
    # If you use also `glutonts` as your output type the start_time and item_id get preserved accordingly
    predictions_gluon = model.forecast_gluon(gluon_dataset, prediction_length=24, output_type="gluonts")
    print("Predictions GluonDataset:\n", type(predictions_gluon), type(predictions_gluon[0]))
except Exception as e:
    print(e)
    # To use the gluonts function you need to install the optional dependency
    # pip install tirex[gluonts]
    pass

### Output Options


TiRex supports different output types for the forecasts

In [None]:
data = torch.tensor(np.genfromtxt(Path.cwd() / "air_passengers.csv"))  # Load Example

# Default: 2D Torch tensor
quantiles, means = model.forecast(context=data, prediction_length=24, output_type="torch")
print("Predictions:\n", type(quantiles), quantiles.shape)


# 2D Numpy Array
quantiles, means = model.forecast(context=data, prediction_length=24, output_type="numpy")
print("Predictions:\n", type(quantiles), quantiles.shape)


# Iterate by patch
# You can also use the forecast function as iterable. This might help with big datasets. All output_types are supported
for i, fc_batch in enumerate(
    model.forecast(context=[data, data, data, data, data], batch_size=2, output_type="torch", yield_per_batch=True)
):
    quantiles, means = fc_batch
    print(f"Predictions batch {i}:\n", type(quantiles), quantiles.shape)


try:
    # QuantileForecast (GluonTS)
    predictions_gluonts = model.forecast(context=data, prediction_length=24, output_type="gluonts")
    print("Predictions (GluonTS Quantile Forecast):\n", type(predictions_gluon), type(predictions_gluon[0]))
    predictions_gluonts[0].plot()
except Exception as e:
    print(e)
    # To use the gluonts function you need to install the optional dependency
    # pip install tirex[gluonts]