# This is going to test the dingomodel we currently have using data thats been zero_padded to keep the same length

I think this will probably not have good results, but we can find out.

In [2]:
from JHPY import *
from JHPY import DINGOModel
from typing import Dict


In [2]:
def load_dataloaders(load_path: str, batch_size: int = None, shuffle_train: bool = True) -> Dict:
    """
    Load previously saved datasets and create DataLoaders.
    
    Parameters
    ----------
    load_path : str
        Path to the saved data file (created with save_dataloaders)
    batch_size : int, optional
        Batch size for DataLoaders. If None, uses the batch_size from metadata.
    shuffle_train : bool
        Whether to shuffle training data. Default: True
    
    Returns
    -------
    dict with 'train_loader', 'val_loader', 'test_loader', 'metadata'
    
    Examples
    --------
    >>> # Save after generation
    >>> result = pycbc_data_generator(config, num_samples=1000)
    >>> save_dataloaders(result, 'my_data.pt')
    >>> 
    >>> # Later, load the data
    >>> loaded = load_dataloaders('my_data.pt')
    >>> train_loader = loaded['train_loader']
    >>> val_loader = loaded['val_loader']
    >>> test_loader = loaded['test_loader']
    """
    print(f"Loading datasets from {load_path}...")
    
    # Load the saved data
    save_data = torch.load(load_path, weights_only=False)
    
    X = save_data['X']
    y = save_data['y']
    train_indices = save_data['train_indices']
    val_indices = save_data['val_indices']
    test_indices = save_data['test_indices']
    metadata = save_data['metadata']
    
    # Use saved batch_size if not provided
    if batch_size is None:
        batch_size = metadata['batch_size']
    else:
        # Update metadata with new batch_size
        metadata = metadata.copy()
        metadata['batch_size'] = batch_size
    
    print(f"  Tensors: X={X.shape}, y={y.shape}")
    print(f"  Splits: train={len(train_indices)}, val={len(val_indices)}, test={len(test_indices)}")
    
    # Recreate the dataset
    full_dataset = TensorDataset(X, y)
    
    # Create subsets using the saved indices
    from torch.utils.data import Subset
    train_data = Subset(full_dataset, train_indices)
    val_data = Subset(full_dataset, val_indices)
    test_data = Subset(full_dataset, test_indices)
    
    # Create DataLoaders
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=shuffle_train)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    print(f"\nReady! DataLoaders with batch_size={batch_size}")
    
    return {
        'train_loader': train_loader,
        'val_loader': val_loader,
        'test_loader': test_loader,
        'metadata': metadata
    }

In [3]:
data_loaders = load_dataloaders("data_loaders_zero_padded.pt", batch_size=128)

model = create_dingo_from_data(data_loaders, context_dim=256, num_flow_layers=10, hidden_dim=1024, multi_detector_mode='concatenate')

Loading datasets from data_loaders_zero_padded.pt...
  Tensors: X=torch.Size([10000, 2, 6274]), y=torch.Size([10000, 2])
  Splits: train=8000, val=1000, test=1000

Ready! DataLoaders with batch_size=128
Creating DINGOModel with inferred dimensions:
  data_dim (time length per detector): 6274
  num_detectors: 2
  param_dim: 2
  context_dim: 256
  num_flow_layers: 10
  multi_detector_mode: concatenate


In [4]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=3
)

training_stuff = train_npe_model(model, optimizer, 50, data_loaders["train_loader"], data_loaders["val_loader"], scheduler=scheduler, patience=5, dropout_rate=0.1, save_best_model=True, model_path="Dingo.pt")

Epoch 1, training: 100%|██████████| 63/63 [00:26<00:00,  2.34it/s]
Epoch 1, validation: 100%|██████████| 8/8 [00:00<00:00, 17.62it/s]



[Epoch  1]
Training - Log Prob: -24.1018
Validation - Log Prob: -7.6950
Current learning rates: [0.0005]
New best validation performance 

Model checkpoint saved to Dingo.pt



Epoch 2, training: 100%|██████████| 63/63 [00:26<00:00,  2.34it/s]
Epoch 2, validation: 100%|██████████| 8/8 [00:00<00:00, 17.69it/s]



[Epoch  2]
Training - Log Prob: -7.6375
Validation - Log Prob: -7.5989
Current learning rates: [0.0005]
New best validation performance 

Model checkpoint saved to Dingo.pt



Epoch 3, training: 100%|██████████| 63/63 [00:26<00:00,  2.37it/s]
Epoch 3, validation: 100%|██████████| 8/8 [00:00<00:00, 17.32it/s]



