from pathlib import Path

from autogluon.timeseries import TimeSeriesDataFrame, TimeSeriesPredictor
from chronos import BaseChronosPipeline

data = TimeSeriesDataFrame.from_path(
    "https://autogluon.s3.amazonaws.com/datasets/timeseries/australian_electricity_subset/test.csv"
)
print(data.head())

prediction_length = 8
train_data, test_data = data.train_test_split(prediction_length=prediction_length)

# predictor = TimeSeriesPredictor(prediction_length=48).fit(
#     train_data,
#     hyperparameters={
#         "Chronos": {"model_path": "bolt_mini", "fine_tune": True, "fine_tune_steps": 100},
#     },
#     verbosity=3,
# )

predictor = TimeSeriesPredictor(prediction_length=prediction_length).fit(
    train_data=train_data,
    hyperparameters={
        "Chronos": [
            # {"model_path": "bolt_small", "ag_args": {"name_suffix": "ZeroShot"}},
            {
                "model_path": "bolt_small",
                "fine_tune": True,
                "ag_args": {
                    "name_suffix": "FineTuned"
                }
            },
        ]
    },
    time_limit=60,  # time limit in seconds
    enable_ensemble=False,
)
# print(predictor.leaderboard(test_data))
# fine_tuned_path = Path(
#     predictor._trainer.load_model("ChronosFineTuned[bolt_small]").most_recent_model.path
# ) / "fine-tuned-ckpt"
# pipeline = BaseChronosPipeline.from_pretrained(fine_tuned_path)

# # Get the context data for prediction (last part of training data)
# context_data = train_data.tail(prediction_length * 2)  # Use last 16 points as context

# import numpy as np

# # Convert to the format expected by Chronos pipeline
# import torch

# # Extract the target values and convert to tensor
# context_values = context_data['target'].values
# context_tensor = torch.tensor(context_values, dtype=torch.float32).unsqueeze(0)  # Add batch dimension

# # Generate predictions
# predictions = pipeline.predict(context_tensor, prediction_length=prediction_length)
# print("Predictions shape:", predictions.shape)
# print("Predictions:", predictions)
