## Model Overview

- **Total Parameters**: 976,219 (~976K)
- **Model Size**: 3.72 MB (FP32)
- **Complexity**: Medium (efficient for deployment)
- **Tasks**: Multi-task learning (Classification + Regression)

### Performance Summary

| Metric | Value |
|--------|-------|
| Classification Accuracy | 53.79% (baseline) |
| MAE (Regression) | 0.014468 |
| RMSE (Regression) | 0.020668 |

### Generalization Test Results

| Subset | Period | Accuracy | MAE | Notes |
|--------|--------|----------|-----|-------|
| Subset 1 | 2018-03-01 to 2020-10-29 | 51.66% | 0.0187 | Much earlier period |
| Subset 4 | 2018-11-28 to 2021-07-28 | 53.23% | 0.0207 | Earlier period |
| Subset 9 | 2020-03-04 to 2022-11-02 | 54.13% | 0.0175 | Overlaps training |
| Subset 10 | 2020-06-03 to 2023-02-03 | 52.42% | 0.0154 | Similar timeframe |
| **Subset 12** | **2020-12-02 to 2023-08-02** | **53.79%** | **0.0145** | **Training baseline** |
| Subset 13 | 2021-03-05 to 2023-11-03 | 55.78% | 0.0137 | **Best - extends beyond** |

**Conclusion**: Model shows **good generalization** with only 2.29% avg accuracy deviation across time periods.

## Architecture Components

### 1. Input Processing (152,384 params)

```
start_emb_l: Low-frequency embedding (363 → 128)
start_emb_h: High-frequency embedding (363 → 128)
te_emb: Time encoding embedding (55 → 128)
```

**Purpose**: Transform raw 360 Alpha factors into embedding space, separately for low and high frequencies.

---

### 2. Wavelet Transform Module

- **Type**: DWT1DForward (Discrete Wavelet Transform)
- **Wavelet**: sym2 (Symlet with 2 vanishing moments)
- **Level**: 1 decomposition

**Function**: Decomposes time series into:
- **Low-frequency (XL)**: Captures trends and smooth patterns
- **High-frequency (XH)**: Captures noise and rapid fluctuations

---

### 3. Dual-Frequency Encoder (660K params)

**2 Transformer Layers × 330K params each**

Each layer contains:

#### a) TCN - Temporal Convolutional Network (32,896 params)
- Captures local temporal patterns
- Uses dilated convolutions for larger receptive field

#### b) Temporal Attention (66,048 params)
- Query, Key, Value, Output projections
- Feedforward network
- Captures time dependencies across the 20-day window

#### c) Spatial Attention - Low Frequency (66,305 params)
- Models correlations between stocks
- Attention over 255 stocks
- Focuses on trend relationships

#### d) Spatial Attention - High Frequency (66,305 params)
- Models high-frequency co-movements
- Captures rapid market reactions
- Separate from low-freq to avoid noise mixing

**Key Design**: Separate processing paths for low and high frequencies prevents noise contamination of trend signals.

---

### 4. Adaptive Fusion Layer (115,088 params)

```
Cross-attention mechanism:
- Query from low-freq
- Keys/Values from both low and high freq
- Learns optimal weighting automatically
```

**Purpose**: Intelligently combines low and high frequency representations based on learned importance.

---

### 5. Graph Integration

- **Graph Embeddings**: 128-dimensional vectors from Struc2vec
- **Graph Structure**: 32,385 edges between 255 stocks
- **Method**: Correlation-based adjacency matrix → Struc2vec → GAT

**Purpose**: Captures structural similarities between stocks beyond simple correlations.

---

### 6. Multi-Task Heads (33,411 params)

#### Classification Head (16,770 params)
- Input: 128-dim fused features
- Output: 2 classes (up/down)
- Loss: Cross-entropy

#### Regression Head (16,641 params)
- Input: 128-dim fused features
- Output: 1 value (continuous return)
- Loss: MSE

**Shared Representations**: Both tasks benefit from the same learned features, improving generalization.

## Data Flow Diagram

```
Input: OHLCV Price-Volume Factors
         (360 factors × 20 days × 255 stocks)
                    |
                    v
         Wavelet Transform (DWT)
                    |
         +----------+----------+
         |                     |
    Low-Freq (XL)         High-Freq (XH)
    (trends)              (noise/rapid)
         |                     |
         v                     v
  Embedding (363→128)   Embedding (363→128)
         |                     |
         v                     v
  ┌──────────────────┐  ┌──────────────────┐
  │ Layer 1:         │  │ Layer 1:         │
  │ - TCN            │  │ - TCN            │
  │ - Temporal Attn  │  │ - Temporal Attn  │
  │ - Spatial Attn   │  │ - Spatial Attn   │
  └──────────────────┘  └──────────────────┘
         |                     |
         v                     v
  ┌──────────────────┐  ┌──────────────────┐
  │ Layer 2:         │  │ Layer 2:         │
  │ - TCN            │  │ - TCN            │
  │ - Temporal Attn  │  │ - Temporal Attn  │
  │ - Spatial Attn   │  │ - Spatial Attn   │
  └──────────────────┘  └──────────────────┘
         |                     |
         +----------+----------+
                    |
                    v
          Adaptive Fusion
        (Cross-attention)
                    |
                    v
          Fused Features (128-dim)
         + Graph Embeddings (128-dim)
                    |
         +----------+----------+
         |                     |
         v                     v
  Classification         Regression
  (2 classes)            (1 value)
  up/down                return %
```

## Parameter Distribution

| Component | Parameters | Percentage |
|-----------|------------|------------|
| Features Module | 942,808 | 96.58% |
| Classifier Module | 33,411 | 3.42% |

