# Train & 5-Fold Evaluation for Beat This!

Simply run the code in this Jupyter Notebook to execute 5-fold cross-validation training and evaluation for "Beat This! beat and downbeat estimator". All parameters as in the ISMIR 2024 setting, original implementation. First epoch of training will take longer time to init dataloader (~2min), since 2nd epoch will speed up (~30s per epoch).

**Acknowledgements:**
This notebook builds upon and extends previous open-source repo & the Beat This! paper:
- **F. Foscarin, J. Schlüter, G. Widmer (2024)**, *Beat This! Accurate Beat Tracking Without DBN Postprocessing*, *Proc. ISMIR*, pp. 962–969. [Paper](https://arxiv.org/pdf/2407.21658) | [GitHub Repo](https://github.com/CPJKU/beat_this)

**Modification:** 
1. We create the `train.py & test.py & dataset_h5.py` to use the MazurkaBL dataset - preprocessed h5 files.
2. We add a class `Bark_sone feature extractor` into the `preprocessing.py`.
3. All the other code comes from [Beat This! Repo](https://github.com/CPJKU/beat_this) without modification, except for import path changes - code structure reorganization.

## 5-Fold Cross-validation Results of Beat This! on MazurkaBL Dataset

### Best Validation Loss Checkpoints (Reported in paper)
| Fold | Checkpoint | Beat F1 | Downbeat F1 | CSV File |
|------|------------|---------|-------------|----------|
| 0    | epoch09-valloss1.2083 | 0.7648 | 0.4244 | `logs/summary_mazurka_h5_22050_S86_F0.csv` |
| 1    | epoch24-valloss0.8547 | **0.8344** | 0.6049 | `logs/summary_mazurka_h5_22050_S86_F1.csv` |
| 2    | epoch14-valloss1.1175 | 0.8083 | 0.5318 | `logs/summary_mazurka_h5_22050_S86_F2.csv` |
| 3    | epoch29-valloss1.5129 | 0.8321 | 0.5721 | `logs/summary_mazurka_h5_22050_S86_F3.csv` |
| 4    | epoch14-valloss1.2587 | 0.7874 | 0.5057 | `logs/summary_mazurka_h5_22050_S86_F4.csv` |
| **Avg ± Std** | (best val ckpt) | **0.805 ± 0.027** | **0.528 ± 0.064** | Trainable≈Total Parameters 20.3 M |

### Last Checkpoints
| Fold | Checkpoint | Beat F1 | Downbeat F1 | CSV File |
|------|------------|---------|-------------|----------|
| 0    | last | 0.6232 | 0.2832 | `logs/summary_mazurka_h5_22050_S86_F0.csv` |
| 1    | last | 0.8201 | **0.6235** | `logs/summary_mazurka_h5_22050_S86_F1.csv` |
| 2    | last | **0.8115** | **0.5866** | `logs/summary_mazurka_h5_22050_S86_F2.csv` |
| 3    | last | **0.8347** | **0.6075** | `logs/summary_mazurka_h5_22050_S86_F3.csv` |
| 4    | last | **0.8145** | **0.5847** | `logs/summary_mazurka_h5_22050_S86_F4.csv` |
| **Avg ± Std** | (last ckpt) | **0.781 ± 0.079** | **0.537 ± 0.129** | Trainable≈Total Params 20.3 M |

In [None]:
h5_root = "/media/datadisk/home/22828187/zhanh/202505_dynest_data/workspaces/hdf5s/mazurka_sr22050"

for f in range(5):
    csv_path = f"/media/datadisk/home/22828187/zhanh/202505_dynest_data/workspaces/split_5fold_fold{f}_seed86.csv"
    !python beat_this/train.py --sample-rate 22050 --fps 50 --seed 86 \
        --h5-root {h5_root}\
        --csv-split {csv_path}\
        --fold {f} --max-epochs 100

Seed set to 86
Namespace(h5_root='/media/datadisk/home/22828187/zhanh/202505_dynest_data/workspaces/hdf5s/mazurka_sr22050', sample_rate=22050, fps=50, max_epochs=100, fold=0, seed=86, csv_split='/media/datadisk/home/22828187/zhanh/202505_dynest_data/workspaces/split_5fold_fold0_seed86.csv', val_ratio=0.1, name='mazurka_h5_22050', gpu=0, force_flash_attention=False, compile=['frontend', 'transformer_blocks', 'task_heads'], n_layers=6, transformer_dim=512, frontend_dropout=0.1, transformer_dropout=0.2, lr=0.001, weight_decay=0.01, logger='none', num_workers=8, n_heads=16, loss='shift_tolerant_weighted_bce', warmup_steps=1000, batch_size=8, accumulate_grad_batches=8, train_length=1500, dbn=False, eval_trim_beats=5.0, val_frequency=5, cache_dir='data/_h5_mel_cache', tempo_augmentation=False, pitch_augmentation=False, mask_augmentation=False, sum_head=True, partial_transformers=True, length_based_oversampling_factor=0.65, val=True, hung_data=False, resume_checkpoint=None, resume_id=None)
Tr

In [1]:
h5_root = "/media/datadisk/home/22828187/zhanh/202505_dynest_data/workspaces/hdf5s/mazurka_sr22050"
csv_root = f"/media/datadisk/home/22828187/zhanh/202505_dynest_data/workspaces"

for f in range(5):
    csv_path = f"{csv_root}/split_5fold_fold{f}_seed86.csv"
    ckpt_dir = f"checkpoints/mazurka_h5_22050_S86_F{f}"
    !python beat_this/test.py --h5-root {h5_root} --csv-split {csv_path} --fold {f}\
        --seed 86 --name mazurka_h5_22050 --datasplit test\
        --ckpt-dir {ckpt_dir}

Seed set to 86
Selected 2 checkpoint(s):
  - checkpoints/mazurka_h5_22050_S86_F0/mazurka_h5_22050_S86_F0_epoch09-valloss1.2083.ckpt
  - checkpoints/mazurka_h5_22050_S86_F0/last.ckpt

==== Evaluating checkpoints/mazurka_h5_22050_S86_F0/mazurka_h5_22050_S86_F0_epoch09-valloss1.2083.ckpt ====
Parameters (BeatThis core): trainable=20,251,696 total=20,251,712
Averaged metrics:
  F-measure_beat: 0.7648
  Cemgil_beat: 0.6174
  CMLt_beat: 0.4459
  AMLt_beat: 0.4459
  F-measure_downbeat: 0.4244
  Cemgil_downbeat: 0.3407
  CMLt_downbeat: 0.0564
  AMLt_downbeat: 0.0720

==== Evaluating checkpoints/mazurka_h5_22050_S86_F0/last.ckpt ====
Parameters (BeatThis core): trainable=20,251,696 total=20,251,712
Averaged metrics:
  F-measure_beat: 0.6232
  Cemgil_beat: 0.5296
  CMLt_beat: 0.1636
  AMLt_beat: 0.1814
  F-measure_downbeat: 0.2832
  Cemgil_downbeat: 0.2553
  CMLt_downbeat: 0.0039
  AMLt_downbeat: 0.0299

Saved summary CSV: logs/summary_mazurka_h5_22050_S86_F0.csv
Seed set to 86
Selected 2 checkp