In [1]:
"""
Train hierarchical RNN (minLSTM/minGRU) autoencoder on light curve time series.

Uses parallelizable RNN variants from "Were RNNs All We Needed?" (arXiv:2410.01201)
Supports training with block masking and time-aware positional encoding.
"""

import numpy as np
import argparse
import sys
from pathlib import Path
from functools import partial
from tqdm import tqdm
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
import pickle


# Add project root to path
sys.path.append('/Users/philvanlane/Documents/lc_ae/')
from src.lcgen.models.rnn import HierarchicalRNN, RNNConfig
from src.lcgen.data.masking import dynamic_block_mask


In [5]:
with open("data/real_lightcurves/pk_star_sector_lc_cf.pickle", "rb") as file:
    lc = pickle.load(file)

In [62]:
# Use sample light curve and sector
sample = lc['405461319_42']

# Pre-process time
time = sample['time']
sector_time = (time - time[0])

# Pre-process flux
flux = np.array(sample['flux'].data)
mask = (np.isfinite(flux) & np.isfinite(time))
flux = flux[mask]
sector_time = sector_time[mask]
flux_norm = (flux - np.nanmean(flux)) / np.nanstd(flux)

# Pre-process flux error
flux_err = np.array(sample['flux_err'].data)
med_flux_error = np.nanmedian(flux_err)
flux_err = flux_err[mask]
flux_err = np.nan_to_num(flux_err, nan=med_flux_error, posinf=med_flux_error, neginf=med_flux_error)


# Metadata
metadata = {
    'tic': sample['TIC_ID'],
    'sector': sample['sector'],
    'duration': sector_time[-1] - sector_time[0],
    'med_flux_error': med_flux_error,
    'n_points': len(flux),
    'mean_flux': np.mean(flux),
    'std_flux': np.std(flux),
}

In [68]:
len(lc.keys())

7911

In [None]:
times = []
fluxes = []
flux_errs = []
metadatas = []

for i,k in enumerate(lc.keys()):
    if i % 100 == 0:
        print(f"Processing light curve {i}")
    # Use sample light curve and sector
    sample = lc[k]

    # Pre-process time
    time = sample['time']
    sector_time = (time - time[0])

    # Pre-process flux
    flux = np.array(sample['flux'].data)
    mask = (np.isfinite(flux) & np.isfinite(time))
    flux = flux[mask]
    sector_time = sector_time[mask]
    flux_norm = (flux - np.nanmean(flux)) / np.nanstd(flux)

    # Pre-process flux error
    flux_err = np.array(sample['flux_err'].data)
    med_flux_error = np.nanmedian(flux_err)
    flux_err = flux_err[mask]
    flux_err = np.nan_to_num(flux_err, nan=med_flux_error, posinf=med_flux_error, neginf=med_flux_error)


    # Metadata
    metadata = {
        'tic': sample['TIC_ID'],
        'sector': sample['sector'],
        'duration': sector_time[-1] - sector_time[0],
        'med_flux_error': med_flux_error,
        'n_points': len(flux),
        'mean_flux': np.nanmean(flux),
        'std_flux': np.nanstd(flux),
    }
    times.append(sector_time)
    fluxes.append(flux_norm)
    flux_errs.append(flux_err)
    metadatas.append(metadata)

Processing light curve 0
Processing light curve 100
Processing light curve 200
Processing light curve 300
Processing light curve 400
Processing light curve 500
Processing light curve 600
Processing light curve 700
Processing light curve 800
Processing light curve 900
Processing light curve 1000
Processing light curve 1100
Processing light curve 1200
Processing light curve 1300
Processing light curve 1400
Processing light curve 1500
Processing light curve 1600
Processing light curve 1700
Processing light curve 1800
Processing light curve 1900
Processing light curve 2000
Processing light curve 2100
Processing light curve 2200
Processing light curve 2300
Processing light curve 2400
Processing light curve 2500
Processing light curve 2600
Processing light curve 2700
Processing light curve 2800
Processing light curve 2900
Processing light curve 3000
Processing light curve 3100
Processing light curve 3200
Processing light curve 3300
Processing light curve 3400
Processing light curve 3500
Proc

