# Using Conceptarium Without Hydra

This notebook demonstrates how to use the Conceptarium benchmarking tool without Hydra configuration files. 

**What you'll learn:**
- Creating datasets with concept annotations
- Instantiating a simple Concept Bottleneck Model (CBM)
- Training with PyTorch Lightning
- Making predictions on new data

**Key objects:**
- **Annotations**: Metadata describing your concepts (names, types, cardinalities)
- **ConceptDataset**: PyTorch dataset wrapper for concept-based learning
- **ConceptDataModule**: Lightning DataModule to handle data loading and splitting
- **CBM**: our CBM model, implemented as a torch.nn.Module
- **Predictor**: The LightningModule object build from the CBM model. The structure and functionalities of this LightningModule are shared across all models and datasets, used to ensure a unified engine that handles the full train/val/test loop

## 1. Setup Python Path

Since `conceptarium` is not installed as a package, we add its parent directory to Python's search path.

**Why this is needed:** The notebook is in `conceptarium/examples/`, but we need to import from `conceptarium/conceptarium/`.

In [1]:
# Add parent directory to path so we can import conceptarium
import sys
from pathlib import Path

# Get the path to the parent directory (where conceptarium folder is)
parent_path = Path.cwd().parent
if str(parent_path) not in sys.path:
    sys.path.insert(0, str(parent_path))
    
print(f"Added to path: {parent_path}")
print(f"Python path: {sys.path[:3]}")

Added to path: /home/gdefelice/Projects/pytorch_concepts/conceptarium
Python path: ['/home/gdefelice/Projects/pytorch_concepts/conceptarium', '/home/gdefelice/miniconda3/envs/conceptarium/lib/python312.zip', '/home/gdefelice/miniconda3/envs/conceptarium/lib/python3.12']


## 2. Import Required Libraries

**Core libraries:**
- `torch`: PyTorch for neural networks
- `pytorch_lightning`: Training framework

**Conceptarium components:**
- `Annotations`, `AxisAnnotation`: Describe concept structure
- `ConceptDataset`: Dataset wrapper for concept data
- `ConceptDataModule`: Handles train/val/test splits and dataloaders
- `DeterministicInference`: Inference engine for the PGM
- `CBM`: Concept Bottleneck Model
- `Predictor`: Training engine

In [2]:
import torch
import numpy as np
from pytorch_lightning import Trainer

# Conceptarium imports
from torch_concepts import Annotations, AxisAnnotation
from torch_concepts.data import ToyDataset
from torch_concepts.data.base import ConceptDataset
from torch_concepts.nn import DeterministicInference
from conceptarium.data.base.datamodule import ConceptDataModule
from conceptarium.nn.models.cbm import CBM
from conceptarium.engines.predictor import Predictor

## 3. Create Synthetic Dataset

Generate a simple toy dataset to demonstrate the framework.

**Dataset structure:**
- **Inputs (X)**: 2-dimensional random features
- **Concepts (C)**: 2 binary concepts derived from input features
  - `concept_0`: 1 first feature > 0
  - `concept_1`: 1 second feature > 0  
- **Task (Y)**: Binary classification (XOR of the two concepts)

**Note:** In Conceptarium, tasks are treated equally to concepts. Bboth names and values need to be concatenated. If an explicit separation of the task is needed by the model (as in the case of a standard CBM), this should (and will) be handled by the model.

In [16]:
# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Hyperparameters
n_samples = 1000

# Generate synthetic XOR dataset manually
x = torch.rand(n_samples, 2)  # 2D random features in [0, 1]

# Create binary concepts based on thresholds
c1 = (x[:, 0] > 0.5).float().unsqueeze(1)  # concept_1: first feature > 0.5
c2 = (x[:, 1] > 0.5).float().unsqueeze(1)  # concept_2: second feature > 0.5
c = torch.cat([c1, c2], dim=1)

# Create XOR task: y = c1 XOR c2
y = (c1 != c2).float()

concept_names_raw = ['concept_1', 'concept_2']
task_names_raw = ['task_xor']

# combine concept names into a single list
concept_names = concept_names_raw + task_names_raw

# same for data
concepts = torch.concat([c, y], dim=1)