### Why 96.58% in Features?
- Most computation goes into learning good representations
- Dual encoders (2 layers each) = major parameter sink
- Multi-head attention mechanisms are parameter-heavy
- Once features are learned, prediction heads are lightweight

## Key Hyperparameters

| Parameter | Value | Notes |
|-----------|-------|-------|
| Input timesteps (T1) | 20 | 20 days of historical data |
| Output timesteps (T2) | 2 | Predict next 2 days |
| Transformer layers (L) | 2 | Per frequency pathway |
| Attention heads (h) | 1 | Single-head attention |
| Embedding dim (d) | 128 | Hidden representation size |
| Sparsity ratio (s) | 1.0 | 100% of connections kept |
| Batch size | 12 | Training batch size |
| Learning rate | 0.001 | Adam optimizer |
| Train/Val/Test | 75/12.5/12.5% | Data split |

**Note**: Sparsity ratio of 1.0 means no sparsity pruning is applied (all spatial attention connections are used).

## Input/Output Specifications

### Input Tensors

| Tensor | Shape | Description |
|--------|-------|-------------|
| XL | [batch, 20, 255, 363] | Low-frequency decomposition |
| XH | [batch, 20, 255, 363] | High-frequency decomposition |
| TE | [batch, 20, 255, time_dim] | Time encoding (day/week/month) |
| XC | [batch, 20, 255] | Trend indicators (binary up/down) |
| bonus_X | [batch, 255, 128] | Graph embeddings per sample |
| adjgat | [255, 128] | Graph structure (fixed) |

### Output Tensors

| Tensor | Shape | Description |
|--------|-------|-------------|
| Classification | [batch, 2, 255, 2] | Logits for up/down per day |
| Regression | [batch, 2, 255] | Predicted returns per day |

**Note**: In practice, only the last timestep (T2=2) predictions are used for evaluation.

## Questions for Deep Dive

### Architecture Questions:
1. **Why separate low/high frequency encoders?**
   - TODO: Investigate if shared encoder performs worse
   - TODO: Ablation study on dual vs single encoder

2. **Why only 1 attention head?**
   - TODO: Test multi-head attention (h=4, h=8)
   - TODO: Check if single head is sufficient for stock data

3. **Why T2=2 (predict 2 days ahead)?**
   - TODO: Experiment with T2=1, T2=5, T2=10
   - TODO: Analyze prediction decay over longer horizons

4. **Sparsity ratio = 1.0 (no pruning)?**
   - TODO: Test s=0.1, 0.3, 0.5 for efficiency gains
   - TODO: Measure impact on accuracy vs speed

### Data Questions:
5. **Why 360 Alpha factors?**
   - TODO: Feature importance analysis
   - TODO: Can we reduce to top 100-200 factors?

6. **Graph embedding effectiveness?**
   - TODO: Ablation study without graph embeddings
   - TODO: Compare Struc2vec vs other graph methods (Node2vec, GraphSAGE)

### Performance Questions:
7. **Why does Subset 13 perform best?**
   - TODO: Analyze market conditions in 2021-2023
   - TODO: Compare volatility patterns across subsets

8. **Classification accuracy ~54% vs random 50%?**
   - TODO: Is this typical for stock prediction?
   - TODO: Compare with baseline models (LSTM, GRU, simple MLP)

## Code to Load and Inspect Model

In [None]:
# Load model for inspection
import torch
import configparser
from Stockformermodel.Multitask_Stockformer_models import Stockformer

# Configuration
config = configparser.ConfigParser()
config.read('config/Multitask_Stock_Subset12.conf')

# Model parameters
infeature = 363
outfea_class = 2
outfea_regress = 1
T1 = 20
T2 = 2
L = 2
h = 1
d = 128
s = 1.0

# Create model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Stockformer(infeature, h*d, outfea_class, outfea_regress, L, h, d, s, T1, T2, device).to(device)

# Load weights
model_path = 'cpt/STOCK/saved_model_Multitask_2020-12-02_2023-08-02'
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

print(f"Model loaded successfully!")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Inspect specific layers
print("\nDual Encoder Structure:")
for i, layer in enumerate(model.features.dual_encoder):
    print(f"\nLayer {i+1}:")
    print(f"  TCN: {sum(p.numel() for p in layer.tcn.parameters()):,} params")
    print(f"  Temporal Attention: {sum(p.numel() for p in layer.tatt.parameters()):,} params")
    print(f"  Spatial Attention (Low): {sum(p.numel() for p in layer.ssal.parameters()):,} params")
    print(f"  Spatial Attention (High): {sum(p.numel() for p in layer.ssah.parameters()):,} params")

In [None]:
# TODO: Add visualization code
# - Attention weight heatmaps
# - Feature importance analysis
# - Prediction distribution plots
# - Temporal pattern visualization

## References

1. **Paper**: Ma, Bohan; Xue, Yushan; Lu, Yuan & Chen, Jing. (2025). "Stockformer: A price-volume factor stock selection model based on wavelet transform and multi-task self-attention networks". *Expert Systems with Applications*, 273, 126803.

2. **GitHub**: https://github.com/Eric991005/Multitask-Stockformer

3. **Qlib Documentation**: https://github.com/microsoft/qlib

4. **Struc2vec**: https://github.com/shenweichen/GraphEmbedding

---

**Next Steps**:
- [ ] Run ablation studies on architecture components
- [ ] Compare with baseline models (LSTM, GRU, Transformer-only)
- [ ] Analyze feature importance
- [ ] Test hyperparameter variations
- [ ] Adapt architecture for NIFTY-200 Indian stocks