In [1]:
import warnings
import pandas as pd
import pytorch_lightning as pl
from tqdm.auto import tqdm
from evaluation.visualization import print_sequence_table

warnings.filterwarnings("ignore")
tqdm.pandas()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pd.read_csv("bst_data/bst_all.csv")
print(f"✅ Загружено {len(df):,} записей из Best Track данных")
print(f"📊 Уникальных циклонов: {df['intl_id'].nunique()}")
print(f"📅 Временной диапазон: {df['analysis_time'].min()} - {df['analysis_time'].max()}")

✅ Загружено 70,217 записей из Best Track данных
📊 Уникальных циклонов: 1924
📅 Временной диапазон: 1951-02-19 06:00:00 - 2024-12-25 12:00:00


In [3]:
df.head()

Unnamed: 0,intl_id,n_data_lines_declared,cyclone_id,last_line_flag,final_analysis_lag_hr,revision_date,storm_name,analysis_time,indicator,grade,...,lon_deg,central_pressure_hpa,max_wind_kt,r50kt_dir,r50kt_long_nm,r50kt_short_nm,r30kt_dir,r30kt_long_nm,r30kt_short_nm,landfall_indicator
0,5101,10,5101,0,6,19901017,,1951-02-19 06:00:00,2,2,...,138.5,1010,,,,,,,,
1,5101,10,5101,0,6,19901017,,1951-02-19 12:00:00,2,2,...,138.5,1010,,,,,,,,
2,5101,10,5101,0,6,19901017,,1951-02-19 18:00:00,2,2,...,142.1,1000,,,,,,,,
3,5101,10,5101,0,6,19901017,,1951-02-20 00:00:00,2,9,...,146.0,994,,,,,,,,
4,5101,10,5101,0,6,19901017,,1951-02-20 06:00:00,2,9,...,150.6,994,,,,,,,,


In [4]:
import pickle
from pathlib import Path

from data_processing.data_processor import DataProcessor
from data_processing.dataset_models import SequenceConfig
from data_processing.dataset_utils import split_data_by_years


sequence_config = SequenceConfig(
    min_history_length=2,
    max_history_length=None,
)

processor = DataProcessor(
    horizons_hours=[6, 12, 18, 24, 30, 36, 42, 48],
    # horizons_hours=[6, 24, 48],
    sequence_config=sequence_config,
    train_max_year=2021,
    val_max_year=2023,
    validate_data=True,
)


dataset_path = Path("bst_data/processed_dataset.pkl")
processor_config_path = Path("bst_data/processor_config.pkl")

processor_config = {
    "horizons_hours": processor.horizons_hours,
    "train_max_year": processor.train_max_year,
    "val_max_year": processor.val_max_year,
    "sequence_config": processor.seq_config,
    "validate_data": processor.validate_data
}

if dataset_path.exists():
    print("🔄 Загружаем сохраненный датасет...")
    with open(dataset_path, 'rb') as f:
        dataset = pickle.load(f)
    print("✅ Датасет загружен из файла")
else:
    print("�� Создаем новый датасет...")
    dataset = processor.build_dataset(df)
    
    print("💾 Сохраняем датасет...")
    with open(dataset_path, 'wb') as f:
        pickle.dump(dataset, f)
    with open(processor_config_path, 'wb') as f:
        pickle.dump(processor_config, f)
    print("✅ Датасет сохранен в processed_dataset.pkl")

🔄 Загружаем сохраненный датасет...
✅ Датасет загружен из файла


In [5]:
X_train, y_train, X_val, y_val, X_test, y_test = split_data_by_years(
    dataset.X, dataset.y, dataset.times, train_max_year=2020, val_max_year=2022
)

In [6]:
len(X_train)

434745

In [7]:
print_sequence_table(X_train["sequences"][0])


🔹 Seq_0
------------------------------------------------------------
Размер: 2 шагов × 9 признаков
+-------+-----------+-----------+------------------------+---------+----------------+---------------+---------------------+------------------------+-----------------------+
|       |   lat_deg |   lon_deg |   central_pressure_hpa |   grade |   velocity_kmh |   bearing_deg |   acceleration_kmh2 |   angular_velocity_deg |   pressure_change_hpa |
| Шаг 1 |     8.800 |   137.500 |               1004.000 |   2.000 |          0.000 |         0.000 |               0.000 |                  0.000 |                 0.000 |
+-------+-----------+-----------+------------------------+---------+----------------+---------------+---------------------+------------------------+-----------------------+
| Шаг 2 |     9.700 |   136.000 |               1004.000 |   2.000 |         32.109 |       301.414 |               0.000 |                  0.000 |                 0.000 |
+-------+-----------+-----------+--