[Epoch  3]
Training - Log Prob: -7.5974
Validation - Log Prob: -7.5550
Current learning rates: [0.0005]
New best validation performance 

Model checkpoint saved to Dingo.pt



Epoch 4, training: 100%|██████████| 63/63 [00:26<00:00,  2.34it/s]
Epoch 4, validation: 100%|██████████| 8/8 [00:00<00:00, 17.59it/s]



[Epoch  4]
Training - Log Prob: -7.5388
Validation - Log Prob: -7.5208
Current learning rates: [0.0005]
New best validation performance 

Model checkpoint saved to Dingo.pt



Epoch 5, training: 100%|██████████| 63/63 [00:27<00:00,  2.32it/s]
Epoch 5, validation: 100%|██████████| 8/8 [00:00<00:00, 16.34it/s]



[Epoch  5]
Training - Log Prob: -7.5340
Validation - Log Prob: -7.5161
Current learning rates: [0.0005]
New best validation performance 

Model checkpoint saved to Dingo.pt



Epoch 6, training: 100%|██████████| 63/63 [00:27<00:00,  2.33it/s]
Epoch 6, validation: 100%|██████████| 8/8 [00:00<00:00, 16.00it/s]



[Epoch  6]
Training - Log Prob: -7.5289
Validation - Log Prob: -7.5359
Current learning rates: [0.0005]


Epoch 7, training: 100%|██████████| 63/63 [00:26<00:00,  2.34it/s]
Epoch 7, validation: 100%|██████████| 8/8 [00:00<00:00, 16.42it/s]



[Epoch  7]
Training - Log Prob: -7.5270
Validation - Log Prob: -7.4815
Current learning rates: [0.0005]
New best validation performance 

Model checkpoint saved to Dingo.pt



Epoch 8, training: 100%|██████████| 63/63 [00:27<00:00,  2.30it/s]
Epoch 8, validation: 100%|██████████| 8/8 [00:00<00:00, 15.59it/s]



[Epoch  8]
Training - Log Prob: -7.5195
Validation - Log Prob: -7.5051
Current learning rates: [0.0005]


Epoch 9, training: 100%|██████████| 63/63 [00:27<00:00,  2.27it/s]
Epoch 9, validation: 100%|██████████| 8/8 [00:00<00:00, 16.38it/s]



[Epoch  9]
Training - Log Prob: -7.5076
Validation - Log Prob: -7.5286
Current learning rates: [0.0005]


Epoch 10, training: 100%|██████████| 63/63 [00:27<00:00,  2.32it/s]
Epoch 10, validation: 100%|██████████| 8/8 [00:00<00:00, 16.80it/s]



[Epoch 10]
Training - Log Prob: -7.5183
Validation - Log Prob: -7.4951
Current learning rates: [0.0005]


Epoch 11, training: 100%|██████████| 63/63 [00:26<00:00,  2.36it/s]
Epoch 11, validation: 100%|██████████| 8/8 [00:00<00:00, 16.08it/s]



[Epoch 11]
Training - Log Prob: -7.5233
Validation - Log Prob: -7.5197
Learning rate reduced. Loading best model from Dingo.pt
Best model loaded, resuming training

Current learning rates: [0.00025]


Epoch 12, training: 100%|██████████| 63/63 [00:26<00:00,  2.39it/s]
Epoch 12, validation: 100%|██████████| 8/8 [00:00<00:00, 16.99it/s]


[Epoch 12]
Training - Log Prob: -7.4876
Validation - Log Prob: -7.4838
Current learning rates: [0.00025]
No improvement in validation log prob in last 5 epochs 






In [5]:
sin_data = generate_sine_data(phase_low=0,phase_high=0,num_points=6274)

generating 10000 samples for training
data generated


In [6]:
config = {
            'data_dim': 6274,
            'param_dim': 3,
            'context_dim': 256,
            'num_flow_layers': 10,
            'hidden_dim': 1024,
        }

sin_model = DINGOModel(config=config)

In [8]:
optimizer = torch.optim.Adam(sin_model.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=3
)

sin_stuff = train_npe_model(sin_model, optimizer, 50, sin_data["Train_Loader"], sin_data["Val_Loader"], scheduler=scheduler, patience=5, dropout_rate=0.1, save_best_model=True, model_path="Dingo_sin.pt")

Epoch 1, training: 100%|██████████| 32/32 [00:17<00:00,  1.86it/s]
Epoch 1, validation: 100%|██████████| 4/4 [00:00<00:00, 10.03it/s]



