# Fine-tuning TimeSFM con PyTorch + Inferencia de Valor/Volatilidad

Este _notebook_ combina:

1. **Descarga de datos BTC a 5m** mediante `get_binance_klines(...)`.
2. **Creación de dataset** de entrenamiento (train/val) con una clase `TimeSeriesDataset`.
3. **Bucle de entrenamiento** (fine-tuning) partiendo del checkpoint genérico `google/timesfm-2.0-500m-pytorch`.
4. **Guardado** de un nuevo checkpoint local con los pesos ajustados.
5. **Código Flask** (al final) que levanta un servidor y usa el modelo fine-tuneado para inferir valor y volatilidad.

De esta manera, podrás **afinar** TimeSFM en tus datos de BTC 5m y luego **servirlo** en tu API.


In [ ]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import requests
import pandas as pd
import numpy as np
import time
import torch
import logging
import traceback
from flask import Flask, Response
from sklearn.preprocessing import RobustScaler
import timesfm
import os
from torch.utils.data import Dataset, DataLoader
from huggingface_hub import snapshot_download

from timesfm import TimesFm, TimesFmCheckpoint, TimesFmHparams
from timesfm.pytorch_patched_decoder import PatchedTimeSeriesDecoder

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

# Este es el checkpoint base en HuggingFace
MODEL_REPO = "google/timesfm-2.0-500m-pytorch"

# Definimos hparams genéricos. Se reescribirá horizon_len cuando sea necesario.
tsfm_hparams = TimesFmHparams(
    backend="pytorch",
    per_core_batch_size=32,
    horizon_len=128,        # Se puede sobrescribir
    input_patch_len=32,
    output_patch_len=128,
    num_layers=50,
    model_dims=1280,
    use_positional_embedding=False
)


## 1) Funciones para descargar BTC 5m
Usamos las mismas que en tu código:


In [ ]:
def get_binance_klines(symbol="BTCUSDT", interval="5m", max_candles=2000):
    df_all = pd.DataFrame()
    limit = 1000
    url = "https://api.binance.com/api/v3/klines"
    params = {"symbol": symbol, "interval": interval, "limit": limit}
    resp = requests.get(url, params=params)
    if resp.status_code != 200:
        raise Exception(str(resp.text))
    data = resp.json()
    if not data:
        raise Exception("Empty response")

    cols = [
        "open_time","open","high","low","close","volume",
        "close_time","quote_asset_volume","number_of_trades",
        "taker_buy_base_asset_volume","taker_buy_quote_asset_volume","ignore"
    ]
    df_part = pd.DataFrame(data, columns=cols)
    df_part["date"] = pd.to_datetime(df_part["close_time"], unit="ms")
    df_all = pd.concat([df_all, df_part], ignore_index=True)
    last_close_time = int(df_part["close_time"].iloc[-1])

    while True:
        if len(df_all) >= max_candles:
            break
        time.sleep(0.2)
        params = {
            "symbol": symbol,
            "interval": interval,
            "limit": limit,
            "startTime": last_close_time + 1
        }
        resp = requests.get(url, params=params)
        if resp.status_code != 200:
            raise Exception(str(resp.text))
        data = resp.json()
        if not data:
            break
        df_part = pd.DataFrame(data, columns=cols)
        df_part["date"] = pd.to_datetime(df_part["close_time"], unit="ms")
        df_all = pd.concat([df_all, df_part], ignore_index=True)
        last_close_time = int(df_part["close_time"].iloc[-1])
        if len(df_part) < limit:
            break

    float_cols = [
        "open","high","low","close","volume","quote_asset_volume",
        "taker_buy_base_asset_volume","taker_buy_quote_asset_volume"
    ]
    for c in float_cols:
        df_all[c] = df_all[c].astype(float)

    df_all.sort_values("date", inplace=True)
    df_all.reset_index(drop=True, inplace=True)
    return df_all

def get_binance_value_5m(token="BTC"):
    if token.upper() != "BTC":
        raise ValueError("Solo BTC en este ejemplo.")
    return get_binance_klines(symbol="BTCUSDT", interval="5m", max_candles=2000)


## 2) Preparamos los datos para *fine-tuning*
Queremos entrenar un **modelo univariante** (solo `close`), así que transformaremos a `log_close`. 

Para el entrenamiento, usaremos una clase `TimeSeriesDataset` que cree ventanas `(x_context, x_future)`.


