In [43]:
import pandas as pd
from utils import FusionDataset
import torch 
from models import Posterior, Prior, Decoder, Forward
from torch.utils.data import DataLoader, Subset
from train import * 


  from .autonotebook import tqdm as notebook_tqdm


In [31]:
b_field = pd.read_csv('b-field.csv', delimiter=',', index_col=0)


In [32]:
import torch

def custom_loss_fn(pred_n, pred_v_parallel, true_n, true_v_parallel, B, x, t):
    """
    Custom loss function (f_n) that calculates a specific loss based on the model's predictions.

    Args:
        pred_n: Tensor, predicted plasma density (from model)
        pred_v_parallel: Tensor, predicted parallel ion velocity (from model)
        true_n: Tensor, true plasma density (target)
        true_v_parallel: Tensor, true parallel ion velocity (target)
        B: Tensor or scalar, magnetic field value
        x: Tensor, spatial position variable
        t: Tensor, temporal position variable
    
    Returns:
        f_n: Tensor, computed custom loss term
    """
    # Compute temporal derivatives
    d_pred_n_dt = torch.autograd.grad(pred_n, t, grad_outputs=torch.ones_like(pred_n), create_graph=True)[0]
    d_true_n_dt = torch.autograd.grad(true_n, t, grad_outputs=torch.ones_like(true_n), create_graph=True)[0]
    
    # Temporal loss component
    temporal_loss = (d_pred_n_dt - d_true_n_dt) ** 2

    # Compute spatial derivatives
    pred_n_v_parallel = pred_n * pred_v_parallel / B
    true_n_v_parallel = true_n * true_v_parallel / B

    d_pred_nv_dx = torch.autograd.grad(pred_n_v_parallel, x, grad_outputs=torch.ones_like(pred_n_v_parallel), create_graph=True)[0]
    d_true_nv_dx = torch.autograd.grad(true_n_v_parallel, x, grad_outputs=torch.ones_like(true_n_v_parallel), create_graph=True)[0]

    # Spatial loss component
    spatial_loss = B ** 2 * (d_pred_nv_dx - d_true_nv_dx) ** 2

    # Total custom loss term
    f_n = temporal_loss.mean() + spatial_loss.mean()
    return f_n


In [35]:
train_data = FusionDataset('.data/train')
val_data = FusionDataset('.data/val')
test_data = FusionDataset('.data/test')

In [44]:
train_subset = Subset(train_data, list(range(100)))
val_subset = Subset(val_data, list(range(100)))
test_subset = Subset(test_data, list(range(100)))

In [45]:
train_loader = DataLoader(train_subset, batch_size = 16)
val_loader = DataLoader(val_subset, batch_size = 16)
test_loader = DataLoader(test_subset, batch_size = 16)

In [46]:
def train_step(x_t, x_t_plus1, forward_t, forward_tplus1, prior, posterior, decoder, opt):
    # Move to GPU
    x_t = x_t.to("cuda")
    x_t_plus1 = x_t_plus1.to("cuda")
    
    # Forward pass
    h_t = forward_t(x_t)
    h_tplus1 = forward_tplus1(x_t_plus1)
    z, mu, log_var = posterior(h_t, h_tplus1)
    kl_nll = keras.ops.sum(posterior.log_prob(z, mu, log_var) - prior.log_prob(h_t, z), axis=(1,2,3))
    rec_ll = decoder.log_prob(x_t_plus1, decoder(z, h_t))
    loss = keras.ops.mean(-rec_ll + kl_nll)

    # Prepare backward pass
    opt.zero_grad()
    loss.backward()
    
    # Apply gradients
    trainable_weights = forward_t.trainable_weights + forward_tplus1.trainable_weights \
          + prior.trainable_weights + posterior.trainable_weights + decoder.trainable_weights
    gradients = [t.value.grad for t in trainable_weights]
    with torch.no_grad():
        opt.apply_gradients(zip(gradients, trainable_weights))

    return keras.ops.mean(kl_nll).item(), keras.ops.mean(-rec_ll).item()