[Epoch  1]
Training - Log Prob: -1.6919
Validation - Log Prob: 3.1525
Current learning rates: [0.0005]
New best validation performance 

Model checkpoint saved to Dingo_sin.pt



Epoch 2, training: 100%|██████████| 32/32 [00:17<00:00,  1.85it/s]
Epoch 2, validation: 100%|██████████| 4/4 [00:00<00:00,  9.19it/s]



[Epoch  2]
Training - Log Prob: 3.5562
Validation - Log Prob: 4.2544
Current learning rates: [0.0005]
New best validation performance 

Model checkpoint saved to Dingo_sin.pt



Epoch 3, training: 100%|██████████| 32/32 [00:17<00:00,  1.86it/s]
Epoch 3, validation: 100%|██████████| 4/4 [00:00<00:00, 10.11it/s]



[Epoch  3]
Training - Log Prob: 4.2017
Validation - Log Prob: 4.8206
Current learning rates: [0.0005]
New best validation performance 

Model checkpoint saved to Dingo_sin.pt



Epoch 4, training: 100%|██████████| 32/32 [00:16<00:00,  1.89it/s]
Epoch 4, validation: 100%|██████████| 4/4 [00:00<00:00,  9.38it/s]



[Epoch  4]
Training - Log Prob: 4.7273
Validation - Log Prob: 4.4464
Current learning rates: [0.0005]


Epoch 5, training: 100%|██████████| 32/32 [00:17<00:00,  1.86it/s]
Epoch 5, validation: 100%|██████████| 4/4 [00:00<00:00,  9.56it/s]



[Epoch  5]
Training - Log Prob: 4.7836
Validation - Log Prob: 5.4061
Current learning rates: [0.0005]
New best validation performance 

Model checkpoint saved to Dingo_sin.pt



Epoch 6, training: 100%|██████████| 32/32 [00:17<00:00,  1.87it/s]
Epoch 6, validation: 100%|██████████| 4/4 [00:00<00:00,  9.27it/s]



[Epoch  6]
Training - Log Prob: 5.1034
Validation - Log Prob: 5.9813
Current learning rates: [0.0005]
New best validation performance 

Model checkpoint saved to Dingo_sin.pt



Epoch 7, training: 100%|██████████| 32/32 [00:17<00:00,  1.85it/s]
Epoch 7, validation: 100%|██████████| 4/4 [00:00<00:00,  9.79it/s]



[Epoch  7]
Training - Log Prob: 5.0228
Validation - Log Prob: 5.7739
Current learning rates: [0.0005]


Epoch 8, training: 100%|██████████| 32/32 [00:17<00:00,  1.86it/s]
Epoch 8, validation: 100%|██████████| 4/4 [00:00<00:00,  9.96it/s]



[Epoch  8]
Training - Log Prob: 5.5380
Validation - Log Prob: 5.6534
Current learning rates: [0.0005]


Epoch 9, training: 100%|██████████| 32/32 [00:17<00:00,  1.87it/s]
Epoch 9, validation: 100%|██████████| 4/4 [00:00<00:00, 10.64it/s]



[Epoch  9]
Training - Log Prob: 5.9277
Validation - Log Prob: 6.7346
Current learning rates: [0.0005]
New best validation performance 

Model checkpoint saved to Dingo_sin.pt



Epoch 10, training: 100%|██████████| 32/32 [00:17<00:00,  1.86it/s]
Epoch 10, validation: 100%|██████████| 4/4 [00:00<00:00, 10.12it/s]



[Epoch 10]
Training - Log Prob: 5.3087
Validation - Log Prob: 5.8160
Current learning rates: [0.0005]


Epoch 11, training: 100%|██████████| 32/32 [00:17<00:00,  1.87it/s]
Epoch 11, validation: 100%|██████████| 4/4 [00:00<00:00,  8.93it/s]



[Epoch 11]
Training - Log Prob: 5.5220
Validation - Log Prob: 4.9522
Current learning rates: [0.0005]


Epoch 12, training: 100%|██████████| 32/32 [00:17<00:00,  1.87it/s]
Epoch 12, validation: 100%|██████████| 4/4 [00:00<00:00, 10.48it/s]



[Epoch 12]
Training - Log Prob: 5.6822
Validation - Log Prob: 5.6409
Current learning rates: [0.0005]


Epoch 13, training:   9%|▉         | 3/32 [00:01<00:16,  1.76it/s]


KeyboardInterrupt: 