In [ ]:
class TimeSeriesDataset(Dataset):
    """Dataset (train/val) para TimeSFM, usando un array unidimensional log_close."""
    def __init__(self, series: np.ndarray, context_length: int, horizon_length: int, freq_type: int = 0):
        self.series = series
        self.context_length = context_length
        self.horizon_length = horizon_length
        self.freq_type = freq_type
        self.samples = []
        total_len = context_length + horizon_length

        for start_idx in range(len(series) - total_len):
            x_cont = series[start_idx : start_idx + context_length]
            x_fut  = series[start_idx + context_length : start_idx + total_len]
            self.samples.append((x_cont, x_fut))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x_cont, x_fut = self.samples[idx]
        x_cont = torch.tensor(x_cont, dtype=torch.float32)
        x_fut  = torch.tensor(x_fut,  dtype=torch.float32)

        input_padding = torch.zeros_like(x_cont)
        freq = torch.tensor([self.freq_type], dtype=torch.long)

        return x_cont, input_padding, freq, x_fut

def create_train_val_datasets(context_len=128, horizon_len=16, train_split=0.8):
    """
    Descarga BTCUSDT 5m, crea log_close y genera dataset train/val
    """
    df = get_binance_klines("BTCUSDT", "5m", max_candles=5000)  # si quieres más datos, sube max_candles
    df.sort_values("date", inplace=True)
    df.reset_index(drop=True, inplace=True)

    # Evitamos problemas de 0 en close
    df["close"] = df["close"].clip(lower=1e-9)
    df.dropna(inplace=True)

    # Tomamos log_close
    df["log_close"] = np.log(df["close"])
    series_np = df["log_close"].values

    split_index = int(len(series_np) * train_split)
    train_data = series_np[:split_index]
    val_data   = series_np[split_index:]

    train_ds = TimeSeriesDataset(train_data, context_len, horizon_len, freq_type=0)
    val_ds   = TimeSeriesDataset(val_data,   context_len, horizon_len, freq_type=0)
    return train_ds, val_ds


## 3) Creación del modelo base y bucle de entrenamiento
Partimos de `google/timesfm-2.0-500m-pytorch` y lo entrenamos en PyTorch.


In [ ]:
def create_timesfm_model(load_pretrained=True, device="cuda"):
    # Ajustamos hiperparams
    # Se puede reescribir horizon_len luego, esto es un valor 'base'
    hparams = TimesFmHparams(
        backend=device,
        horizon_len=16,  # lo ajustamos a nuestro scenario
        input_patch_len=32,
        output_patch_len=16,
        num_layers=50,
        model_dims=1280,
        use_positional_embedding=False,
    )

    if load_pretrained:
        checkpoint = TimesFmCheckpoint(huggingface_repo_id=MODEL_REPO)
    else:
        checkpoint = None

    # Creamos un TimesFm
    tfm_model = TimesFm(hparams=hparams, checkpoint=checkpoint)

    # Extraemos el modelo PyTorch real
    pytorch_model = PatchedTimeSeriesDecoder(tfm_model._model_config)
    pytorch_model.to(device)

    if load_pretrained:
        # Cargamos pesos del repo HF
        ckpt_path = snapshot_download(MODEL_REPO) + "/torch_model.ckpt"
        weights = torch.load(ckpt_path, map_location=device)
        pytorch_model.load_state_dict(weights, strict=False)

    return pytorch_model, hparams

def train_loop(model,
               train_ds,
               val_ds,
               epochs=5,
               lr=1e-4,
               batch_size=32,
               device="cuda"):
    """Bucle sencillo de entrenamiento con MSELoss"""

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)

    for ep in range(epochs):
        model.train()
        total_loss = 0
        for x_cont, x_pad, freq, x_fut in train_loader:
            x_cont = x_cont.to(device)
            x_pad  = x_pad.to(device)
            freq   = freq.to(device)
            x_fut  = x_fut.to(device)

            optimizer.zero_grad()
            preds = model(x_cont, x_pad, freq)
            # preds.shape puede ser [B, N, horizon_len] o [B, N, horizon_len, channels]
            # en MSE univariante, nos interesa [B, horizon_len].
            if preds.ndim == 4:
                # [B, N, horizon_len, 1]
                preds = preds[..., 0]
            if preds.ndim == 3:
                # [B, N, horizon_len]
                # Cogemos el último patch
                preds = preds[:, -1, :]

            loss = criterion(preds, x_fut)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        mean_train_loss = total_loss / len(train_loader)

        # Validación
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for x_cont, x_pad, freq, x_fut in val_loader:
                x_cont = x_cont.to(device)
                x_pad  = x_pad.to(device)
                freq   = freq.to(device)
                x_fut  = x_fut.to(device)
                vpreds = model(x_cont, x_pad, freq)
                if vpreds.ndim == 4:
                    vpreds = vpreds[..., 0]
                if vpreds.ndim == 3:
                    vpreds = vpreds[:, -1, :]

                vloss = criterion(vpreds, x_fut)
                total_val_loss += vloss.item()
        mean_val_loss = total_val_loss / len(val_loader)

        print(f"Epoch {ep+1}/{epochs}, train_loss={mean_train_loss:.4f}, val_loss={mean_val_loss:.4f}")