In [47]:
def evaluate(dataloader, forward_t, forward_tplus1, prior, posterior, decoder):
    # Set models to evaluation mode
    forward_t.eval()
    forward_tplus1.eval()
    prior.eval()
    posterior.eval()
    decoder.eval()

    kl_loss_total, rec_loss_total = 0, 0
    num_batches = len(dataloader)
    
    with torch.no_grad():
        for x_t, x_tplus1 in dataloader:
            x_t = x_t.to("cuda")
            x_t_plus1 = x_tplus1.to("cuda")

            h_t = forward_t(x_t)
            h_tplus1 = forward_tplus1(x_tplus1)
            z, mu, log_var = posterior(h_t, h_tplus1)
            kl_nll = keras.ops.sum(posterior.log_prob(z, mu, log_var) - prior.log_prob(h_t, z), axis=(1,2,3))
            rec_ll = decoder.log_prob(x_t_plus1, decoder(z, h_t))
            
            kl_loss_total += keras.ops.mean(kl_nll).item()
            rec_loss_total += keras.ops.mean(-rec_ll).item()

    # Calculate average losses over all batches
    avg_kl_loss = kl_loss_total / num_batches
    avg_rec_loss = rec_loss_total / num_batches

    return avg_kl_loss, avg_rec_loss
    

In [48]:
def run(train_dataloader, val_dataloader, test_dataloader, forward_t, forward_tplus1, prior, posterior, decoder, optimizer, n_epochs=100):
    train_loss_history = {"kl_loss": [], "rec_loss": []}
    val_loss_history = {"kl_loss": [], "rec_loss": []}
    
    os.makedirs("./results/basic0", exist_ok=True)
    
    for epoch in tqdm(range(n_epochs)):
        # Training phase
        forward_t.train()
        forward_tplus1.train()
        prior.train()
        posterior.train()
        decoder.train()

        for batch_idx, (x_t, x_tplus1) in enumerate(train_dataloader):
            # Initialize scaler on first batch of first epoch
            if epoch == 0 and batch_idx == 0:
                forward_t.layers[0].adapt(x_t)
                forward_tplus1.layers[0].adapt(x_tplus1)
            
            # Training step
            kl_loss, rec_loss = train_step(x_t, x_tplus1, forward_t, forward_tplus1, prior, posterior, decoder, optimizer)
            
            # Moving average for epoch loss
            if batch_idx == 0:
                train_loss_history["kl_loss"].append(kl_loss)
                train_loss_history["rec_loss"].append(rec_loss)
            else:
                train_loss_history["kl_loss"][-1] = train_loss_history["kl_loss"][-1] * (1 - 1 / (batch_idx + 1)) + kl_loss / (batch_idx + 1)
                train_loss_history["rec_loss"][-1] = train_loss_history["rec_loss"][-1] * (1 - 1 / (batch_idx + 1)) + rec_loss / (batch_idx + 1)
        
        # Validation phase
        val_kl_loss, val_rec_loss = evaluate(val_dataloader, forward_t, forward_tplus1, prior, posterior, decoder)
        val_loss_history["kl_loss"].append(val_kl_loss)
        val_loss_history["rec_loss"].append(val_rec_loss)

        # Save models and training history
        forward_t.save("./results/basic0/forward_t.keras")
        forward_tplus1.save("./results/basic0/forward_tplus1.keras")
        prior.save("./results/basic0/prior.keras")
        posterior.save("./results/basic0/posterior.keras")
        decoder.save("./results/basic0/decoder.keras")
        
        # Save training and validation history
        history = {
            "train_loss": train_loss_history,
            "val_loss": val_loss_history,
        }
        with open("./results/basic0/history.json", "w") as f:
            json.dump(history, f)
        
            
    # Test phase (after training completes)
    test_kl_loss, test_rec_loss = evaluate(test_dataloader, forward_t, forward_tplus1, prior, posterior, decoder)
    print(f"Test KL Loss: {test_kl_loss:.4f}, Test Rec Loss: {test_rec_loss:.4f}")


In [50]:
optimizer = torch.optim.Adam(
    list(Forward().parameters()) + list(Forward().parameters()) +
    list(Prior().parameters()) + list(Posterior().parameters()) + list(Decoder().parameters()),
    lr=1e-4  # Set learning rate
)
run(train_dataloader=train_loader, val_dataloader=val_loader, test_dataloader=test_loader, forward_t=Forward(), forward_tplus1=Forward(),prior=Prior(),posterior=Posterior(),decoder=Decoder(), optimizer=optimizer)

  0%|          | 0/100 [00:00<?, ?it/s]


NotImplementedError: Subclasses of Dataset should implement __getitem__.