In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np

In [2]:
class GRN(nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super().__init__()
        self.fc1 = nn.Linear(d_in, d_hidden)
        self.fc2 = nn.Linear(d_hidden, d_out)
        self.gate = nn.Linear(d_out, d_out)
        self.skip = nn.Linear(d_in, d_out) if d_in != d_out else nn.Identity()
        self.norm = nn.LayerNorm(d_out)

    def forward(self, x):
        h = F.elu(self.fc1(x))
        h = self.fc2(h)
        g = torch.sigmoid(self.gate(h))
        return self.norm(g * h + (1 - g) * self.skip(x))


In [3]:
class StaticEncoder(nn.Module):
    def __init__(self, d_static, d_model):
        super().__init__()
        self.grn = GRN(d_static, d_model, d_model)

    def forward(self, s):
        return self.grn(s)


In [4]:
class VariableSelectionNetwork(nn.Module):
    def __init__(self, num_vars, d_model):
        super().__init__()
        self.value_proj = nn.Linear(num_vars, d_model)
        self.weight_proj = nn.Linear(num_vars, num_vars)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # x: [B, T, num_vars]
        values = self.value_proj(x)               # [B, T, d_model]
        weights = self.softmax(self.weight_proj(x))  # [B, T, num_vars]
        fused = values
        return fused, weights


In [5]:
class ContextEnrichment(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.grn = GRN(d_model * 2, d_model, d_model)

    def forward(self, temporal, context):
        context = context.unsqueeze(1).expand_as(temporal)
        return self.grn(torch.cat([temporal, context], dim=-1))


In [6]:
class TemporalAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=num_heads,
            batch_first=True
        )
        self.grn = GRN(d_model, d_model, d_model)

    def forward(self, x):
        T = x.size(1)
        mask = torch.triu(torch.ones(T, T), diagonal=1).bool().to(x.device)
        attn_out, attn_weights = self.attn(x, x, x, attn_mask=mask)
        out = self.grn(attn_out + x)
        return out, attn_weights


In [7]:
class PredictionHead(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.fc = nn.Linear(d_model, 1)

    def forward(self, x):
        return self.fc(x[:, -1])


In [8]:
class MiniTFT(nn.Module):
    def __init__(self, n_obs, n_known, d_static, d_model=32):
        super().__init__()
        self.static_enc = StaticEncoder(d_static, d_model)

        self.obs_vsn = VariableSelectionNetwork(n_obs, d_model)
        self.known_vsn = VariableSelectionNetwork(n_known, d_model)

        self.enrich = ContextEnrichment(d_model)
        self.attn = TemporalAttention(d_model, num_heads=4)
        
        self.post_attn_grn = GRN(d_model, d_model, d_model)
        self.layer_norm = nn.LayerNorm(d_model)
        
        self.head = PredictionHead(d_model)

    def forward(self, obs, known, static):
        # Preparation for 3 types of inputs
        s = self.static_enc(static)
        obs_fused,_ = self.obs_vsn(obs)
        known_fused,_ = self.known_vsn(known)
        # Locality enhanchement
        x = obs_fused + known_fused
        # Temporal processing
        x = self.enrich(x, s)
        # Temporal attention
        attn_out,attn_weights = self.attn(x)
        x = self.layer_norm(attn_out + x)
        x = self.layer_norm(self.post_attn_grn(x) + x)
        return self.head(x), attn_weights


In [9]:

df = pd.read_parquet("../../data/features/BTC_features.parquet")
print(df.shape)
print(df.info())
print(df.head(1))
print("---------------------------------------")
print(df.columns.tolist())


(5481, 819)
<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 5481 entries, 2010-12-09 to 2025-12-10
Columns: 819 entries, open to ret_lag_144
dtypes: float64(817), object(2)
memory usage: 34.3+ MB
None
                open      high      low     close    volume symbol  \
date                                                                 
2010-12-09 -0.674481 -0.676188 -0.67272 -0.674708 -0.640733    BTC   

                   source  building_permits  consumer_confidence       cpi  \
date                                                                         
2010-12-09  alpha_vantage         -2.064399            -0.387082 -1.290801   

            ...  close_lag_21  ret_lag_21  close_lag_34  ret_lag_34  \
date        ...                                                       
2010-12-09  ...     -0.669729    3.274749     -0.667228    2.574162   

            close_lag_55  ret_lag_55  close_lag_89  ret_lag_89  close_lag_144  \
date                                                 

In [10]:
static_cols = ["symbol", "source"]
observed_cols = [
    "open", "high", "low", "close", "volume",
    "ema_34", "ema_89", "ema_200",
    "rsi_14", "macd", "log_return", "vol_20"
] + [c for c in df.columns if "lag_" in c]
known_cols = [
    "building_permits", "consumer_confidence", "cpi",
    "fed_funds_rate", "gdp", "industrial_production",
    "money_supply_m1", "money_supply_m2",
    "nonfarm_payrolls", "pce_inflation",
    "ppi", "retail_sales", "trade_balance",
    "unemployment_rate"
] + [c for c in df.columns if c.startswith("fed_emb_")]
target_col = "next_close"

# train val test splitting
n = len(df)
train_end = int(0.70 * n)
val_end = int(0.85 * n)

df_train = df.iloc[:train_end]
df_val   = df.iloc[train_end:val_end]
df_test  = df.iloc[val_end:]


X_train = df_train[observed_cols + known_cols]
X_val   = df_val[observed_cols + known_cols]
X_test  = df_test[observed_cols + known_cols]

y_train = df_train[target_col].values
y_val   = df_val[target_col].values
y_test  = df_test[target_col].values
# Sliding window
LOOKBACK = 89

def build_windows(X, y, lookback):
    X_out, y_out = [], []
    for i in range(lookback, len(X)):
        X_out.append(X[i-lookback:i])
        y_out.append(y[i])
    return np.array(X_out), np.array(y_out)

Xtr, ytr = build_windows(X_train, y_train, LOOKBACK)
Xva, yva = build_windows(X_val, y_val, LOOKBACK)
Xte, yte = build_windows(X_test, y_test, LOOKBACK)

# split input feature tensors
n_obs   = len(observed_cols)
n_known = len(known_cols)
#Tensor
obs_tr   = Xtr[:, :, :n_obs]
known_tr = Xtr[:, :, n_obs:]

static_tr = np.zeros((len(Xtr), 2))  # placeholder encoding


In [11]:
def count_parameters(model):
    # Total parameters (including those frozen/not being trained)
    total_params = sum(p.numel() for p in model.parameters())
    
    # Trainable parameters only
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")
    return total_params

# Usage:
model = MiniTFT(n_obs=n_obs,
    n_known=n_known,
    d_static=2,
    d_model=32)
count_parameters(model)

Total Parameters: 659,161
Trainable Parameters: 659,161


659161

In [12]:
import torch
from torch.utils.data import TensorDataset, DataLoader
import os 
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from tqdm import tqdm
import gc
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler 

In [None]:
#CPu run
device = "cpu"

model = MiniTFT(
    n_obs=n_obs,
    n_known=n_known,
    d_static=2,
    d_model=32
).to(device)

pred = model(
    torch.tensor(obs_tr).float().to(device),
    torch.tensor(known_tr).float().to(device),
    torch.tensor(static_tr).float().to(device)
)

In [14]:
# Testing only 
device = "cuda" if torch.cuda.is_available() else "cpu"
# Try only 32 samples instead of the whole dataset
batch_size = 32

model = MiniTFT(
    n_obs=n_obs,
    n_known=n_known,
    d_static=2,
    d_model=32
).to(device)

with torch.no_grad(): # Use this for testing
    pred, weights = model(
        torch.tensor(obs_tr[:batch_size]).float().to(device),
        torch.tensor(known_tr[:batch_size]).float().to(device),
        torch.tensor(static_tr[:batch_size]).float().to(device)
    )

print("Prediction shape:", pred.shape)
print("Attention weights shape:", weights.shape)

Prediction shape: torch.Size([32, 1])
Attention weights shape: torch.Size([32, 89, 89])


In [None]:
#GPU run
device = "cuda"

model = MiniTFT(
    n_obs=n_obs,
    n_known=n_known,
    d_static=2,
    d_model=32
).to(device)

# 1. Wrap data in a Dataset
train_dataset = TensorDataset(
    torch.tensor(obs_tr).float(),
    torch.tensor(known_tr).float(),
    torch.tensor(static_tr).float(),
    torch.tensor(ytr).float()
)

# 2. Create a DataLoader for batching
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, pin_memory=True,persistent_workers=True, num_workers=4)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()
scaler = GradScaler()

model.train()

print(f"Starting inference on {len(train_dataset)} samples...")
for epoch in range(10):  
    total_losses = 0.0
    for batch_obs, batch_known, batch_static, target in tqdm(train_loader, desc="Processing Batches"):
        optimizer.zero_grad(set_to_none=True)
        with autocast():
            batch_obs = batch_obs.to(device)
            batch_known = batch_known.to(device)
            batch_static = batch_static.to(device)
            target =target.to(device).unsqueeze(-1)
            
            
            
            output, _ = model(batch_obs, batch_known, batch_static)
            loss = criterion(output, target)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_losses += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_losses/len(train_loader):.4f}")