## 4) Entrenar y guardar un checkpoint local
En la siguiente celda, haremos:
- Descarga de datos
- Creación del dataset `train_ds` y `val_ds` (con `context_len=128` y `horizon_len=16`, por ejemplo)
- Creación del modelo base (`load_pretrained=True`)
- Llamamos a `train_loop(...)`
- Guardamos los pesos finetuneados en un .ckpt local


In [ ]:
def run_finetuning():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Usando device={device}")

    # Preparamos dataset (BTC 5m)
    train_ds, val_ds = create_train_val_datasets(
        context_len=128,
        horizon_len=16,
        train_split=0.8
    )
    print(f"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}")

    # Creamos modelo base (preentrenado) + hparams
    model, base_hparams = create_timesfm_model(load_pretrained=True, device=device)

    # Entrenamos (ejemplo: 5 epochs)
    train_loop(
        model=model,
        train_ds=train_ds,
        val_ds=val_ds,
        epochs=5,
        lr=1e-4,
        batch_size=32,
        device=device
    )

    # Guardar checkpoint local
    ckpt_path = "my_finetuned_timesfm.ckpt"
    torch.save(model.state_dict(), ckpt_path)
    print(f"\nCheckpoint guardado en {ckpt_path}")


Puedes **ejecutar** la celda de arriba para hacer el fine-tuning. Una vez finalice, tendrás un nuevo fichero `my_finetuned_timesfm.ckpt` con tus pesos ajustados.

In [ ]:
# Descomenta para entrenar:
# run_finetuning()

## 5) Flask con inferencia de Valor y Volatilidad
A continuación, el **mismo** código que usas para tu API, pero adaptado para **cargar** el checkpoint local `my_finetuned_timesfm.ckpt`. 

### NOTA IMPORTANTE
Para usar tu nuevo checkpoint en un `TimesFm(...)` real, se requiere la carpeta de `config.json` + `torch_model.ckpt`. Lo más práctico es:
1. Descargar la carpeta original del repo HF (snapshot_download), que te da `config.json`, etc.
2. Sobrescribir (o renombrar) `torch_model.ckpt` con tu `my_finetuned_timesfm.ckpt`.
3. Referenciar `checkpoint=TimesFmCheckpoint(local_path="esa/carpeta")`.

Si solo guardas `my_finetuned_timesfm.ckpt` sin la config, tendrás que "engañar" a `TimesFmCheckpoint` para que use la config base y tus pesos. 
Por simplicidad, en la celda de abajo asumimos que tienes una carpeta local, p. ej. `my_fine_repo/` con:
- `config.json`
- `merges.txt` (si hace falta)
- `special_tokens_map.json` (a veces)
- `torch_model.ckpt` (este es tu `my_finetuned_timesfm.ckpt` renombrado)


In [ ]:
# Ejemplo de un 'mini-servicio' Flask con valor y volatilidad.
from flask import Flask, Response