print(f"Dataset loaded:")
print(f"  Features shape: {x.shape}")
print(f"  Concepts shape: {concepts.shape}")
print(f"  Concept names: {concept_names}")

Dataset loaded:
  Features shape: torch.Size([1000, 2])
  Concepts shape: torch.Size([1000, 3])
  Concept names: ['concept_1', 'concept_2', 'task_xor']


## 4. Define Annotations

Annotations provide metadata about your concepts.

**Required information:**
- **labels**: Concept names (e.g., `['concept_0', 'concept_1', 'task_xor']`)
- **metadata**: Dictionary with `type` for each concept (`'discrete'` or `'continuous'`)
- **cardinalities**: Number of classes per concept (use `1` for binary concepts)

**Key insight:** Cardinality of 1 means binary concept (optimized representation). Cardinality > 1 means multi-class categorical concept.

**Annotations structure:**
- Axis 0 (optional): Sample annotations
- Axis 1 (required): Concept annotations

In [4]:
# Define concept names and task name
# treating task as a concept
concept_names = ['concept_1', 'concept_2', 'task_xor']

# Create metadata for each concept/task
metadata = {
    'concept_1': {'type': 'discrete'},
    'concept_2': {'type': 'discrete'},
    'task_xor': {'type': 'discrete'},
}

# Cardinalities: use 1 for binary concepts/tasks (for optimization)
cardinalities = (1, 1, 1)

# Create AxisAnnotation for concepts
concept_annotation = AxisAnnotation(
    labels=concept_names,
    metadata=metadata,
    cardinalities=cardinalities
)

# Create full Annotations object.
# Axis 0 for samples, if you need to annotate each sample separately
# Axis 1 for concept annotations
annotations = Annotations({
    1: concept_annotation  # Concept axis
})

print(f"Annotations created for {len(concept_names)} variables")
print(f"All labels: {concept_names}")

Annotations created for 3 variables
All labels: ['concept_1', 'concept_2', 'task_xor']


## 5. Create ConceptDataset

Wrap raw data and annotations into a PyTorch-compatible dataset.

**Input format:**
- `input_data`: Tensor of shape `(n_samples, n_features)`
- `concepts`: Tensor of shape `(n_samples, n_concepts)` - includes both concepts and tasks
- `annotations`: Annotations object from previous step

**Output format** (what you get from `dataset[i]`):
```python
{
    'inputs': {'x': tensor of shape (n_features,)},
    'concepts': {'c': tensor of shape (n_concepts,)}
}
```

In [5]:
# Create ConceptDataset
dataset = ConceptDataset(
    input_data=x,
    concepts=concepts,
    annotations=annotations
)

print(f"Dataset created:")
print(f"  Total samples: {len(dataset)}")
print(f"  Sample structure: {list(dataset[0].keys())}")
print(f"  Input shape: {dataset[0]['inputs']['x'].shape}")
print(f"  Concepts shape: {dataset[0]['concepts']['c'].shape}")

Dataset created:
  Total samples: 1000
  Sample structure: ['inputs', 'concepts']
  Input shape: torch.Size([2])
  Concepts shape: torch.Size([3])


## 6. Create DataModule

DataModule handles data splitting and creates train/val/test dataloaders.

**Key parameters:**
- `val_size`, `test_size`: Fraction of data for validation and test (0.0-1.0)
- `batch_size`: Number of samples per batch
- `backbone`: Optional pretrained model for feature extraction (we use `None` for raw inputs)
- `precompute_embs`: Whether to precompute embeddings with backbone and store them on disk.
- `scalers`: Optional data normalization (not needed for discrete concepts)

**After `setup('fit')`:** Dataset is split and ready for training.

In [6]:
# Create DataModule
datamodule = ConceptDataModule(
    dataset=dataset,
    val_size=0.1,
    test_size=0.2,
    batch_size=32,
    backbone=None,  # No pretrained backbone
    precompute_embs=False, # No need to precompute embeddings with backbone
    scalers=None,  # No scaling is needed for discrete concepts
    workers=0
)

# Setup the data (split into train/val/test)
datamodule.setup('fit')

