# Multi-Chart Fine-Tuning Pipeline for Kronos\n\nThis notebook implements a pipeline to fine-tune the Kronos model on data from multiple charts to predict the future of a single target chart.

## 1. Imports

In [None]:
import os\nimport pandas as pd\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data import Dataset, DataLoader\nimport matplotlib.pyplot as plt\nfrom tqdm.notebook import trange\n\nfrom model.kronos import Kronos\nfrom model.module import TemporalEmbedding

## 2. Data Loading and Preprocessing

In [None]:
def load_and_preprocess_data(data_dir, file_paths, target_file):\n    dfs = {}\n    all_files = file_paths + [target_file]\n    for file in all_files:\n        path = os.path.join(data_dir, file)\n        df = pd.read_csv(path)\n        df['timestamps'] = pd.to_datetime(df['timestamps'])\n        df = df.set_index('timestamps')\n        prefix = file.split('.')[0]\n        df = df.add_prefix(f\"{prefix}_\")\n        dfs[file] = df\n\n    merged_df = pd.concat(dfs.values(), axis=1, join='outer')\n    merged_df = merged_df.sort_index()\n    merged_df = merged_df.ffill()\n    merged_df = merged_df.dropna()\n    return merged_df\n\nDATA_DIR = \"./data/\"\nINPUT_FILES = [\"BTC_ETH.csv\", \"BTC_XAU.csv\", \"BTC_EUR.csv\"]\nTARGET_FILE = \"BTC_USDT.csv\"\n\nif not os.path.exists(DATA_DIR):\n    os.makedirs(DATA_DIR)\ndate_range = pd.to_datetime(pd.date_range(start=\"2023-01-01\", periods=1000, freq=\"h\"))\nfor filename in INPUT_FILES + [TARGET_FILE]:\n    dummy_data = {\n        \"timestamps\": date_range,\n        \"open\": np.random.rand(1000) * 1000,\n        \"high\": np.random.rand(1000) * 1000 + 1000,\n        \"low\": np.random.rand(1000) * 1000 - 500,\n        \"close\": np.random.rand(1000) * 1000,\n        \"volume\": np.random.rand(1000) * 100,\n        \"amount\": np.random.rand(1000) * 10000\n    }\n    dummy_df = pd.DataFrame(dummy_data)\n    dummy_df.to_csv(os.path.join(DATA_DIR, filename), index=False)\n\nmerged_data = load_and_preprocess_data(DATA_DIR, INPUT_FILES, TARGET_FILE)\nprint(\"Data loaded and preprocessed.\")

## 3. Model Definition

In [None]:
DEVICE = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\nMODEL_PATH = \"./pretrained_models/Kronos-small\"\nbase_model = Kronos.from_pretrained(MODEL_PATH).to(DEVICE)\n\nclass MultiChartKronosForRegression(nn.Module):\n    def __init__(self, base_model, num_charts, d_model):\n        super().__init__()\n        self.num_charts = num_charts\n        self.time_emb = base_model.time_emb\n        self.input_heads = nn.ModuleList([nn.Linear(6, d_model) for _ in range(num_charts)])\n        self.transformer = base_model.transformer\n        self.norm = base_model.norm\n        for param in self.transformer.parameters():\n            param.requires_grad = False\n        self.regression_head = nn.Linear(d_model, 6)\n    \n    def forward(self, list_of_chart_tensors, stamp=None):\n        embeddings = [self.input_heads[i](chart) for i, chart in enumerate(list_of_chart_tensors)]\n        x = torch.stack(embeddings).sum(dim=0)\n        if stamp is not None:\n            x = x + self.time_emb(stamp)\n        for layer in self.transformer:\n            x = layer(x)\n        x = self.norm(x)\n        return self.regression_head(x)\n\nnum_all_charts = len(INPUT_FILES) + 1\nd_model = base_model.d_model\nmulti_chart_model = MultiChartKronosForRegression(base_model, num_all_charts, d_model).to(DEVICE)\nprint(\"Multi-head model created.\")

## 4. Fine-Tuning

In [None]:
LEARNING_RATE = 1e-4\nBATCH_SIZE = 32\nNUM_EPOCHS = 10\nSEQ_LEN = 128\nPRED_LEN = 32\nCLIP = 5.0\n\ndef calc_time_stamps(timestamps):\n    return pd.DataFrame({\n        'minute': timestamps.minute,\n        'hour': timestamps.hour,\n        'weekday': timestamps.weekday,\n        'day': timestamps.day,\n        'month': timestamps.month\n    })\n\nclass MultiChartDataset(Dataset):\n    def __init__(self, df, input_files, target_file, seq_len, pred_len):\n        self.df = df\n        self.input_files = input_files\n        self.target_file = target_file\n        self.seq_len = seq_len\n        self.pred_len = pred_len\n\n    def __len__(self):\n        return len(self.df) - self.seq_len - self.pred_len + 1\n\n    def __getitem__(self, idx):\n        end_idx = idx + self.seq_len\n        pred_end_idx = end_idx + self.pred_len\n        \n        x_charts, y_charts = [], []\n        for file in self.input_files + [self.target_file]:\n            prefix = file.split('.')[0]\n            cols = [f\"{prefix}_{c}\" for c in ['open', 'high', 'low', 'close', 'volume', 'amount']]\n            x_charts.append(torch.tensor(self.df[cols].iloc[idx:end_idx].values, dtype=torch.float32))\n        \n        target_prefix = self.target_file.split('.')[0]\n        target_cols = [f\"{target_prefix}_{c}\" for c in ['open', 'high', 'low', 'close', 'volume', 'amount']]\n        y_target = torch.tensor(self.df[target_cols].iloc[end_idx:pred_end_idx].values, dtype=torch.float32)\n        \n        x_stamp = torch.tensor(calc_time_stamps(self.df.index[idx:end_idx]).values, dtype=torch.float32)\n        return {'x_charts': x_charts, 'x_stamp': x_stamp, 'y_target': y_target}\n\ntrain_size = int(len(merged_data) * 0.8)\ntrain_df, val_df = merged_data.iloc[:train_size], merged_data.iloc[train_size:]\ntrain_dataset = MultiChartDataset(train_df, INPUT_FILES, TARGET_FILE, SEQ_LEN, PRED_LEN)\nval_dataset = MultiChartDataset(val_df, INPUT_FILES, TARGET_FILE, SEQ_LEN, PRED_LEN)\ntrain_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\nval_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)\n\noptimizer = torch.optim.Adam(multi_chart_model.parameters(), lr=LEARNING_RATE)\ncriterion = nn.MSELoss()\n\nmulti_chart_model.train()\nfor epoch in trange(NUM_EPOCHS, desc=\"Epochs\"):\n    train_loss = 0\n    for batch in train_loader:\n        optimizer.zero_grad()\n        x_charts = [c.to(DEVICE) for c in batch['x_charts']]\n        x_stamp, y_target = batch['x_stamp'].to(DEVICE), batch['y_target'].to(DEVICE)\n        predictions = multi_chart_model(x_charts, x_stamp)\n        loss = criterion(predictions[:, -PRED_LEN:, :], y_target)\n        loss.backward()\n        torch.nn.utils.clip_grad_norm_(multi_chart_model.parameters(), CLIP)\n        optimizer.step()\n        train_loss += loss.item()\n    print(f\"Epoch {epoch+1}, Train Loss: {train_loss / len(train_loader):.6f}\")\n\n    multi_chart_model.eval()\n    val_loss = 0\n    with torch.no_grad():\n        for batch in val_loader:\n            x_charts = [c.to(DEVICE) for c in batch['x_charts']]\n            x_stamp, y_target = batch['x_stamp'].to(DEVICE), batch['y_target'].to(DEVICE)\n            predictions = multi_chart_model(x_charts, x_stamp)\n            val_loss += criterion(predictions[:, -PRED_LEN:, :], y_target).item()\n    print(f\"Epoch {epoch+1}, Val Loss: {val_loss / len(val_loader):.6f}\")

## 5. Prediction and Visualization

In [None]:
def predict_and_plot(model, dataset, num_samples=5):\n    model.eval()\n    loader = DataLoader(dataset, batch_size=num_samples, shuffle=True)\n    with torch.no_grad():\n        batch = next(iter(loader))\n        x_charts = [c.to(DEVICE) for c in batch['x_charts']]\n        x_stamp, y_target = batch['x_stamp'].to(DEVICE), batch['y_target'].cpu().numpy()\n        future_preds = model(x_charts, x_stamp)[:, -PRED_LEN:, :].cpu().numpy()\n    \n    for i in range(num_samples):\n        plt.figure(figsize=(15, 8))\n        plt.subplot(2, 1, 1)\n        plt.plot(y_target[i, :, 3], label='Ground Truth Close')\n        plt.plot(future_preds[i, :, 3], label='Predicted Close', linestyle='--')\n        plt.title('Close Price Prediction')\n        plt.legend()\n        \n        plt.subplot(2, 1, 2)\n        plt.plot(y_target[i, :, 4], label='Ground Truth Volume')\n        plt.plot(future_preds[i, :, 4], label='Predicted Volume', linestyle='--')\n        plt.title('Volume Prediction')\n        plt.legend()\n        plt.tight_layout()\n        plt.show()\n\npredict_and_plot(multi_chart_model, val_dataset)