Skip to content

rominay/DL_project

Repository files navigation

DL_project

Self-supervised and supervised representation-learning experiments for polymathic-ai/active_matter.

The primary model family in the current repo is Field-Aware ViT-JEPA (FA-JEPA) in src/models/jepa_vit/, trained by src/training/train_jepa_vit.py. The repo also keeps H-JEPA and a supervised Field-Aware ViT baseline for comparison.

Install

From the repo root:

python -m pip install -e .

For benchmark notebooks and plotting utilities, install the benchmark extras or use the checked-in .venv when available:

python -m pip install -e ".[benchmarks]"

H-JEPA uses Lightning, which is in the optional hjepa extra:

python -m pip install -e ".[hjepa]"

Dataset

The active-matter loader is implemented in src/data_pipeline/dataset.py. It expects data_root / "data" / {train,valid,test}.

Two layouts are common in this working tree:

data/
  train/
  valid/
  test/

Use data_root: . or --data-root /home/romina/DL_project for that layout.

data/active_matter/
  active_matter.yaml
  stats.yaml
  data/
    train/
    valid/
    test/

Use data_root: data/active_matter for the Well-style layout. The default FA-JEPA config currently uses this second layout.

Download helpers live in:

python -m data_pipeline.download_data

Periodic shift is only physically valid on the untouched 256x256 periodic domain before resize. The FA-JEPA default leaves it off.

Data Contract

Do not modify the active-matter dataloader for FA-JEPA work unless that is the explicit task.

ActiveMatterPairDataset is the FA-JEPA/H-JEPA training contract:

  • context: (B, 11, T, H, W)
  • target: (B, 11, T, H, W)
  • alpha: (B,)
  • zeta: (B,)
  • physical_params: (B, 2) ordered as [zeta, alpha]

ActiveMatterWindowDataset is used for DISCO and frozen-encoder evaluation:

  • x: (B, 11, T, H, W)

Channel order from assemble_channels(...):

  1. concentration
  2. velocity_x
  3. velocity_y
  4. D_xx
  5. D_xy
  6. D_yx
  7. D_yy
  8. E_xx
  9. E_xy
  10. E_yx
  11. E_yy

FA-JEPA field groups:

  • concentration: 0:1
  • velocity: 1:3
  • orientation: 3:7
  • strain-rate: 7:11

Conv / ViT3D Baselines

A separate set of single-encoder JEPA baselines lives in szcharlesji/physical-representation-learning, also as a submodule in this repo. These do not use field-aware factorisation — every channel goes through a single backbone — and exist to isolate two design questions for FA-JEPA: does the regulariser converge? and does the tokeniser need to span time?

Config Encoder Regulariser Best linear MSE $\downarrow$ Notes
VICReg, bs=2 (fp32) ConvEncoder VICReg 0.308 Small batch — collapses past epoch 11
VICReg, bs=8 (bf16) ConvEncoder VICReg 0.243 bf16 baseline; recovers from the bs=2 collapse
VICReg + FFT ConvEncoder VICReg 0.220 Band-limited FFT resize replaces the per-dataset crop; preserves periodic BCs
SIGReg, FFT ConvEncoder SIGReg 0.631 Negative control — did not converge, MSE rises with training
Conv+Attn, FFT Conv + 1 transformer block at stage 4 VICReg 0.251 Single attention block over the post-conv spatial grid; barely moves the FFT-only baseline
Conv+Attn ×6, FFT$^{\star}$ Conv + 6 transformer blocks at stage 4 VICReg 0.267 Steeper descent than the single-block run, but cut at epoch 7 by HPC quota — the curve is still monotonically descending
ViT3D-d6, FFT 3D PatchEmbed (4×16×16) + 6 transformer blocks VICReg 0.107 $\sim$3× better than every CNN variant — the change is the tokeniser, not the depth

Conv+Attn ×6 was terminated early; its row is the best of four checkpoints rather than a converged run.

Configs (in the sister repo):

configs/train_activematter.yaml             # VICReg conv baseline (bs=8 bf16)
configs/train_activematter_fft.yaml         # + FFT preprocessing
configs/train_activematter_sigreg.yaml      # SIGReg negative control
configs/train_activematter_cnn_attn.yaml    # Conv + 1 attention block
configs/train_activematter_cnn_attn_d6.yaml # Conv + 6 attention blocks
configs/train_activematter_vit3d.yaml       # ViT3D-d6 (winning config)