print(f"DataModule created:")
print(f"  Train samples: {datamodule.train_len}")
print(f"  Val samples: {datamodule.val_len}")
print(f"  Test samples: {datamodule.test_len}")
print(f"  Batch size: {datamodule.batch_size}")

Input shape: (1000, 2)
Using raw input data without backbone preprocessing.
DataModule created:
  Train samples: 700
  Val samples: 100
  Test samples: 200
  Batch size: 32


## 7. Define Variable Distributions

Specify which probability distributions to use for different concept types.

**Distribution types:**
- `discrete_card1`: For binary concepts (cardinality = 1)
  - Uses `RelaxedBernoulli` for differentiable sampling
- `discrete_cardn`: For multi-class concepts (cardinality > 1)
  - Uses `RelaxedOneHotCategorical`
- `continuous_card1/cardn`: For continuous concepts
  - Uses `Delta` distribution (deterministic)

**Temperature parameter:** Lower values (e.g., 0.1) make sampling closer to discrete/deterministic.

**Note:** The model automatically selects the correct distribution based on each concept's cardinality.

In [7]:
# Variable distributions map distribution types to their configurations
# This tells the model which distribution to use for each type of concept
# Here we define the distribution for binary concepts/tasks, as they all have cardinality 1
variable_distributions = {
    # For binary concepts (cardinality = 1)
    'discrete_card1': {
        'path': 'torch.distributions.RelaxedBernoulli',
        'kwargs': {}
    }
}

print("Variable distributions defined:")
for key, config in variable_distributions.items():
    print(f"  {key}: {config['path']}")

Variable distributions defined:
  discrete_card1: torch.distributions.RelaxedBernoulli


## 8. Create CBM Model

Initialize a Concept Bottleneck Model.

**Key parameters:**
- `task_names`: Since a CBM separates concepts from task, provide list of task variable names (subset of concept labels).
- `inference`: Inference engine class (e.g., `DeterministicInference`)
- `input_size`: Dimensionality of input features
- `annotations`: Concept metadata from step 4
- `variable_distributions`: Distribution configs from step 7
- `encoder_kwargs`: Kwargs of the encoder network.

**Model architecture:**
1. **Encoder**: Input â†’ Embedding (MLP layers)
2. **Model PGM**: Embedding â†’ Concepts â†’ Tasks

**Note:** The model creates a Probabilistic Graphical Model (PGM) internally to represent concept relationships.

In [8]:
# Task names (concepts that are predictions, not observations)
task_names = ('task_xor',)

# Create CBM model
latent_dims = 64  # Hidden layer size in the encoder

model = CBM(
    task_names=task_names,
    inference=DeterministicInference,
    input_size=x.shape[1],
    annotations=annotations,
    variable_distributions=variable_distributions,
    encoder_kwargs={'hidden_size': 16,
                    'n_layers': 1,
                    'activation': 'leaky_relu',
                    'dropout': 0.}
)

print(f"CBM model created:")
print(f"  Input size: {x.shape[1]}")
print(f" Encoder: {model.encoder}")
print(f" Model PGM: {model.pgm}")

CBM model created:
  Input size: 2
 Encoder: MLP(
  (mlp): Sequential(
    (0): Dense(
      (affinity): Linear(in_features=2, out_features=16, bias=True)
      (activation): LeakyReLU(negative_slope=0.01)
      (dropout): Identity()
    )
  )
)
 Model PGM: ProbabilisticGraphicalModel(
  (factors): ModuleDict(
    (embedding): Factor(concepts=['embedding'], module=Identity)
    (concept_1): Factor(concepts=['concept_1'], module=ProbEncoderFromEmb)
    (concept_2): Factor(concepts=['concept_2'], module=ProbEncoderFromEmb)
    (task_xor): Factor(concepts=['task_xor'], module=ProbPredictor)
  )
)


## 9. Setup Loss Functions and Metrics

Define how to compute loss and evaluate model performance.

**Loss configuration:**
- `discrete.binary`: Loss function for binary concepts
  - `BCEWithLogitsLoss`: Binary cross-entropy for logits (includes sigmoid)

**Metrics configuration:**
- `discrete.binary.accuracy`: Accuracy metric for binary concepts
  - `threshold: 0.0`: For logit inputs (since logits can be negative)

**Format:** Each config specifies:
- `path`: Full import path to the class
- `kwargs`: Arguments to pass to the class constructor

**Note:** The Predictor automatically applies the correct loss/metric based on concept type and cardinality.

In [9]:
# Loss configuration
loss_config = {
    'discrete': {
        'binary': {
            'path': 'torch.nn.BCEWithLogitsLoss',
            'kwargs': {}
        }
    }
}

# Metrics configuration
metrics_config = {
    'discrete': {
        'binary': {
            'accuracy': {
                'path': 'torchmetrics.classification.BinaryAccuracy',
                'kwargs': {}
            }
        }
    }
}

print("Loss and metrics configured:")
print(f"  Binary loss: {loss_config['discrete']['binary']['path']}")
print(f"  Binary accuracy: {metrics_config['discrete']['binary']['accuracy']['path']}")

Loss and metrics configured:
  Binary loss: torch.nn.BCEWithLogitsLoss
  Binary accuracy: torchmetrics.classification.BinaryAccuracy


## 10. Create Predictor (Training Engine)

The Predictor wraps the model and handles the training loop.

**Key parameters:**
- `model`: CBM model from step 8
- `loss`, `metrics`: Configurations from step 9
- `enable_summary_metrics`: Compute metrics averaged across all concepts of each type
- `enable_perconcept_metrics`: Compute separate metrics for each individual concept. Also list of concepts names can be provided. 'True' abilitate it for all concepts
- `optim_class`: Optimizer (e.g., `torch.optim.AdamW`)
- `optim_kwargs`: Optimizer parameters (e.g., learning rate)
- `scheduler_class`: Learning rate scheduler (optional)
- `scheduler_kwargs`: Scheduler parameters (optional)

**Trainer configuration:**
- `max_epochs`: Maximum number of training epochs
- `accelerator`: Hardware to use (`'auto'` detects GPU/CPU automatically)
- `devices`: Number of GPUs/CPUs to use
- `callbacks`: Training callbacks (e.g., `EarlyStopping` to stop when validation loss stops improving)

**What it does:**
- Computes forward pass and loss
- Updates model parameters
- Logs metrics to TensorBoard/WandB
- Handles train/validation/test steps

In [14]:
# Create Predictor (PyTorch Lightning Module)
engine = Predictor(
    model=model,
    loss=loss_config,
    metrics=metrics_config,
    preprocess_inputs=False, # whether to preprocess inputs (e.g., scaling)
    scale_concepts=False, # whether to scale concepts before loss computation
    enable_summary_metrics=True, 
    enable_perconcept_metrics=True,
    optim_class=torch.optim.AdamW,
    optim_kwargs={'lr': 0.0007},
    scheduler_class=None,
    scheduler_kwargs=None,
)

# Create Trainer
trainer = Trainer(
    max_epochs=500,
    accelerator='auto',
    devices=1,
)

print(f"Predictor and Trainer created:")
print(f"Predictor: {engine}")

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


loss configuration validated (all binary):
  Binary (card=1): {'path': 'torch.nn.BCEWithLogitsLoss', 'kwargs': {}}
  Categorical (card>1): unused
  continuous: unused
metrics configuration validated (all binary):
  Binary (card=1): {'accuracy': {'path': 'torchmetrics.classification.BinaryAccuracy', 'kwargs': {}}}
  Categorical (card>1): unused
  continuous: unused
Predictor and Trainer created:
Predictor: Predictor(model=CBM, n_concepts=3, optimizer=AdamW, scheduler=None)


## 11. Train the Model

Use PyTorch Lightning Trainer for the training loop.

**Training process:**
1. For each epoch: train on all batches, validate on validation set
2. Log metrics (loss, accuracy) for monitoring
3. Stop early if validation loss doesn't improve for `patience` epochs
4. Save best model checkpoint

In [15]:
# Train the model
trainer.fit(engine, datamodule=datamodule)

print("\nTraining completed!")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name           | Type              | Params | Mode 
-------------------------------------------------------------
0 | model          | CBM               | 85     | eval 
1 | binary_loss_fn | BCEWithLogitsLoss | 0      | train
2 | train_metrics  | MetricCollection  | 0      | train
3 | val_metrics    | MetricCollection  | 0      | train
4 | test_metrics   | MetricCollection  | 0      | train
-------------------------------------------------------------
85        Trainable params
0         Non-trainable params
85        Total params
0.000     Total estimated model params size (MB)
16        Modules in train mode
27        Modules in eval mode

  | Name           | Type              | Params | Mode 
-------------------------------------------------------------
0 | model          | CBM               | 85     | eval 
1 | binary_loss_fn | BCEWithLogitsLoss | 0      | train
2 | train_metrics  | MetricCollection  | 0      | train
3 | val_metrics

Input shape: (1000, 2)
Using raw input data without backbone preprocessing.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=500` reached.



Training completed!


## 12. Test the Model

Evaluate the trained model on the held-out test set.

**What it does:**
- Runs the model on all test batches
- Computes test metrics (loss, accuracy)
- Returns a dictionary with all test results

**Interpreting results:**
- `test_loss`: Average loss on test set
- `test/SUMMARY-binary_accuracy`: Overall accuracy across all binary concepts
- `test/concept_0_accuracy`, etc.: Per-concept accuracies (if enabled)

In [17]:
# Test the model
test_results = trainer.test(engine, datamodule=datamodule)

print("\nTest results:")
for key, value in test_results[0].items():
    print(f"  {key}: {value:.4f}")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Input shape: (1000, 2)
Using raw input data without backbone preprocessing.


Testing: |          | 0/? [00:00<?, ?it/s]

â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
        Test metric                 DataLoader 0
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
test/SUMMARY-binary_accuracy     0.903333306312561
  test/concept_1_accuracy               1.0
  test/concept_2_accuracy        0.9950000047683716
   test/task_xor_accuracy        0.7149999737739563
         test_loss            

## 13. Make Predictions

Use the trained model to make predictions on new data.

**Prediction process:**
1. Get a batch from the test dataloader
2. Set model to evaluation mode (`engine.eval()`)
3. Use `predict_batch()` to get model outputs
4. Convert logits to probabilities with `torch.sigmoid()` (for binary concepts)

**Output format:**
- Raw predictions are **logits** (unbounded values)
- Apply **sigmoid** to get probabilities in [0, 1]
- For binary concepts: probability > 0.5 â†’ class 1, else class 0

**Comparing with ground truth:**
- Predictions shape: `(batch_size, n_concepts)`
- Ground truth shape: `(batch_size, n_concepts)`
- Each column corresponds to one concept/task

In [18]:
# Get a test batch
test_loader = datamodule.test_dataloader()
batch = next(iter(test_loader))

# Make predictions
engine.eval()
with torch.no_grad():
    predictions = engine.predict_batch(batch)

print(f"Predictions shape: {predictions.shape}")
print(f"\nFirst 5 predictions (logits):")
print(predictions[:5])

# Convert logits to probabilities
probs = torch.sigmoid(predictions[:5])
print(f"\nFirst 5 predictions (probabilities):")
print(probs)

# Ground truth
print(f"\nFirst 5 ground truth:")
print(batch['concepts']['c'][:5])

Predictions shape: torch.Size([32, 3])

First 5 predictions (logits):
tensor([[-2.4557e+01, -3.1189e+01, -6.3304e-03],
        [ 3.9395e+00,  1.2526e+01, -6.5807e-02],
        [-2.2230e+01, -7.3331e+00, -6.2835e-03],
        [ 1.4344e+01,  2.4587e+01, -6.8361e-02],
        [-1.1450e+01,  8.1740e+00,  6.5451e-02]])

First 5 predictions (probabilities):
tensor([[2.1638e-11, 2.8508e-14, 4.9842e-01],
        [9.8091e-01, 1.0000e+00, 4.8355e-01],
        [2.2170e-10, 6.5311e-04, 4.9843e-01],
        [1.0000e+00, 1.0000e+00, 4.8292e-01],
        [1.0645e-05, 9.9972e-01, 5.1636e-01]])

First 5 ground truth:
tensor([[0., 0., 0.],
        [1., 1., 0.],
        [0., 0., 0.],
        [1., 1., 0.],
        [0., 1., 1.]])