def create_finetuned_model_service(local_checkpoint_path: str):
    """
    Crea y retorna un Flask con un TimesFM cargado desde 'local_checkpoint_path'.
    """
    app = Flask(__name__)

    # Cargamos TimesFm con el checkpoint local
    # Asumimos que local_checkpoint_path es una carpeta con config.json y torch_model.ckpt
    # en torch_model.ckpt están tus pesos fine-tuneados
    global tfm_model

    try:
        local_ckpt = timesfm.TimesFmCheckpoint(local_path=local_checkpoint_path)
        global_hparams = tsfm_hparams  # reusamos el que definimos al principio

        tfm_model = timesfm.TimesFm(
            hparams=global_hparams,
            checkpoint=local_ckpt
        )
    except:
        tfm_model = None

    # Reusamos 'prepare_data', 'unscale_value', etc.

    @app.route("/inference/value/<string:token>/<int:pl>", methods=["GET"])
    def inference_value(token, pl):
        try:
            if tfm_model is None:
                return Response("No TimeSFM model loaded", status=500)
            if pl <= 0:
                return Response("pl>0", status=400)

            df_raw = get_binance_value_5m(token)
            if df_raw.empty:
                return Response("Empty data", status=500)
            df_feat, scalers, _ = prepare_data(df_raw)
            if len(df_feat) < 50:
                return Response("Insufficient data after cleaning", status=500)

            window_size = min(len(df_feat), 512)
            tail = df_feat.iloc[-window_size:]
            X = tail["log_close_scaled"].values

            tfm_model.hparams.horizon_len = pl
            with torch.no_grad():
                forecast_np = tfm_model.forecast(inputs=[X], freq=[0])
            out = forecast_np[0]
            if out.ndim == 1:
                pred_scaled = out
            else:
                pred_scaled = out[:, 0]

            final_pred_scaled = pred_scaled[-1]
            log_c = unscale_value(final_pred_scaled, scalers["log_close"])
            price = np.exp(log_c)
            if price < 0:
                price = 0.0

            return Response(f"{price:.8f}", status=200)
        except Exception as e:
            logging.error(traceback.format_exc())
            return Response(str(e), 500)

    @app.route("/inference/volatility/<string:token>", methods=["GET"])
    def inference_volatility_6h(token):
        try:
            if tfm_model is None:
                return Response("No TimeSFM model loaded", status=500)

            df_raw = get_binance_value_5m(token)
            if df_raw.empty:
                return Response("Empty data", status=500)
            df_feat, scalers, _ = prepare_data(df_raw)
            if len(df_feat) < 50:
                return Response("Insufficient data after cleaning", status=500)

            window_size = min(len(df_feat), 512)
            tail = df_feat.iloc[-window_size:]
            X = tail["log_close_scaled"].values

            tfm_model.hparams.horizon_len = 72
            with torch.no_grad():
                forecast_np = tfm_model.forecast(inputs=[X], freq=[0])
            out_np = forecast_np[0]

            if out_np.ndim == 1:
                scaled_preds = out_np
            else:
                scaled_preds = out_np[:, 0]

            pred_log_close = []
            for val_s in scaled_preds:
                val_unsc = unscale_value(val_s, scalers["log_close"])
                pred_log_close.append(val_unsc)
            pred_log_close = np.array(pred_log_close, dtype=np.float32)

            if len(pred_log_close) < 2:
                hist_log = tail["log_close"].values
                if len(hist_log) < 3:
                    return Response("0.000000", status=200)
                r_hist = hist_log[1:] - hist_log[:-1]
                if len(r_hist) < 2:
                    return Response("0.000000", status=200)
                vol = r_hist.std(ddof=1) * np.sqrt(72.0)
                return Response(f"{vol:.6f}", status=200)

            returns = pred_log_close[1:] - pred_log_close[:-1]
            if len(returns) < 2:
                return Response("0.000000", status=200)

            raw_std = returns.std(ddof=1)
            vol_6h = raw_std * np.sqrt(72.0)

            if np.isnan(vol_6h) or np.isinf(vol_6h) or vol_6h < 1e-12:
                hist_log = tail["log_close"].values
                if len(hist_log) < 3:
                    vol_6h = 0.0
                else:
                    r_hist = hist_log[1:] - hist_log[:-1]
                    if len(r_hist) < 2:
                        vol_6h = 0.0
                    else:
                        if len(r_hist) >= 72:
                            hist_std = r_hist[-72:].std(ddof=1)
                        else:
                            hist_std = r_hist.std(ddof=1)
                        vol_6h = hist_std * np.sqrt(72.0)
                        if np.isnan(vol_6h) or np.isinf(vol_6h) or vol_6h < 1e-12:
                            vol_6h = 0.0

            return Response(f"{vol_6h:.6f}", status=200)
        except Exception as e:
            logging.error(traceback.format_exc())
            return Response(str(e), 500)

    return app


### Ejecución local del Flask
Una vez tengas tu carpeta local con `config.json` y `my_finetuned_timesfm.ckpt` (renombrado a `torch_model.ckpt`), puedes crear la app e iniciarla.

In [ ]:
# Ejemplo de cómo lanzar la app
# Normalmente lo harías en un main guardado en un .py.

# from flask import Flask
# if __name__ == "__main__":
#     local_ckpt_dir = "./my_finerepo"  # carpeta con config.json y torch_model.ckpt
#     app = create_finetuned_model_service(local_checkpoint_path=local_ckpt_dir)
#     app.run(host="0.0.0.0", port=8000, debug=True)
# 
# Luego harías:
# GET /inference/value/BTC/10 => predicción a 10 velas
# GET /inference/volatility/BTC => volatilidad a 6h

## Conclusiones

1. **Ejecuta** la celda de `run_finetuning()` (o llama la función) para entrenar un poco. 
2. **Reemplaza** `torch_model.ckpt` en una carpeta local con tus pesos generados (`my_finetuned_timesfm.ckpt`). 
3. **Crea** la app Flask con `create_finetuned_model_service(...)` apuntando a esa carpeta local. 
4. **Lanza** el servicio y prueba tus endpoints `/inference/value/...` y `/inference/volatility/...`.

¡Listo! Has integrado *fine-tuning* de TimeSFM con tu lógica de inferencia de precio y volatilidad.