Self-supervised and supervised representation-learning experiments for
polymathic-ai/active_matter.
The primary model family in the current repo is Field-Aware ViT-JEPA
(FA-JEPA) in src/models/jepa_vit/, trained by
src/training/train_jepa_vit.py. The repo also keeps H-JEPA and a
supervised Field-Aware ViT baseline for comparison.
From the repo root:
python -m pip install -e .For benchmark notebooks and plotting utilities, install the benchmark extras or
use the checked-in .venv when available:
python -m pip install -e ".[benchmarks]"H-JEPA uses Lightning, which is in the optional hjepa extra:
python -m pip install -e ".[hjepa]"The active-matter loader is implemented in src/data_pipeline/dataset.py.
It expects data_root / "data" / {train,valid,test}.
Two layouts are common in this working tree:
data/
train/
valid/
test/
Use data_root: . or --data-root /home/romina/DL_project for that layout.
data/active_matter/
active_matter.yaml
stats.yaml
data/
train/
valid/
test/
Use data_root: data/active_matter for the Well-style layout. The default
FA-JEPA config currently uses this second layout.
Download helpers live in:
python -m data_pipeline.download_dataPeriodic shift is only physically valid on the untouched 256x256 periodic
domain before resize. The FA-JEPA default leaves it off.
Do not modify the active-matter dataloader for FA-JEPA work unless that is the explicit task.
ActiveMatterPairDataset is the FA-JEPA/H-JEPA training contract:
context:(B, 11, T, H, W)target:(B, 11, T, H, W)alpha:(B,)zeta:(B,)physical_params:(B, 2)ordered as[zeta, alpha]
ActiveMatterWindowDataset is used for DISCO and frozen-encoder evaluation:
x:(B, 11, T, H, W)
Channel order from assemble_channels(...):
- concentration
- velocity_x
- velocity_y
- D_xx
- D_xy
- D_yx
- D_yy
- E_xx
- E_xy
- E_yx
- E_yy
FA-JEPA field groups:
- concentration:
0:1 - velocity:
1:3 - orientation:
3:7 - strain-rate:
7:11
A separate set of single-encoder JEPA baselines lives in
szcharlesji/physical-representation-learning, also as a submodule in this repo.
These do not use field-aware factorisation — every channel goes through a
single backbone — and exist to isolate two design questions for FA-JEPA:
does the regulariser converge? and does the tokeniser need to span time?
| Config | Encoder | Regulariser | Best linear MSE |
Notes |
|---|---|---|---|---|
| VICReg, bs=2 (fp32) | ConvEncoder | VICReg | 0.308 | Small batch — collapses past epoch 11 |
| VICReg, bs=8 (bf16) | ConvEncoder | VICReg | 0.243 | bf16 baseline; recovers from the bs=2 collapse |
| VICReg + FFT | ConvEncoder | VICReg | 0.220 | Band-limited FFT resize replaces the per-dataset crop; preserves periodic BCs |
| SIGReg, FFT | ConvEncoder | SIGReg | 0.631 | Negative control — did not converge, MSE rises with training |
| Conv+Attn, FFT | Conv + 1 transformer block at stage 4 | VICReg | 0.251 | Single attention block over the post-conv spatial grid; barely moves the FFT-only baseline |
| Conv+Attn ×6, FFT$^{\star}$ | Conv + 6 transformer blocks at stage 4 | VICReg | 0.267 | Steeper descent than the single-block run, but cut at epoch 7 by HPC quota — the curve is still monotonically descending |
| ViT3D-d6, FFT | 3D PatchEmbed (4×16×16) + 6 transformer blocks | VICReg | 0.107 | $\sim$3× better than every CNN variant — the change is the tokeniser, not the depth |
Conv+Attn ×6 was terminated early; its row is the best of four checkpoints rather than a converged run.
Configs (in the sister repo):
configs/train_activematter.yaml # VICReg conv baseline (bs=8 bf16)
configs/train_activematter_fft.yaml # + FFT preprocessing
configs/train_activematter_sigreg.yaml # SIGReg negative control
configs/train_activematter_cnn_attn.yaml # Conv + 1 attention block
configs/train_activematter_cnn_attn_d6.yaml # Conv + 6 attention blocks
configs/train_activematter_vit3d.yaml # ViT3D-d6 (winning config)
The released ViT3D-d6 checkpoint at pretrain epoch 29 is on
Hugging Face
(≈15.4 M-param encoder).
FA-JEPA lives in:
src/models/jepa_vit/
src/training/train_jepa_vit.py
configs/jepa_vit_default.yaml
Default architecture:
- context frames:
16 - target frames:
16 - tubelet size:
(2, 16, 16) - embed dim:
384 - heads:
6 - encoder depth:
8 - predictor depth:
6 - encoder dropout:
0.0 - predictor dropout:
0.1 - physics-prior attention: off by default
- resize target:
crop_size: 224withresize_mode: physics_faithful
For a 16 x 256 x 256 clip, each field produces an 8 x 16 x 16
tubelet-token grid. The encoder keeps structured tokens shaped like:
(B, T_tok, H_tok, W_tok, 4, D)
Each factorized encoder block applies:
- field attention over the four field tokens at each tubelet
- spatial attention over patch positions for each fixed time and field
- temporal attention over tubelet-time positions for each fixed patch and field
- token-wise MLP
The predictor is a standard ViT stack over context_tokens + target_queries.
MultiHeadAttention uses torch.nn.functional.scaled_dot_product_attention,
which is load-bearing for memory at realistic batch sizes.
Training objective:
- predictor MSE on target latents
- plus
sigreg_weight * pooled_sigreg pooled_sigreg = 0.5 * (SIGReg(context_time) + SIGReg(target_time))
The target branch is the same encoder. Only target_latents for predictor MSE
are detached; target-view SIGReg still backprops through the shared encoder.
Run FA-JEPA:
python -m training.train_jepa_vit --config configs/jepa_vit_default.yamlEnable W&B:
python -m training.train_jepa_vit \
--config configs/jepa_vit_default.yaml \
--wandbCheckpoints default to:
checkpoints/jepa_vit/
latest.pt
best_model_rank_01.pt
...
epoch_0001.pt
When validation is enabled, best checkpoints are ranked by val_loss. If
validation is disabled, they fall back to training loss.
Check HJEPA.md for detail.
H-JEPA lives in:
src/models/hjepa/
src/training/train_hjepa.py
src/training/hjepa_module.py
configs/hjepa/
This is the Lightning-based hierarchical JEPA comparison path. It uses the pair dataloader and supports VICReg-style pair loss, optional representation SIGReg, EMA targets, and split encoder stages.
Default H-JEPA run:
python -m training.train_hjepa --config configs/hjepa/ablation_base.yamlUseful H-JEPA configs:
configs/hjepa/ablation_base.yaml
H-JEPA checkpoints default to the config's Lightning checkpoint directory, for
example checkpoints/hjepa_activematter.
The supervised baseline is the direct regression comparison method:
src/supervised_baseline/
src/training/train_supervised_baseline.py
configs/supervised_baseline_default.yaml
It uses a Field-Aware ViT-style encoder and directly predicts (alpha, zeta)
with supervised regression, rather than learning through JEPA latent prediction.
Run it with:
python -m training.train_supervised_baseline \
--config configs/supervised_baseline_default.yamlDefault checkpoints:
checkpoints/supervised_baseline/
Current supervised baseline benchmark artifacts are stored in:
benchmark_results/results_supervised_baseline/
eval_results.csv
eval_results.png
eval_results.svg
Frozen-feature evaluation utilities live in:
src/evaluation/features.py
src/evaluation/probe_utils.py
src/evaluation/probes.py
src/evaluation/utils.py
src/evaluation/analyze_jepa_vit_field_attention.py
The intended probe protocol is regression only:
- frozen encoder
- one linear probe
- inverse-distance-weighted kNN
- z-score-normalized
(alpha, zeta) - no backbone finetuning, MLP probe, classification reformulation, or attention pooling
Note: in the current checkout, src/evaluation/run_probe.py is not present,
while src/evaluation/__init__.py still imports it. That means
python -m evaluation.run_probe ... is not currently a valid command until the
entry point is restored. The lower-level feature/probe modules and benchmark
notebooks are present.
Attention-analysis aggregation is available as:
python -m evaluation.analyze_jepa_vit_field_attention \
--checkpoint checkpoints/jepa_vit_running/physics_b128/best_model_rank_03.pt \
--data-root /home/romina/DL_project \
--split valid \
--device cuda \
--batch-size 4 \
--max-samples 256 \
--output-dir benchmark_results/FA_JEPA/attention_results/physics_b128_best_model_rank_03_validCurrent FA-JEPA result notebooks, plots, probe outputs, and attention-analysis artifacts are under:
benchmark_results/FA_JEPA/
Important FA-JEPA notebooks:
benchmark_results/FA_JEPA/all_results_f.ipynbbenchmark_results/FA_JEPA/all_results_f_pooling_per_field.ipynbbenchmark_results/FA_JEPA/jepa_vit_attention_analysis.ipynb
Important FA-JEPA result folders:
benchmark_results/FA_JEPA/probe_results/benchmark_results/FA_JEPA/probe_logs/benchmark_results/FA_JEPA/probe_logs_f/benchmark_results/FA_JEPA/summary_tables/benchmark_results/FA_JEPA/attention_results/
Summary tables currently live at:
benchmark_results/FA_JEPA/summary_tables/
all_probe_metrics.csv
best_by_valid_probe_rows.csv
paper_probe_table.csv
Generated comparison figures include:
benchmark_results/FA_JEPA/linear_metrics_physics_vs_data_only.png
benchmark_results/FA_JEPA/knn_metrics_physics_vs_data_only.png
benchmark_results/FA_JEPA/test_linear_metrics_physics_vs_data_only.png
benchmark_results/FA_JEPA/test_knn_metrics_physics_vs_data_only.png
benchmark_results/FA_JEPA/attention_analysis_data_vs_phy.png
benchmark_results/FA_JEPA/attention_analysis_data_prior_effective_minus_data.png
benchmark_results/FA_JEPA/rel_blended_attention.png
Collapse diagnostics are load-bearing for FA-JEPA. The trainer logs:
embed/variance_meanembed/covariance_off_diag_meanembed/norm_meanembed/context_token/*embed/target_token/*embed/context_pooled/*embed/target_pooled/*monitor/context_token_to_pooled_variance_ratiomonitor/target_token_to_pooled_variance_ratiomonitor/predictor_loss_rolling_meanmonitor/predictor_loss_rolling_std
Warnings trigger when:
embed/variance_mean < 0.1embed/variance_meandecreases for three checks in a row
If either warning appears during model/loss changes, investigate before treating the run as healthy.
Run the full suite:
pytest tests/ -vFocused FA-JEPA tests:
pytest tests/test_jepa_vit_patch_embed.py -v
pytest tests/test_jepa_vit_attention.py -v
pytest tests/test_jepa_vit_encoder.py -v
pytest tests/test_jepa_vit_predictor.py -v
pytest tests/test_jepa_vit_model.py -v
pytest tests/test_sigreg.py -v
pytest tests/test_jepa_vit_training.py -v
pytest tests/test_train_jepa_vit.py -v
pytest tests/test_jepa_vit_smoke.py -vH-JEPA and baseline coverage:
pytest tests/test_hjepa_model.py tests/test_hjepa_training.py tests/test_hjepa_eval.py -v
pytest tests/test_supervised_baseline.py tests/test_eval_supervised.py -vsrc/data_pipeline/: HDF5 loading, augmentation, resize, download helperssrc/models/jepa_vit/: FA-JEPA modulessrc/models/hjepa/: H-JEPA modulessrc/models/common/: shared losses and diagnostics, including SIGRegsrc/models/baseline_jepa/: older baseline JEPA ablationsrc/models/disco/: DISCO next-step predictorsrc/models/drift_jepa/: drift-JEPA experimentssrc/supervised_baseline/: supervised Field-Aware ViT baselinesrc/training/: training entry pointssrc/evaluation/: feature extraction, probes, and analysis utilitiesconfigs/: FA-JEPA, H-JEPA, benchmark, and baseline configsbenchmark_results/: current paper/benchmark notebooks and generated resultstests/: module, training, probe, and smoke testsdocs/references.md: condensed local reference notes
External-reference understanding should be updated in docs/references.md and
then treated as the local source of truth. For normal implementation work, do
not re-fetch external references.