In [1]:
# !wget -O faces_dataset.tar.gz "https://www.dropbox.com/scl/fi/7dv71y3nxrcdrpmwntr8e/faces_aligned_small_mirrored_co_aligned_cropped_cleaned.tar.gz?rlkey=h03r92h1mdr9yet2tkqosqq1k&dl=1"

In [None]:
# import tarfile

# with tarfile.open("faces_dataset.tar.gz", "r:gz") as tar:
#     tar.extractall("faces_dataset")

In [None]:
# import os
# file_names = os.listdir("/kaggle/working/faces_dataset/faces_aligned_small_mirrored_co_aligned_cropped_cleaned/M")
# len(file_names)

17673

In [None]:
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader

from gda_functions import Config, set_seed_all, DiffusionTrainer, NoisedMixtureStream

# Configure training
cfg = Config()

use_gpu = torch.cuda.is_available()
if not use_gpu:
    # Make precision CPU-compatible and keep runtime reasonable on CPU
    cfg.precision = "32-true"
    cfg.steps = min(cfg.steps, 500)
    cfg.batch_size = min(cfg.batch_size, 2048)

set_seed_all(cfg.seed)

device = torch.device("cuda", torch.cuda.current_device()) if use_gpu else torch.device("cpu")

# Dataset and DataLoader (kept here so we can read mu0 and u)
stream = NoisedMixtureStream(cfg, device=device)
train_loader = DataLoader(
    stream,
    batch_size=None,
    num_workers=0,
    pin_memory=use_gpu,
)

# Model and Trainer
model = DiffusionTrainer(cfg)

n_gpus = torch.cuda.device_count() if use_gpu else 0
trainer = pl.Trainer(
    max_steps=cfg.steps,
    accelerator="gpu" if use_gpu else "cpu",
    devices=n_gpus if use_gpu else None,
    precision=cfg.precision,
    log_every_n_steps=cfg.log_every_n_steps,
    enable_progress_bar=True,
    enable_checkpointing=False,
    gradient_clip_val=cfg.grad_clip_norm,
)

trainer.fit(model, train_dataloaders=train_loader)

# Extract and show learned parameters and data family params
denoiser = model.model  # HalfSpaceDenoiser
state = {k: v.detach().cpu() for k, v in denoiser.state_dict().items()}

print("Learned parameters:")
print("w0:", state.get("w0"))
print("w1:", state.get("w1"))
if getattr(denoiser, "use_extended", False):
    print("w2:", state.get("w2"))
    print("w3:", state.get("w3"))
print("b:", state.get("b"))

print("\nData family parameters:")
print("mu0:", stream.mu0.detach().cpu())
print("u:", stream.u.detach().cpu())