In [8]:
print_sequence_table(X_train["sequences"][8])


🔹 Seq_0
------------------------------------------------------------
Размер: 3 шагов × 9 признаков
+-------+-----------+-----------+------------------------+---------+----------------+---------------+---------------------+------------------------+-----------------------+
|       |   lat_deg |   lon_deg |   central_pressure_hpa |   grade |   velocity_kmh |   bearing_deg |   acceleration_kmh2 |   angular_velocity_deg |   pressure_change_hpa |
| Шаг 1 |     8.800 |   137.500 |               1004.000 |   2.000 |          0.000 |         0.000 |               0.000 |                  0.000 |                 0.000 |
+-------+-----------+-----------+------------------------+---------+----------------+---------------+---------------------+------------------------+-----------------------+
| Шаг 2 |     9.700 |   136.000 |               1004.000 |   2.000 |         32.109 |       301.414 |               0.000 |                  0.000 |                 0.000 |
+-------+-----------+-----------+--

In [9]:
X_train[:8]

Unnamed: 0,sequences,analysis_time,intl_id,storm_name,target_time_hours,day_of_year_sin,day_of_year_cos,month_of_year_sin,month_of_year_cos
0,"[[8.8, 137.5, 1004.0, 2.0, 0.0, 0.0, 0.0, 0.0,...",2000-05-05,1,6 DAMREY,6.0,0.829677,-0.558244,0.5,-0.866025
1,"[[8.8, 137.5, 1004.0, 2.0, 0.0, 0.0, 0.0, 0.0,...",2000-05-05,1,6 DAMREY,12.0,0.829677,-0.558244,0.5,-0.866025
2,"[[8.8, 137.5, 1004.0, 2.0, 0.0, 0.0, 0.0, 0.0,...",2000-05-05,1,6 DAMREY,18.0,0.829677,-0.558244,0.5,-0.866025
3,"[[8.8, 137.5, 1004.0, 2.0, 0.0, 0.0, 0.0, 0.0,...",2000-05-05,1,6 DAMREY,24.0,0.829677,-0.558244,0.5,-0.866025
4,"[[8.8, 137.5, 1004.0, 2.0, 0.0, 0.0, 0.0, 0.0,...",2000-05-05,1,6 DAMREY,30.0,0.829677,-0.558244,0.5,-0.866025
5,"[[8.8, 137.5, 1004.0, 2.0, 0.0, 0.0, 0.0, 0.0,...",2000-05-05,1,6 DAMREY,36.0,0.829677,-0.558244,0.5,-0.866025
6,"[[8.8, 137.5, 1004.0, 2.0, 0.0, 0.0, 0.0, 0.0,...",2000-05-05,1,6 DAMREY,42.0,0.829677,-0.558244,0.5,-0.866025
7,"[[8.8, 137.5, 1004.0, 2.0, 0.0, 0.0, 0.0, 0.0,...",2000-05-05,1,6 DAMREY,48.0,0.829677,-0.558244,0.5,-0.866025


In [10]:
y_train[:8]

Unnamed: 0,dlat_target,dlon_target
0,0.2,-1.0
1,0.5,-1.6
2,1.0,-2.3
3,1.4,-3.0
4,2.1,-3.8
5,2.8,-3.8
6,3.3,-4.2
7,3.5,-4.4


In [None]:
from training import CycloneDataModule

from models.model import LightningCycloneModel
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from core.features import FeatureConfig

feature_cfg = FeatureConfig()

data = CycloneDataModule(
    X_train,
    y_train,
    X_val,
    y_val,
    sample_weight=None,
    batch_size=1024,
    shuffle_batch=True,
    shuffle_dataset=True,
    shuffle_sequences=False,
    normalize_sequences=False,
    augment_data=False,
)

model = LightningCycloneModel(
    sequence_feature_dim=len(feature_cfg.sequence_features),
    static_feature_dim=len(feature_cfg.static_features),
    hidden_dim=128,
    learning_rate=1e-3,
    loss_fn="haversine"
    # loss_fn="sector",
)

callbacks = [
    EarlyStopping(monitor="val_loss", patience=5, verbose=True, mode="min"),
    ModelCheckpoint(
        monitor="val_loss",
        filename="best-{epoch:02d}-{val_loss:.2f}",
        save_top_k=1,
        mode="min",
    ),
    LearningRateMonitor(logging_interval="epoch"),
]

trainer = pl.Trainer(
    max_epochs=20,
    callbacks=callbacks,
    log_every_n_steps=10,
    accelerator="auto",
    devices="auto",
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [12]:
import torch

torch.set_float32_matmul_precision("high")

In [13]:
print("🚀 Начало обучения...")
trainer.fit(model, datamodule=data)

🚀 Начало обучения...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type          | Params | Mode 
----------------------------------------------------
0 | criterion | HaversineLoss | 0      | train
1 | net       | NNLatLon      | 181 K  | train
----------------------------------------------------
181 K     Trainable params
0         Non-trainable params
181 K     Total params
0.728     Total estimated model params size (MB)
17        Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 425/425 [00:24<00:00, 17.08it/s, v_num=15, train_loss_step=305.0, val_loss_step=307.0, val_loss_epoch=293.0, val_p300_48h=32.00, train_loss_epoch=323.0]

Metric val_loss improved. New best score: 292.888


Epoch 1: 100%|██████████| 425/425 [00:24<00:00, 17.48it/s, v_num=15, train_loss_step=281.0, val_loss_step=262.0, val_loss_epoch=285.0, val_p300_48h=32.80, train_loss_epoch=285.0, train_p300_48h=33.40]

Metric val_loss improved by 8.095 >= min_delta = 0.0. New best score: 284.793


Epoch 2: 100%|██████████| 425/425 [00:24<00:00, 17.39it/s, v_num=15, train_loss_step=283.0, val_loss_step=280.0, val_loss_epoch=283.0, val_p300_48h=34.20, train_loss_epoch=275.0, train_p300_48h=33.70]

Metric val_loss improved by 2.009 >= min_delta = 0.0. New best score: 282.784


Epoch 4: 100%|██████████| 425/425 [00:24<00:00, 17.37it/s, v_num=15, train_loss_step=305.0, val_loss_step=257.0, val_loss_epoch=277.0, val_p300_48h=34.00, train_loss_epoch=266.0, train_p300_48h=34.80]

Metric val_loss improved by 6.170 >= min_delta = 0.0. New best score: 276.614


Epoch 12: 100%|██████████| 425/425 [00:22<00:00, 18.66it/s, v_num=15, train_loss_step=240.0, val_loss_step=216.0, val_loss_epoch=298.0, val_p300_48h=30.90, train_loss_epoch=242.0, train_p300_48h=41.20]

Monitored metric val_loss did not improve in the last 8 records. Best score: 276.614. Signaling Trainer to stop.


Epoch 12: 100%|██████████| 425/425 [00:59<00:00,  7.17it/s, v_num=15, train_loss_step=240.0, val_loss_step=216.0, val_loss_epoch=298.0, val_p300_48h=30.90, train_loss_epoch=242.0, train_p300_48h=41.20]


In [15]:
best_model = LightningCycloneModel.load_from_checkpoint(
    "lightning_logs/version_15/checkpoints/best-epoch=04-val_loss=276.61.ckpt",
)

best_model.eval()


LightningCycloneModel(
  (criterion): HaversineLoss()
  (net): NNLatLon(
    (model): SimpleGRUModel(
      (gru): GRU(9, 128, num_layers=2, batch_first=True, dropout=0.1)
      (dropout): Dropout(p=0.1, inplace=False)
      (static_head): Sequential(
        (0): Linear(in_features=5, out_features=64, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=64, out_features=64, bias=True)
        (4): ReLU()
        (5): Dropout(p=0.1, inplace=False)
      )
      (combined_head): Sequential(
        (0): Linear(in_features=192, out_features=128, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
      )
      (output_layer): Linear(in_features=128, out_features=2, bias=True)
    )
  )
)

In [16]:
def horizon_split(X, y, horizon_hours):
    idx = X["target_time_hours"] == horizon_hours
    return X[idx].reset_index(drop=True), y[idx].reset_index(drop=True)

X_val6, y_val6 = horizon_split(X_val, y_val, 6)
X_val12, y_val12 = horizon_split(X_val, y_val, 12)
X_val24, y_val24 = horizon_split(X_val, y_val, 24)
X_val48, y_val48 = horizon_split(X_val, y_val, 48)

In [17]:
import pandas as pd
from evaluation.evaluator import ModelEvaluator

evaluator = ModelEvaluator()

# Список (датафреймы, метка для вывода)
horizon_sets = [
    (X_val6, y_val6, "6 ч"),
    (X_val12, y_val12, "12 ч"),
    (X_val24, y_val24, "24 ч"),
    (X_val48, y_val48, "48 ч"),
]

best_model.eval()
rows = []
for X_h, y_h, label in horizon_sets:
    res = evaluator.evaluate_horizon(best_model, X_h, y_h)
    res["Горизонт"] = label
    rows.append(res)

df_metrics = (
    pd.DataFrame(rows)
    .rename(
        columns={
            "samples": "Примеров",
            "mean_km": "Средняя (км)",
            "median_km": "Медиана (км)",
            "max_km": "Макс. ошибка (км)",
            "p50": "P<50 км(%)",
            "p100": "P<100 км(%)",
            "p300": "P<300 км(%)",
        }
    )
    .set_index("Горизонт")
    .round(1)
)

print(df_metrics.to_string())

          Примеров  Средняя (км)  Медиана (км)  Макс. ошибка (км)  P<50 км(%)  P<100 км(%)  P<300 км(%)
Горизонт                                                                                               
6 ч           1487          54.6          43.8              378.4        56.0         88.0         99.9
12 ч          1425         109.2          88.4              576.9        23.8         56.7         96.6
24 ч          1312         227.1         185.8             1272.1         6.2         22.6         76.3
48 ч          1109         494.4         405.9             2648.9         1.2          5.4         34.0


In [18]:
export_model_path = "weights/model.onnx"

In [19]:
best_model.to("cpu")
best_model.eval()
best_model.export_to_onnx(export_model_path)

✅ Модель успешно экспортирована в ONNX: weights/model.onnx
✅ ONNX модель валидирована успешно. Максимальная разница: 1.49e-08


# Inference

In [20]:
from evaluation.visualization import create_inference_pipeline

pipeline = create_inference_pipeline(export_model_path)

In [22]:
from evaluation.visualization import plot_animated_trajectory

unique_cyclones = X_test["intl_id"].unique()[:5]

for cyclone_id in unique_cyclones:
    plot_animated_trajectory(
        pipeline, 
        df, 
        cyclone_id,
        horizon_hours=48
    ) 

Dataset validation: 1 rows, 7 columns
Dataset validation: 2 rows, 7 columns
Dataset validation: 3 rows, 7 columns
Dataset validation: 4 rows, 7 columns
Dataset validation: 5 rows, 7 columns
Dataset validation: 6 rows, 7 columns
Dataset validation: 7 rows, 7 columns
Dataset validation: 8 rows, 7 columns
Dataset validation: 9 rows, 7 columns
Dataset validation: 10 rows, 7 columns
Dataset validation: 11 rows, 7 columns
Dataset validation: 12 rows, 7 columns
Dataset validation: 13 rows, 7 columns
Dataset validation: 14 rows, 7 columns


Dataset validation: 1 rows, 7 columns
Dataset validation: 2 rows, 7 columns
Dataset validation: 3 rows, 7 columns
Dataset validation: 4 rows, 7 columns
Dataset validation: 5 rows, 7 columns
Dataset validation: 6 rows, 7 columns
Dataset validation: 7 rows, 7 columns
Dataset validation: 8 rows, 7 columns
Dataset validation: 9 rows, 7 columns
Dataset validation: 10 rows, 7 columns
Dataset validation: 11 rows, 7 columns
Dataset validation: 12 rows, 7 columns
Dataset validation: 13 rows, 7 columns
Dataset validation: 14 rows, 7 columns
Dataset validation: 15 rows, 7 columns
Dataset validation: 16 rows, 7 columns
Dataset validation: 17 rows, 7 columns
Dataset validation: 18 rows, 7 columns
Dataset validation: 19 rows, 7 columns
Dataset validation: 20 rows, 7 columns
Dataset validation: 21 rows, 7 columns
Dataset validation: 22 rows, 7 columns
Dataset validation: 23 rows, 7 columns
Dataset validation: 24 rows, 7 columns
Dataset validation: 25 rows, 7 columns
Dataset validation: 26 rows, 7 col

Dataset validation: 1 rows, 7 columns
Dataset validation: 2 rows, 7 columns
Dataset validation: 3 rows, 7 columns
Dataset validation: 4 rows, 7 columns
Dataset validation: 5 rows, 7 columns
Dataset validation: 6 rows, 7 columns
Dataset validation: 7 rows, 7 columns
Dataset validation: 8 rows, 7 columns
Dataset validation: 9 rows, 7 columns
Dataset validation: 10 rows, 7 columns
Dataset validation: 11 rows, 7 columns
Dataset validation: 12 rows, 7 columns
Dataset validation: 13 rows, 7 columns
Dataset validation: 14 rows, 7 columns
Dataset validation: 15 rows, 7 columns
Dataset validation: 16 rows, 7 columns
Dataset validation: 17 rows, 7 columns
Dataset validation: 18 rows, 7 columns
Dataset validation: 19 rows, 7 columns
Dataset validation: 20 rows, 7 columns
Dataset validation: 21 rows, 7 columns
Dataset validation: 22 rows, 7 columns
Dataset validation: 23 rows, 7 columns
Dataset validation: 24 rows, 7 columns
Dataset validation: 25 rows, 7 columns
Dataset validation: 26 rows, 7 col

Dataset validation: 1 rows, 7 columns
Dataset validation: 2 rows, 7 columns
Dataset validation: 3 rows, 7 columns
Dataset validation: 4 rows, 7 columns
Dataset validation: 5 rows, 7 columns
Dataset validation: 6 rows, 7 columns
Dataset validation: 7 rows, 7 columns
Dataset validation: 8 rows, 7 columns
Dataset validation: 9 rows, 7 columns
Dataset validation: 10 rows, 7 columns
Dataset validation: 11 rows, 7 columns
Dataset validation: 12 rows, 7 columns
Dataset validation: 13 rows, 7 columns
Dataset validation: 14 rows, 7 columns
Dataset validation: 15 rows, 7 columns
Dataset validation: 16 rows, 7 columns
Dataset validation: 17 rows, 7 columns
Dataset validation: 18 rows, 7 columns
Dataset validation: 19 rows, 7 columns
Dataset validation: 20 rows, 7 columns
Dataset validation: 21 rows, 7 columns
Dataset validation: 22 rows, 7 columns
Dataset validation: 23 rows, 7 columns


Dataset validation: 1 rows, 7 columns
Dataset validation: 2 rows, 7 columns
Dataset validation: 3 rows, 7 columns
Dataset validation: 4 rows, 7 columns
Dataset validation: 5 rows, 7 columns
Dataset validation: 6 rows, 7 columns
Dataset validation: 7 rows, 7 columns
Dataset validation: 8 rows, 7 columns
Dataset validation: 9 rows, 7 columns
Dataset validation: 10 rows, 7 columns
Dataset validation: 11 rows, 7 columns
Dataset validation: 12 rows, 7 columns
Dataset validation: 13 rows, 7 columns
Dataset validation: 14 rows, 7 columns
Dataset validation: 15 rows, 7 columns
Dataset validation: 16 rows, 7 columns
Dataset validation: 17 rows, 7 columns
Dataset validation: 18 rows, 7 columns
Dataset validation: 19 rows, 7 columns
Dataset validation: 20 rows, 7 columns
Dataset validation: 21 rows, 7 columns
Dataset validation: 22 rows, 7 columns
Dataset validation: 23 rows, 7 columns
Dataset validation: 24 rows, 7 columns
Dataset validation: 25 rows, 7 columns
Dataset validation: 26 rows, 7 col

In [16]:
# from evaluation.visualization import plot_enhanced_trajectory

# unique_cyclones = X_val["intl_id"].unique()[:5]  # Первые 5 циклонов
# for cyclone_id in unique_cyclones:
#     print(f"Визуализация циклона {cyclone_id}")
#     idx = X_val["target_time_hours"] == 48.0
#     plot_enhanced_trajectory(model, X_val[idx], y_val[idx], cyclone_id)