In [3]:
import torch
import torch.nn as nn
import numpy as np 
import os 
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch_geometric.data import Data, DataLoader, InMemoryDataset
from dataset import SimulationDataset
import os
import re
from torch.utils.data import Dataset
import numpy as np


In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
def separate_data_by_file(folders, base_dir):
    # Create dictionaries to store output and forcing data per file
    images = {}
    idx = 0 
    for folder in folders:
        folder_path = os.path.join(base_dir, folder)
        for file in os.listdir(folder_path):
            if file.endswith('.npz'):
                file_path = os.path.join(folder_path, file)
                data = np.load(file_path)
                images[idx] = data
                idx += 1 
                
    return images

In [6]:
# Example usage
base_dir = "/Users/famkevanree/Library/Mobile Documents/com~apple~CloudDocs/Master TUe/Y2/Q1/2AMM40 Adv. Topics in AI/nuclear-fusion/data/preprocessed"  # Specify your base directory
folders = ['50T_ramp_up', '50T_ramp_down']
images = separate_data_by_file(folders, base_dir)


In [7]:
import torch
from torch.utils.data import Dataset
import numpy as np

class TimeSeriesPerInstanceDataset(Dataset):
    def __init__(self, images, normalize=True):
        """
        Args:
            images (dict): A dictionary where each key is an instance and the value is a dictionary with 'output' and 'forcing'.
                           Each entry contains multiple time steps of data.
            normalize (bool): Whether to apply normalization to the inputs and targets per variable.
        """
        self.images = images
        self.normalize = normalize

        # Precompute all (instance_idx, time_idx) pairs to ensure each item in the dataset
        self.index_pairs = []
        for key, data in self.images.items():
            timesteps_in_instance = data['output'].shape[0] - 1  # -1 because we need t and t+1
            for time_idx in range(timesteps_in_instance):
                self.index_pairs.append((key, time_idx))

        # Compute the min and max for each channel if normalization is enabled
        if self.normalize:
            self.min_max_values = self.compute_channel_min_max()

    def __len__(self):
        return len(self.index_pairs)

    def compute_channel_min_max(self):
        """ Compute the min and max values for each channel (per variable) across all instances. """
        # Initialize min and max arrays for both input and target channels
        input_min, input_max = np.full(8, float('inf')), np.full(8, float('-inf'))  # 8 input channels (6 outputs, 2 forcing)
        target_min, target_max = np.full(6, float('inf')), np.full(6, float('-inf'))  # 6 target channels

        for _, data in self.images.items():
            output = data['output']  # shape (timesteps, 500, 6)
            forcing = data['forcing']  # shape (timesteps, 500, 2)

            # For each channel in 'output' and 'forcing', update min and max
            for i in range(6):  # output has 6 channels
                channel_data = output[:, :, i]
                input_min[i] = min(input_min[i], np.min(channel_data))
                input_max[i] = max(input_max[i], np.max(channel_data))
                target_min[i] = min(target_min[i], np.min(channel_data))
                target_max[i] = max(target_max[i], np.max(channel_data))

            for i in range(2):  # forcing has 2 channels
                channel_data = forcing[:, :, i]
                input_min[6 + i] = min(input_min[6 + i], np.min(channel_data))
                input_max[6 + i] = max(input_max[6 + i], np.max(channel_data))

        return {
            "input_min": input_min, "input_max": input_max,
            "target_min": target_min, "target_max": target_max
        }

    def normalize_per_channel(self, data, min_vals, max_vals):
        """ Normalize each channel separately using Min-Max normalization. """
        for i in range(data.shape[1]):  # Loop through each channel
            data[:, i] = (data[:, i] - min_vals[i]) / (max_vals[i] - min_vals[i] + 1e-6)  # Normalize per channel
        return data

    def __getitem__(self, idx):
        """ Returns normalized input and target. """
        instance_idx, time_idx = self.index_pairs[idx]
        current_data = self.images[instance_idx]

        # Extract input (output_t, forcing_{t+1}) and target (output_{t+1})
        output_t = current_data['output'][time_idx]  # shape (500, 6)
        forcing_t_plus_1 = current_data['forcing'][time_idx + 1]  # shape (500, 2)
        target_t_plus_1 = current_data['output'][time_idx + 1]  # shape (500, 6)

        # Concatenate output_t and forcing_t_plus_1 to form the input
        input_t = np.concatenate((output_t, forcing_t_plus_1), axis=-1)  # shape (500, 8)

        # Normalize input and target per channel if normalization is enabled
        if self.normalize:
            input_t = self.normalize_per_channel(
                input_t,
                self.min_max_values["input_min"],
                self.min_max_values["input_max"]
            )
            target_t_plus_1 = self.normalize_per_channel(
                target_t_plus_1,
                self.min_max_values["target_min"],
                self.min_max_values["target_max"]
            )

        # Convert to torch tensors
        input_t = torch.tensor(input_t, dtype=torch.float32).permute(1, 0)  # (8, 500)
        target_t_plus_1 = torch.tensor(target_t_plus_1, dtype=torch.float32).permute(1, 0)  # (6, 500)

        return input_t, target_t_plus_1


