In [9]:
import pytorch_lightning as pl
from distributions import *
from models import *
from visualization import *
from dataloaders import *

train_loader, test_loader = get_dataloaders(dataset_name='mnist', batch_size=512, num_workers=0)

plot_config = PlotConfig(
    show_plots=True,
    selected_plots=['embeddings', 'probabilities_star', 'neighborhood_dist'] # Added neighborhood_dist
)

plotly_callback = PlotLogger(config=plot_config)

# --- Corrected IConConfig ---
config = IConConfig(
    # --- Mapper (for 3D Sphere Plot) ---
    #mapper = MLPMapper(
    #    input_dim=784, # MNIST flattened 28*28
    #    hidden_dims=(512, 512, 1024),
    #    output_dim=3,
    #    normalize=True, # Important for sphere visualization and cosine-based kernels
    #    softmax=False,
    #    input_key='image', # Specify input key from batch
    #    output_key='embeddings' # Specify output key for batch
    #),
    mapper = SimpleCNN(
        output_dim=2,           # Output 3 dimensions for visualization
        input_key='image',      # Takes 'image' from batch (expects CxHxW, handled by dataloader)
        output_key='embeddings', # Outputs 'embeddings'
        # --- IMPORTANT ---
        # SimpleCNN applies normalization/sphere projection internally.
        # Decide if you want these applied by the CNN:
        normalize_feats=True, # Standardize features before final layer? (Try False first)
        unit_sphere=False,      # Normalize output to unit sphere? (Set True if using Gaussian/Cosine)
        poincare_ball=False    # Not applicable here
    ),

    # --- Distributions (Kernel Modules) ---
    supervisory_distribution = Label(
        input_key='label',     # Use the 'label' field from the batch
        normalize=True,        # Output a probability distribution (rows sum to 1)
        mask_diagonal=True     # Don't compare samples with themselves
    ),
    
    learned_distribution = Gaussian(
        input_key='embeddings', # Use the 'embeddings' field produced by the mapper
        sigma=.5,              # Sigma for the Gaussian kernel
        normalize=True,         # Output a probability distribution
        mask_diagonal=True      # Don't compare samples with themselves
        # metric='cosine' # Optional: if mapper normalizes, cosine distance might be suitable
    ),

    # --- General Settings ---
    num_classes=10,
    accuracy_mode='regular',
    linear_probe=True,
    loss_type='kl',         # KL divergence between the two distributions
    log_icon_loss=True,    # Important if loss_type='kl' expects log-probabilities

    # --- Optimizer & Scheduler ---
    lr=5e-3,
    optimizer='adamw',
    weight_decay=0,

    # --- Other ICon Settings ---
    use_ema=False,
    ema_momentum=0.5,
    use_mixed_precision=False,
    gradient_clip_val=10,
)

# --- Model and Trainer (as before) ---
icon_model_3d = IConModel(config=config)

trainer = pl.Trainer(
        max_epochs=2,
        check_val_every_n_epoch=1,
        callbacks=[plotly_callback],
        accelerator="auto",
        devices="auto",
        #precision='64-true'
        # precision="16-mixed" # Enable if use_mixed_precision=True
        # logger=...
    )

# --- Run Training ---
# This should now work without the pickling error if dataloaders.py is fixed
trainer.fit(icon_model_3d, train_loader, test_loader)

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name                     | Type      | Params | Mode 
---------------------------------------------------------------
0 | mapper                   | SimpleCNN | 222 K  | train
1 | supervisory_distribution | Label     | 0      | train
2 | learned_distribution     | Gaussian  | 0      | train
3 | linear_probe             | Linear    | 30     | train
4 | train_acc                | Accuracy  | 0      | train
5 | val_acc                  | Accuracy  | 0      | train
---------------------------------------------------------------
222 K     Trainable params
0         Non-trainable params
222 K     Total params
0.889     Total estimated model params size (MB)
16        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).



VBox(children=(HBox(children=(Play(value=0, description='Play Epochs', interval=1500, max=0), IntSlider(value=…


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.



Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [10]:
help(IConConfig)

Help on class IConConfig in module models.model_config:

class IConConfig(builtins.object)
 |  IConConfig(mapper: Union[torch.nn.modules.module.Module, Sequence[torch.nn.modules.module.Module]], supervisory_distribution: torch.nn.modules.module.Module, learned_distribution: torch.nn.modules.module.Module, mapper2: Optional[torch.nn.modules.module.Module] = None, num_classes: Optional[int] = None, lr: float = 0.0005, accuracy_mode: Optional[str] = None, use_ema: bool = False, ema_momentum: float = 0.999, loss_type: str = 'ce', decay_factor: float = 0.9, linear_probe: bool = False, optimizer: str = 'adamw', weight_decay: float = 0.0, gradient_clip_val: float = 10.0, use_mixed_precision: bool = False, log_icon_loss: bool = True) -> None
 |  
 |  Configuration class for KernelModel with validation.
 |  
 |  Methods defined here:
 |  
 |  __eq__(self, other)
 |      Return self==value.
 |  
 |  __init__(self, mapper: Union[torch.nn.modules.module.Module, Sequence[torch.nn.modules.module.Mod