In [None]:
!pip install git+https://github.com/nicola-decao/s-vae-pytorch.git

In [None]:
!pip install git+https://github.com/trabbani/quantum-vae.git

In [None]:
# Verify installations
import torch
print(f"PyTorch version: {torch.__version__}")

from hyperspherical_vae.distributions import VonMisesFisher
print("Hyperspherical VAE package loaded successfully!")

In [None]:
# Check if GPU is available
if torch.cuda.is_available():
    print(f"GPU is available! Using device: {torch.cuda.get_device_name(0)}")
else:
    print("""
    GPU is not available. For faster training, please ensure you're using a GPU runtime:
    
    1. Go to `Runtime` → `Change runtime type`.
    2. Select `GPU` under Hardware Accelerator.
    3. Restart the runtime (`Runtime` → `Restart runtime`).
    4. Rerun the cells.
    
    If you're running this locally, ensure you have CUDA-enabled PyTorch installed.
    """)

In [None]:
import logging

logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO,
    force=True  # Override existing handlers
)
logging.info("Logging to console works!")


In [13]:
from quantum_vae.training.trainer import train_qvae

config = {
    'num_points': 5000,
    'noise_level': 0.1,
    'input_dim': 3,
    'quantum_dim': 2,
    'w': 0.1,
    'beta': 0.001,
    'batch_size': 64,
    'num_epochs': 200,
    'learning_rate': 1e-3,
    'save_path': 'best_model.pth',
    'seed': 42
}

train_qvae(config)

[13:53:15] INFO: Epoch 001/200 | Train Loss: 1.3120 (Recon: 1.2115, KL: 0.0644, Var: 1.0042) | Val Loss: 1.2394 (Recon: 1.1444, KL: 0.0561, Var: 0.9499)


[13:53:16] INFO: Epoch 002/200 | Train Loss: 1.2070 (Recon: 1.1162, KL: 0.0504, Var: 0.9074) | Val Loss: 1.1562 (Recon: 1.0703, KL: 0.0459, Var: 0.8594)
[13:53:17] INFO: Epoch 003/200 | Train Loss: 1.0789 (Recon: 0.9972, KL: 0.0433, Var: 0.8168) | Val Loss: 1.0307 (Recon: 0.9540, KL: 0.0416, Var: 0.7659)
[13:53:17] INFO: Epoch 004/200 | Train Loss: 1.0195 (Recon: 0.9469, KL: 0.0389, Var: 0.7257) | Val Loss: 0.9657 (Recon: 0.8972, KL: 0.0367, Var: 0.6842)
[13:53:18] INFO: Epoch 005/200 | Train Loss: 0.9093 (Recon: 0.8437, KL: 0.0351, Var: 0.6557) | Val Loss: 0.8728 (Recon: 0.8110, KL: 0.0339, Var: 0.6183)
[13:53:19] INFO: Epoch 006/200 | Train Loss: 0.8560 (Recon: 0.7977, KL: 0.0324, Var: 0.5827) | Val Loss: 0.8143 (Recon: 0.7592, KL: 0.0310, Var: 0.5502)
[13:53:20] INFO: Epoch 007/200 | Train Loss: 0.7866 (Recon: 0.7341, KL: 0.0297, Var: 0.5238) | Val Loss: 0.7497 (Recon: 0.7002, KL: 0.0292, Var: 0.4946)
[13:53:21] INFO: Epoch 008/200 | Train Loss: 0.7290 (Recon: 0.6824, KL: 0.0292, Va