In [72]:
dict = {
    'times': times,
    'fluxes': fluxes,
    'flux_errs': flux_errs,
    'metadatas': metadatas,
}

In [79]:
len(dict['fluxes'][19])

15150

In [73]:
with open("data/real_lightcurves/star_sector_lc_formatted.pickle", "wb") as file:
    pickle.dump(dict, file)

In [66]:
flux_err.shape

(13081,)

In [48]:
sample.keys()



dict_keys(['TIC_ID', 'sector', 'time', 'time_adj', 'flux', 'flux_err', 'flux_norm', 'flux_mean', 'asinh_mean', 'norm_asinh_mean', 'time_norm', 'flux_norm_standard', 'flux_err_norm_standard', 'flux_norm_absTmag', 'flux_err_norm_absTmag'])

In [None]:
metadata = {
    'tic': sample['TIC_ID'],
    'sector': sample['sector'],

405461319

In [12]:
lc['405461319_42'].keys()

dict_keys(['TIC_ID', 'sector', 'time', 'time_adj', 'flux', 'flux_err', 'flux_norm', 'flux_mean', 'asinh_mean', 'norm_asinh_mean', 'time_norm', 'flux_norm_standard', 'flux_err_norm_standard', 'flux_norm_absTmag', 'flux_err_norm_absTmag'])

In [2]:
with open('data/mock_lightcurves/mock_lightcurves.pkl', 'rb') as f:
    mock_lc = pickle.load(f)

In [67]:
mock_lc.keys()

dict_keys(['times', 'fluxes', 'flux_errs', 'metadatas'])

In [55]:
mock_lc['times'][0]

array([0.00000000e+00, 3.25398265e-01, 6.50796530e-01, 9.76194795e-01,
       1.30159306e+00, 1.62699132e+00, 1.95238959e+00, 2.60318612e+00,
       3.25398265e+00, 3.57938091e+00, 3.90477918e+00, 4.23017744e+00,
       4.55557571e+00, 4.88097397e+00, 5.20637224e+00, 5.85716877e+00,
       6.18256703e+00, 6.50796530e+00, 6.83336356e+00, 7.48416009e+00,
       8.13495662e+00, 8.46035489e+00, 8.78575315e+00, 9.43654968e+00,
       9.76194795e+00, 1.00873462e+01, 1.04127445e+01, 1.07381427e+01,
       1.10635410e+01, 1.13889393e+01, 1.17143375e+01, 1.20397358e+01,
       1.23651341e+01, 1.26905323e+01, 1.30159306e+01, 1.33413289e+01,
       1.36667271e+01, 1.39921254e+01, 1.43175237e+01, 1.46429219e+01,
       1.49683202e+01, 1.52937185e+01, 1.59445150e+01, 1.62699132e+01,
       1.65953115e+01, 1.69207098e+01, 1.75715063e+01, 1.78969046e+01,
       1.82223028e+01, 1.85477011e+01, 1.91984976e+01, 1.95238959e+01,
       1.98492942e+01, 2.01746924e+01, 2.05000907e+01, 2.08254890e+01,
      

In [19]:
with open('data/real_lightcurves/star_day_timeseries.pickle', 'rb') as f:
    lc = pickle.load(f)

In [30]:
lc['384984325_2902'].keys()

dict_keys(['tic', 'day', 'time_d_padded', 'flux_d_padded', 'flux_err_d_padded', 'norm_mean', 'norm_std'])

In [17]:
lc['metadatas'][0]

{'lc_type': 'multiperiodic',
 'sampling_strategy': 'astronomical',
 'duration': 379.3972738151323,
 'n_points': 471,
 'snr': 17.302907326988453,
 'offset': 10.40237733344247,
 'noise_std': 0.07874219299524557,
 'periods': array([79.44703621, 53.82900147, 38.08426937]),
 'amplitudes': array([1.33483645, 1.04650404, 0.77585134]),
 'phases': array([5.96601403, 0.92408993, 5.82192175])}

In [11]:



class LightCurveDataset(Dataset):
    """Dataset for light curve time series with timestamps."""

    def __init__(self, flux, timestamps):
        """
        Args:
            flux: Numpy array of shape (N, seq_len)
            timestamps: Numpy array of shape (N, seq_len)
        """
        self.flux = torch.from_numpy(flux).float()
        self.timestamps = torch.from_numpy(timestamps).float()

    def __len__(self):
        return len(self.flux)

    def __getitem__(self, idx):
        # Return flux and timestamps - masking will be applied in collate_fn
        return {
            'flux': self.flux[idx],
            'time': self.timestamps[idx]
        }


def collate_with_masking(batch, min_block_size=1, max_block_size=None,
                         min_mask_ratio=0.1, max_mask_ratio=0.9):
    """
    Custom collate function that applies masking to light curves.

    Creates input as [masked_flux, mask_indicator] for the transformer.
    """
    # Extract flux and timestamps
    batch_flux = torch.stack([item['flux'] for item in batch], dim=0)  # [batch_size, seq_len]
    batch_time = torch.stack([item['time'] for item in batch], dim=0)  # [batch_size, seq_len]

    batch_size, seq_len = batch_flux.shape

    if max_block_size is None:
        max_block_size = seq_len // 2

    # Apply masking per sample
    batch_inputs = []
    batch_targets = []
    batch_masks = []
    batch_block_sizes = []
    batch_mask_ratios = []

    for i in range(batch_size):
        flux = batch_flux[i]

        # Apply dynamic block masking to flux
        flux_masked, mask, block_size, mask_ratio = dynamic_block_mask(
            flux,
            min_block_size=min_block_size,
            max_block_size=max_block_size,
            min_mask_ratio=min_mask_ratio,
            max_mask_ratio=max_mask_ratio
        )

        # Stack flux and mask as two channels: [masked_flux, mask_indicator]
        # Shape: (seq_len, 2)
        input_with_mask = torch.stack([flux_masked, mask.float()], dim=-1)

        batch_inputs.append(input_with_mask)
        batch_targets.append(flux)
        batch_masks.append(mask)
        batch_block_sizes.append(block_size)
        batch_mask_ratios.append(mask_ratio)

    return {
        'input': torch.stack(batch_inputs),  # (batch, seq_len, 2)
        'target': torch.stack(batch_targets),  # (batch, seq_len)
        'time': batch_time,  # (batch, seq_len)
        'mask': torch.stack(batch_masks),  # (batch, seq_len)
        'block_size': torch.tensor(batch_block_sizes),
        'mask_ratio': torch.tensor(batch_mask_ratios)
    }


def load_lightcurve_data(hdf5_file):
    """
    Load light curve data from HDF5 file.

    Expected structure:
        - 'flux': (N, seq_len) - normalized flux values
        - 'time': (N, seq_len) - timestamps

    Args:
        hdf5_file: Path to HDF5 file

    Returns:
        flux: Numpy array of shape (N, seq_len)
        timestamps: Numpy array of shape (N, seq_len)
    """
    with h5py.File(hdf5_file, 'r') as f:
        flux = f['flux'][:]
        timestamps = f['time'][:]
        print(f"Loaded flux: {flux.shape}, time: {timestamps.shape}")

        # Check for NaNs
        n_nans_flux = np.sum(np.isnan(flux))
        n_nans_time = np.sum(np.isnan(timestamps))
        if n_nans_flux > 0:
            print(f"Warning: {n_nans_flux} NaN values found in flux")
            flux = np.nan_to_num(flux, nan=0.0)
        if n_nans_time > 0:
            print(f"Warning: {n_nans_time} NaN values found in timestamps")
            timestamps = np.nan_to_num(timestamps, nan=0.0)

        return flux, timestamps


def train_epoch(model, dataloader, optimizer, criterion, device, scheduler=None):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    total_masked_loss = 0
    total_unmasked_loss = 0
    n_batches = 0

    for batch in dataloader:
        x_input = batch['input'].to(device)  # (batch, seq_len, 2) [flux, mask]
        x_target = batch['target'].to(device)  # (batch, seq_len)
        timestamps = batch['time'].to(device)  # (batch, seq_len)
        mask = batch['mask'].to(device)  # (batch, seq_len)

        # Forward pass
        optimizer.zero_grad()
        output = model(x_input, timestamps)
        x_recon = output['reconstructed'].squeeze(-1)  # (batch, seq_len)

        # Compute loss over ALL regions (both masked and unmasked)
        loss = criterion(x_recon, x_target)

        # Also track masked vs unmasked separately for monitoring
        masked_loss = criterion(x_recon[mask], x_target[mask])
        unmasked_loss = criterion(x_recon[~mask], x_target[~mask])

        # Backprop and optimize on full reconstruction loss
        loss.backward()

        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        # Step scheduler after each batch (for OneCycleLR)
        if scheduler is not None:
            scheduler.step()

        total_loss += loss.item()
        total_masked_loss += masked_loss.item()
        total_unmasked_loss += unmasked_loss.item()
        n_batches += 1

    return {
        'total_loss': total_loss / n_batches,
        'masked_loss': total_masked_loss / n_batches,
        'unmasked_loss': total_unmasked_loss / n_batches
    }


def validate_epoch(model, dataloader, criterion, device):
    """Validate for one epoch."""
    model.eval()
    total_loss = 0
    total_masked_loss = 0
    total_unmasked_loss = 0
    n_batches = 0

    with torch.no_grad():
        for batch in dataloader:
            x_input = batch['input'].to(device)
            x_target = batch['target'].to(device)
            timestamps = batch['time'].to(device)
            mask = batch['mask'].to(device)

            # Forward pass
            output = model(x_input, timestamps)
            x_recon = output['reconstructed'].squeeze(-1)

            # Compute loss
            loss = criterion(x_recon, x_target)
            masked_loss = criterion(x_recon[mask], x_target[mask])
            unmasked_loss = criterion(x_recon[~mask], x_target[~mask])

            total_loss += loss.item()
            total_masked_loss += masked_loss.item()
            total_unmasked_loss += unmasked_loss.item()
            n_batches += 1

    return {
        'total_loss': total_loss / n_batches,
        'masked_loss': total_masked_loss / n_batches,
        'unmasked_loss': total_unmasked_loss / n_batches
    }


def plot_reconstruction_examples(model, dataloader, device, save_path, n_examples=4):
    """Plot reconstruction examples."""
    model.eval()

    # Get one batch
    batch = next(iter(dataloader))
    x_input = batch['input'].to(device)[:n_examples]
    x_target = batch['target'].to(device)[:n_examples]
    timestamps = batch['time'].to(device)[:n_examples]
    mask = batch['mask'].to(device)[:n_examples]
    block_sizes = batch['block_size'][:n_examples].numpy()
    mask_ratios = batch['mask_ratio'][:n_examples].numpy()

    with torch.no_grad():
        output = model(x_input, timestamps)
        x_recon = output['reconstructed'].squeeze(-1)

    # Move to CPU for plotting
    x_target = x_target.cpu().numpy()
    x_recon = x_recon.cpu().numpy()
    timestamps_np = timestamps.cpu().numpy()
    mask = mask.cpu().numpy()

    # Create plots
    fig, axes = plt.subplots(n_examples, 1, figsize=(12, 3*n_examples))
    if n_examples == 1:
        axes = [axes]

    for i, ax in enumerate(axes):
        t = timestamps_np[i]
        target = x_target[i]
        recon = x_recon[i]
        m = mask[i]
        block_size = block_sizes[i]
        mask_ratio = mask_ratios[i]

        # Plot target and reconstruction
        ax.plot(t, target, 'k-', alpha=0.6, label='Target', linewidth=1.5)
        ax.plot(t, recon, 'b-', alpha=0.8, label='Reconstruction', linewidth=1.5)

        # Highlight masked regions
        if m.sum() > 0:
            # Find contiguous masked regions
            masked_indices = np.where(m)[0]
            ax.scatter(t[masked_indices], recon[masked_indices],
                      c='red', s=10, alpha=0.5, label='Masked regions')

        ax.legend()
        ax.set_xlabel('Time')
        ax.set_ylabel('Flux')
        ax.set_title(f'Example {i+1} (masked {mask_ratio*100:.1f}%, block size {block_size})')
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()
    print(f"Saved reconstruction plot to {save_path}")


def main():
    parser = argparse.ArgumentParser(description='Train hierarchical RNN (minLSTM/minGRU) autoencoder on light curves')
    parser.add_argument('--input', type=str,
                        default='data/mock_lightcurves/timeseries.h5',
                        help='Path to HDF5 file with light curve time series')
    parser.add_argument('--output_dir', type=str, default='models/rnn',
                        help='Directory to save model checkpoints')
    parser.add_argument('--epochs', type=int, default=50,
                        help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='Batch size')
    parser.add_argument('--lr', type=float, default=1e-3,
                        help='Learning rate')
    parser.add_argument('--encoder_dims', type=int, nargs='+', default=[64, 128, 256, 512],
                        help='Encoder dimensions for hierarchical levels')
    parser.add_argument('--rnn_type', type=str, default='minlstm', choices=['minlstm', 'minGRU'],
                        help='RNN cell type (minlstm or minGRU)')
    parser.add_argument('--num_layers', type=int, default=4,
                        help='Number of hierarchical levels')
    parser.add_argument('--num_layers_per_level', type=int, default=2,
                        help='RNN layers per hierarchy level')
    parser.add_argument('--mask_ratio', type=float, default=0.5,
                        help='Maximum masking ratio')
    parser.add_argument('--block_size', type=int, default=32,
                        help='Maximum block size for masking')
    parser.add_argument('--val_split', type=float, default=0.15,
                        help='Validation set fraction')
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed')
    parser.add_argument('--device', type=str, default='auto',
                        help='Device (cuda/cpu/auto)')
    parser.add_argument('--save_every', type=int, default=10,
                        help='Save checkpoint every N epochs')

    args = parser.parse_args()

    # Set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # Device
    if args.device == 'auto':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(args.device)
    print(f"Using device: {device}")

    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load data
    print(f"\nLoading data from {args.input}...")
    flux, timestamps = load_lightcurve_data(args.input)

    print(f"\nDataset statistics:")
    print(f"  Shape: {flux.shape}")
    print(f"  Flux range: [{np.min(flux):.3f}, {np.max(flux):.3f}]")
    print(f"  Flux mean: {np.mean(flux):.3f}")
    print(f"  Flux std: {np.std(flux):.3f}")
    print(f"  Time range: [{np.min(timestamps):.3f}, {np.max(timestamps):.3f}]")

    # Create dataset
    dataset = LightCurveDataset(flux, timestamps)

    # Split into train/val
    val_size = int(len(dataset) * args.val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(
        dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(args.seed)
    )

    print(f"\nDataset split:")
    print(f"  Train: {train_size} samples")
    print(f"  Val: {val_size} samples")

    # Create collate function with masking parameters
    collate_fn = partial(
        collate_with_masking,
        min_block_size=1,
        max_block_size=args.block_size,
        min_mask_ratio=0.1,
        max_mask_ratio=args.mask_ratio
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=0,
        collate_fn=collate_fn,
        pin_memory=True if device.type == 'cuda' else False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0,
        collate_fn=collate_fn,
        pin_memory=True if device.type == 'cuda' else False
    )

    # Create model
    seq_len = flux.shape[1]
    print(f"\nCreating Hierarchical RNN ({args.rnn_type}) autoencoder...")
    config = RNNConfig(
        input_dim=2,  # [flux, mask]
        input_length=seq_len,
        encoder_dims=args.encoder_dims,
        rnn_type=args.rnn_type,
        num_layers_per_level=args.num_layers_per_level,
        dropout=0.0,  # No dropout
        min_period=0.00278,
        max_period=1640.0
    )

    model = HierarchicalRNN(config).to(device)

    # Count parameters
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model parameters: {n_params:,}")

    # Optimizer and scheduler
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)

    # OneCycleLR scheduler
    total_steps = len(train_loader) * args.epochs
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=args.lr,
        total_steps=total_steps,
        pct_start=0.1,
        div_factor=1e2,
        final_div_factor=1e2
    )

    # Loss function
    criterion = nn.MSELoss()

    # Training loop
    print(f"\nStarting training for {args.epochs} epochs...")
    max_block_display = args.block_size if args.block_size else "seq_len//2"
    print(f"Dynamic masking: block size [1, {max_block_display}], "
          f"mask ratio [10%, {args.mask_ratio*100:.0f}%]")

    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    for epoch in range(args.epochs):
        # Train
        train_metrics = train_epoch(model, train_loader, optimizer, criterion, device, scheduler)
        train_losses.append(train_metrics['total_loss'])

        # Validate
        val_metrics = validate_epoch(model, val_loader, criterion, device)
        val_losses.append(val_metrics['total_loss'])

        # Print progress (match MLP format)
        print(f"Epoch {epoch+1}/{args.epochs}")
        print(f"  Train - Loss: {train_metrics['total_loss']:.6f} | "
              f"Masked: {train_metrics['masked_loss']:.6f} | "
              f"Unmasked: {train_metrics['unmasked_loss']:.6f}")
        print(f"  Val   - Loss: {val_metrics['total_loss']:.6f} | "
              f"Masked: {val_metrics['masked_loss']:.6f} | "
              f"Unmasked: {val_metrics['unmasked_loss']:.6f}")

        # Save best model
        if val_metrics['total_loss'] < best_val_loss:
            best_val_loss = val_metrics['total_loss']
            checkpoint_path = output_dir / 'rnn_best.pt'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_metrics['total_loss'],
                'val_loss': best_val_loss,
                'config': config
            }, checkpoint_path)
            print(f"  Saved best model to {checkpoint_path}")

        # Plot samples every 10 epochs
        if (epoch + 1) % 10 == 0 or epoch == 0:
            plot_path = output_dir / f'rnn_recon_epoch{epoch+1}.png'
            plot_reconstruction_examples(
                model, val_loader, device, plot_path
            )

        # Save periodic checkpoints
        if (epoch + 1) % args.save_every == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_metrics['total_loss'],
                'config': config
            }, output_dir / f'rnn_checkpoint_epoch{epoch+1}.pt')

    # Save final model
    final_path = output_dir / 'rnn_final.pt'
    torch.save({
        'epoch': args.epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'config': config
    }, final_path)
    print(f"\nSaved final model to {final_path}")

    # Plot training curves
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.title(f'Hierarchical RNN ({args.rnn_type}) Training')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(output_dir / 'rnn_training_curve.png', dpi=150)
    plt.close()

    print(f"\nTraining complete!")
    print(f"Best validation loss: {best_val_loss:.6f}")


if __name__ == '__main__':
    main()


usage: ipykernel_launcher.py [-h] [--input INPUT] [--output_dir OUTPUT_DIR]
                             [--epochs EPOCHS] [--batch_size BATCH_SIZE]
                             [--lr LR]
                             [--encoder_dims ENCODER_DIMS [ENCODER_DIMS ...]]
                             [--rnn_type {minlstm,minGRU}]
                             [--num_layers NUM_LAYERS]
                             [--num_layers_per_level NUM_LAYERS_PER_LEVEL]
                             [--mask_ratio MASK_RATIO]
                             [--block_size BLOCK_SIZE] [--val_split VAL_SPLIT]
                             [--seed SEED] [--device DEVICE]
                             [--save_every SAVE_EVERY]
ipykernel_launcher.py: error: unrecognized arguments: --f=/Users/philvanlane/Library/Jupyter/runtime/kernel-v3df2f76f03ef5fc5e9231dfcce9065f09e4b49105.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