model_path = "mini_tft_model.pth"
torch.save(model.state_dict(), model_path)

print(f" Model exported successfully to: {model_path}")
torch.cuda.empty_cache()
gc.collect()

  scaler = GradScaler()


Starting inference on 3747 samples...


  with autocast():
Processing Batches: 100%|██████████| 469/469 [00:11<00:00, 42.46it/s]


Epoch 1, Loss: 128499375.0314


Processing Batches: 100%|██████████| 469/469 [00:07<00:00, 66.29it/s]


Epoch 2, Loss: 129387471.4260


Processing Batches: 100%|██████████| 469/469 [00:07<00:00, 66.62it/s]


Epoch 3, Loss: 127096782.3849


Processing Batches: 100%|██████████| 469/469 [00:07<00:00, 65.98it/s]


Epoch 4, Loss: 125750588.3652


Processing Batches: 100%|██████████| 469/469 [00:07<00:00, 65.82it/s]


Epoch 5, Loss: 124194728.7749


Processing Batches: 100%|██████████| 469/469 [00:07<00:00, 65.75it/s]


Epoch 6, Loss: 122495367.6297


Processing Batches: 100%|██████████| 469/469 [00:07<00:00, 67.00it/s]


Epoch 7, Loss: 120652151.2581


Processing Batches: 100%|██████████| 469/469 [00:07<00:00, 66.93it/s]


Epoch 8, Loss: 118474632.3154


Processing Batches: 100%|██████████| 469/469 [00:07<00:00, 66.24it/s]


Epoch 9, Loss: 116329860.1734


Processing Batches: 100%|██████████| 469/469 [00:07<00:00, 65.16it/s]


Epoch 10, Loss: 114048092.1588
 Model exported successfully to: mini_tft_model.pth


252