In [8]:
data = TimeSeriesPerInstanceDataset(images)

In [9]:
dataloader = DataLoader(data, batch_size=8)



In [75]:
for batch_idx, (inputs, targets) in enumerate(dataloader):
    # Inputs: Shape (batch_size, time_steps-1, coordinates, 8)
    # Targets: Shape (batch_size, time_steps-1, coordinates, 6)
    print(f"Batch {batch_idx}: Inputs shape = {inputs.shape}, Targets shape = {targets.shape}")
    break

Batch 0: Inputs shape = torch.Size([8, 8, 500]), Targets shape = torch.Size([8, 6, 500])


In [9]:
# from models import Forward, Prior, Posterior, Decoder
os.environ["KERAS_BACKEND"] = "torch"
import keras, math

In [13]:
class CNNEncoderDecoder(nn.Module):
    def __init__(self):
        super(CNNEncoderDecoder, self).__init__()

        # Increase the model capacity (more layers, more filters)
        self.encoder = nn.Sequential(
            nn.Conv1d(8, 128, kernel_size=3, padding=1),  # Increased from 64 to 128
            nn.ReLU(),
            nn.Conv1d(128, 256, kernel_size=3, padding=1),  # Increased from 128 to 256
            nn.ReLU(),
            nn.Conv1d(256, 512, kernel_size=3, padding=1),  # Increased from 256 to 512
            nn.ReLU(),
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(128, 6, kernel_size=3, padding=1),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [14]:
num_epochs = 10  
model = CNNEncoderDecoder()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0)
criterion = nn.MSELoss()


# Track the training loss over time
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0

    # Iterate over the DataLoader (get mini-batches)
    for inputs, targets in dataloader:
        # Forward pass
        outputs = model(inputs)  # Predictions
        loss = criterion(outputs, targets)  # Calculate loss
        
        # Backward pass and optimization
        optimizer.zero_grad()  # Clear previous gradients
        loss.backward()  # Backpropagate the loss
        optimizer.step()  # Update the weights

        # Accumulate loss for this epoch
        running_loss += loss.item()

    # Average loss per batch for this epoch
    epoch_loss = running_loss / len(dataloader)

    # Print loss after each epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

Epoch [1/1000], Loss: 0.0039
Epoch [2/1000], Loss: 0.0002
Epoch [3/1000], Loss: 0.0001
Epoch [4/1000], Loss: 0.0001
Epoch [5/1000], Loss: 0.0001


KeyboardInterrupt: 

In [66]:
# Before loss.backward() inside your training loop, print output stats
print("Model output min:", outputs.min().item())
print("Model output max:", outputs.max().item())
print("Target min:", targets.min().item())
print("Target max:", targets.max().item())


Model output min: 1.0
Model output max: 1.0
Target min: 0.0
Target max: 2.3947370289835698e+20
