## 1. Setup Python Path and Imports

Add the parent directory to the Python path to import Conceptarium modules.

In [None]:
import sys
from pathlib import Path

# Add parent directory to path
parent_path = Path.cwd().parent
if str(parent_path) not in sys.path:
    sys.path.insert(0, str(parent_path))

print(f"Added to path: {parent_path}")

## 2. Import Required Libraries

Import Hydra and Conceptarium components.

In [None]:
# Configure warnings before importing third-party libraries
import conceptarium.warnings_config  # noqa: F401

from hydra import initialize, compose
from omegaconf import OmegaConf
from hydra.utils import instantiate

from conceptarium.trainer import Trainer
from conceptarium.hydra import parse_hyperparams
from conceptarium.resolvers import register_custom_resolvers
from conceptarium.utils import setup_run_env, clean_empty_configs, update_config_from_data

print("Imports successful!")

## 3. Initialize Hydra and Load Configuration

Use `hydra.initialize()` to set up Hydra in notebook mode, then compose the configuration.
```

In [None]:
config_path = "../conf"
config_name = "sweep"
# Initialize Hydra with the configuration path
with initialize(config_path=config_path, version_base="1.3"):
    # - Compose configuration
    # - Override any parameters as needed
    cfg = compose(config_name=config_name, 
                  overrides=['model=cbm', # any model
                             'dataset=asia']) # any dataset

print(f"Configuration loaded from {config_path}/{config_name}.yaml")
print(f"\nDataset: {cfg.dataset.name}")
print(f"Model: {cfg.model._target_}")
print(f"Max epochs: {cfg.trainer.max_epochs}")
print(f"Batch size: {cfg.dataset.batch_size}\n")

# Print the full configuration
print("=" * 60)
print("Full Configuration:")
print("=" * 60)
print(OmegaConf.to_yaml(cfg))

## 4. Setup Environment

Configure random seeds and devices for reproducibility.

In [None]:
# Set random seed, configure devices
cfg = setup_run_env(cfg)  

# Remove empty config entries. 
# Used for compatibility across models and datasets
cfg = clean_empty_configs(cfg)  

## 5. Instantiate Dataset (DataModule)

Load and prepare the dataset. The datamodule handles:
- Loading raw data (for the bnlearn datasets, the input data is extracted from the hidden representations of an autoencoder)
- Creating annotations (concept metadata)
- The setup method handle the dataset splitting into train/val/test
- Creating dataloaders

In [None]:
datamodule = instantiate(cfg.dataset, _convert_="all")
datamodule.setup('fit')

print(f"\n  Total samples: {len(datamodule.dataset)}")
print(f"  Train: {datamodule.train_len}, Val: {datamodule.val_len}, Test: {datamodule.test_len}")
print(f"  Batch size: {datamodule.batch_size}")
print(f"  Concepts: {list(datamodule.annotations.get_axis_labels(1))}\n")

# Update config based on dataset properties
cfg = update_config_from_data(cfg, datamodule)

## 6. Instantiate Model

Instantiate the model using hydra instantiation.

Concept annotations and graph structure cannot be known before the dataset is instantiated.
For this reason, we instantiate the model only partially with hydra, using the `_partial_` flag. The model is then completed by passing the dataset annotations and graph structure.

- **annotations**: Concept metadata from dataset
- **graph**: Structural dependencies between concepts (if available)

In [None]:
model = instantiate(cfg.model, _convert_="all", _partial_=True)(annotations=datamodule.annotations,
                                                                graph=datamodule.graph)

print(f"  Model class: {model.__class__.__name__}")
print(f"  Model Encoder: {model.encoder}")
print(f"  Model PGM: {model.pgm}")

## 7. Instantiate Engine (Predictor)

Instantiate the training engine using hydra.
The engine wraps the model and handles:
- **Loss computation**: From `engine/loss/*.yaml`
- **Metrics computation**: From `engine/metrics/*.yaml`
- **Optimization**: Optimizer and learning rate
- **Training loops**: Train/validation/test steps

Similarly to the model, the engine is instantiated partially with hydra using the `_partial_` flag, and then completed by passing the model instance.

Finally, instantiate the PyTorch Lightning Trainer from the configuration. 
This define:
- Early stopping (based on validation loss)
- Model checkpointing (saves best model)
- Logging (WandB/TensorBoard)
- Progress bars



In [None]:
engine = instantiate(cfg.engine,  _convert_="all", _partial_=True)(model=model)

trainer = Trainer(cfg)
trainer.logger.log_hyperparams(parse_hyperparams(cfg))

## 8. Train Model

Train the PyTorch Lightning Trainer

In [None]:
# Train the model
trainer.fit(engine, datamodule=datamodule)

print("\nTraining completed!")

## 9. Test Model

Evaluate the trained model on the held-out test set.

In [None]:
test_results = trainer.test(datamodule=datamodule)
trainer.logger.finalize("success")

## 10. Make Predictions (Optional)

Use the trained model to make predictions on test data.

In [None]:
import torch

# Get a test batch
test_loader = datamodule.test_dataloader()
batch = next(iter(test_loader))

print(batch)

# Move engine to correct device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
engine = engine.to(device)

# Make predictions
engine.eval()
with torch.no_grad():
    predictions = engine.predict_batch(batch)

print(f"Predictions shape: {predictions.shape}")
print(f"\nFirst 5 predictions (endogenous):")
print(predictions[:5])

# Convert endogenous to probabilities
probs = torch.sigmoid(predictions[:5])
print(f"\nFirst 5 predictions (probabilities):")
print(probs)

# Ground truth
print(f"\nFirst 5 ground truth:")
print(batch['concepts']['c'][:5])

## 11. Finalize and Cleanup

Close the logger and finish the experiment.

In [None]:
# Finalize logger
trainer.logger.experiment.finish()

print("Experiment finished successfully!")

## Summary

This notebook demonstrated how to:
1. ✅ Load Hydra configuration in a notebook using `initialize()` and `compose()`. Eventually override configuration parameters
2. ✅ Instantiate dataset, model, and engine from config
3. ✅ Train and test a model using PyTorch Lightning
4. ✅ Make predictions with the trained model