The released ViT3D-d6 checkpoint at pretrain epoch 29 is on Hugging Face (≈15.4 M-param encoder).

FA-JEPA

FA-JEPA lives in:

src/models/jepa_vit/
src/training/train_jepa_vit.py
configs/jepa_vit_default.yaml

Default architecture:

  • context frames: 16
  • target frames: 16
  • tubelet size: (2, 16, 16)
  • embed dim: 384
  • heads: 6
  • encoder depth: 8
  • predictor depth: 6
  • encoder dropout: 0.0
  • predictor dropout: 0.1
  • physics-prior attention: off by default
  • resize target: crop_size: 224 with resize_mode: physics_faithful

For a 16 x 256 x 256 clip, each field produces an 8 x 16 x 16 tubelet-token grid. The encoder keeps structured tokens shaped like:

(B, T_tok, H_tok, W_tok, 4, D)

Each factorized encoder block applies:

  1. field attention over the four field tokens at each tubelet
  2. spatial attention over patch positions for each fixed time and field
  3. temporal attention over tubelet-time positions for each fixed patch and field
  4. token-wise MLP

The predictor is a standard ViT stack over context_tokens + target_queries. MultiHeadAttention uses torch.nn.functional.scaled_dot_product_attention, which is load-bearing for memory at realistic batch sizes.

Training objective:

  • predictor MSE on target latents
  • plus sigreg_weight * pooled_sigreg
  • pooled_sigreg = 0.5 * (SIGReg(context_time) + SIGReg(target_time))

The target branch is the same encoder. Only target_latents for predictor MSE are detached; target-view SIGReg still backprops through the shared encoder.

Run FA-JEPA:

python -m training.train_jepa_vit --config configs/jepa_vit_default.yaml

Enable W&B:

python -m training.train_jepa_vit \
  --config configs/jepa_vit_default.yaml \
  --wandb

Checkpoints default to:

checkpoints/jepa_vit/
  latest.pt
  best_model_rank_01.pt
  ...
  epoch_0001.pt

When validation is enabled, best checkpoints are ranked by val_loss. If validation is disabled, they fall back to training loss.

H-JEPA

Check HJEPA.md for detail.

H-JEPA lives in:

src/models/hjepa/
src/training/train_hjepa.py
src/training/hjepa_module.py
configs/hjepa/

This is the Lightning-based hierarchical JEPA comparison path. It uses the pair dataloader and supports VICReg-style pair loss, optional representation SIGReg, EMA targets, and split encoder stages.

Default H-JEPA run:

python -m training.train_hjepa --config configs/hjepa/ablation_base.yaml

Useful H-JEPA configs:

  • configs/hjepa/ablation_base.yaml

H-JEPA checkpoints default to the config's Lightning checkpoint directory, for example checkpoints/hjepa_activematter.

Supervised Baseline

The supervised baseline is the direct regression comparison method:

src/supervised_baseline/
src/training/train_supervised_baseline.py
configs/supervised_baseline_default.yaml

It uses a Field-Aware ViT-style encoder and directly predicts (alpha, zeta) with supervised regression, rather than learning through JEPA latent prediction.

Run it with:

python -m training.train_supervised_baseline \
  --config configs/supervised_baseline_default.yaml

Default checkpoints:

checkpoints/supervised_baseline/

Current supervised baseline benchmark artifacts are stored in:

benchmark_results/results_supervised_baseline/
  eval_results.csv
  eval_results.png
  eval_results.svg

Evaluation Status

Frozen-feature evaluation utilities live in:

src/evaluation/features.py
src/evaluation/probe_utils.py
src/evaluation/probes.py
src/evaluation/utils.py
src/evaluation/analyze_jepa_vit_field_attention.py

The intended probe protocol is regression only:

  • frozen encoder
  • one linear probe
  • inverse-distance-weighted kNN
  • z-score-normalized (alpha, zeta)
  • no backbone finetuning, MLP probe, classification reformulation, or attention pooling

Note: in the current checkout, src/evaluation/run_probe.py is not present, while src/evaluation/__init__.py still imports it. That means python -m evaluation.run_probe ... is not currently a valid command until the entry point is restored. The lower-level feature/probe modules and benchmark notebooks are present.

Attention-analysis aggregation is available as:

python -m evaluation.analyze_jepa_vit_field_attention \
  --checkpoint checkpoints/jepa_vit_running/physics_b128/best_model_rank_03.pt \
  --data-root /home/romina/DL_project \
  --split valid \
  --device cuda \
  --batch-size 4 \
  --max-samples 256 \
  --output-dir benchmark_results/FA_JEPA/attention_results/physics_b128_best_model_rank_03_valid

Benchmark Results

Current FA-JEPA result notebooks, plots, probe outputs, and attention-analysis artifacts are under:

benchmark_results/FA_JEPA/

Important FA-JEPA notebooks:

  • benchmark_results/FA_JEPA/all_results_f.ipynb
  • benchmark_results/FA_JEPA/all_results_f_pooling_per_field.ipynb
  • benchmark_results/FA_JEPA/jepa_vit_attention_analysis.ipynb

Important FA-JEPA result folders:

  • benchmark_results/FA_JEPA/probe_results/
  • benchmark_results/FA_JEPA/probe_logs/
  • benchmark_results/FA_JEPA/probe_logs_f/
  • benchmark_results/FA_JEPA/summary_tables/
  • benchmark_results/FA_JEPA/attention_results/

Summary tables currently live at:

benchmark_results/FA_JEPA/summary_tables/
  all_probe_metrics.csv
  best_by_valid_probe_rows.csv
  paper_probe_table.csv

Generated comparison figures include:

benchmark_results/FA_JEPA/linear_metrics_physics_vs_data_only.png
benchmark_results/FA_JEPA/knn_metrics_physics_vs_data_only.png
benchmark_results/FA_JEPA/test_linear_metrics_physics_vs_data_only.png
benchmark_results/FA_JEPA/test_knn_metrics_physics_vs_data_only.png
benchmark_results/FA_JEPA/attention_analysis_data_vs_phy.png
benchmark_results/FA_JEPA/attention_analysis_data_prior_effective_minus_data.png
benchmark_results/FA_JEPA/rel_blended_attention.png

Collapse Monitoring

Collapse diagnostics are load-bearing for FA-JEPA. The trainer logs:

  • embed/variance_mean
  • embed/covariance_off_diag_mean
  • embed/norm_mean
  • embed/context_token/*
  • embed/target_token/*
  • embed/context_pooled/*
  • embed/target_pooled/*
  • monitor/context_token_to_pooled_variance_ratio
  • monitor/target_token_to_pooled_variance_ratio
  • monitor/predictor_loss_rolling_mean
  • monitor/predictor_loss_rolling_std

Warnings trigger when:

  • embed/variance_mean < 0.1
  • embed/variance_mean decreases for three checks in a row

If either warning appears during model/loss changes, investigate before treating the run as healthy.

Tests

Run the full suite:

pytest tests/ -v

Focused FA-JEPA tests:

pytest tests/test_jepa_vit_patch_embed.py -v
pytest tests/test_jepa_vit_attention.py -v
pytest tests/test_jepa_vit_encoder.py -v
pytest tests/test_jepa_vit_predictor.py -v
pytest tests/test_jepa_vit_model.py -v
pytest tests/test_sigreg.py -v
pytest tests/test_jepa_vit_training.py -v
pytest tests/test_train_jepa_vit.py -v
pytest tests/test_jepa_vit_smoke.py -v

H-JEPA and baseline coverage:

pytest tests/test_hjepa_model.py tests/test_hjepa_training.py tests/test_hjepa_eval.py -v
pytest tests/test_supervised_baseline.py tests/test_eval_supervised.py -v

Repository Layout

  • src/data_pipeline/: HDF5 loading, augmentation, resize, download helpers
  • src/models/jepa_vit/: FA-JEPA modules
  • src/models/hjepa/: H-JEPA modules
  • src/models/common/: shared losses and diagnostics, including SIGReg
  • src/models/baseline_jepa/: older baseline JEPA ablation
  • src/models/disco/: DISCO next-step predictor
  • src/models/drift_jepa/: drift-JEPA experiments
  • src/supervised_baseline/: supervised Field-Aware ViT baseline
  • src/training/: training entry points
  • src/evaluation/: feature extraction, probes, and analysis utilities
  • configs/: FA-JEPA, H-JEPA, benchmark, and baseline configs
  • benchmark_results/: current paper/benchmark notebooks and generated results
  • tests/: module, training, probe, and smoke tests
  • docs/references.md: condensed local reference notes

Notes

External-reference understanding should be updated in docs/references.md and then treated as the local source of truth. For normal implementation work, do not re-fetch external references.

About

Group project for DL course

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors