In [None]:
%%writefile requirements.txt
numpy
notebook
torch==2.4.1
torchvision==0.19.1
torchaudio==2.4.1
lightning

In [None]:
!pip install -r requirements.txt   

<img src="graphics/E2CO_Training.png" width="1000">

In [None]:
"""
Importing All necessary Packages
"""

import os
import torch
import torch.nn as nn
import pytorch_lightning as pl

from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from typing import Optional, Tuple
from torch import load, unique

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

import warnings

In [None]:
"""
Core Training Run Options
"""

args = {
    # Core Run Options
    'SEED': 12,
    'model': 'miniegg_model',
    'results_path': '\miniegg_model\data_files',
    'deterministic': False,
    'model_type': 'EGG',
    'trail_settings': 'default',
    'version_num': 0,
    'log_every_n_steps': 10,
    'enable_progress_bar': True,

    # Samples Selection
    'n_training_samples': 6000,    # 300 * 20
    'n_validation_samples': 1000,   # 50  * 20
    'n_testing_samples': 200,      # 10  * 20

    # DataLoader Settings
    'n_epochs': 100,
    'batch_size_training': 12,
    'batch_size_validation': 4,
    'shuffled_train': True,
    'shuffled_validation': True,
    'num_workers': 0,

    # Trainer Settings
    'n_checkpoints2save': 10,  # n=-1 means all checkpoints, else save top n checkpoints
    'every_n_epochs': 10,  # Set to `10` to save the model every 10 epochs
    'max_n_gpus': 1,
    'mixed_precision': False,

    # Model Information
    'nW': 12,  # Injectors = 8 | Producers = 4 
    'nObs': 16,  # Injectors+=1 | Producers+=2

    # Latent Settings
    'l_z': 50,

    # Encoder Network Settings
    'nEncLinear': 3,

    # Decoder Network Settings
    'nDecLinear': 3,

    # Transition Network Settings
    'nTrans': 200,
    'nTransfBlock': 2,

    # Transition Outputs Network Settings
    'nTransWD': 20,
    'nTransfBlockWD': 2,

    # Loss Function Scaling (Regularization)
    'well_data_loss_scale': 1,  # 
    'weight_trans_reg': 0.001,  # 0.008296379118348702,
    'weight_trans_reg1': 0.001,  # 0.005617127794698825,
    'weight_reg_ABCD': 0.001,  # 0.002038408624236886,
    'latent_loss_scale': 0.001,

    # Scheduler and Optimizer
    'use_scheduler': True,
    'scheduler_step_size': 200,
    'scheduler_gamma': 0.5,
    'adam_learning_rate': 1e-3,
    'weight_decay': 0,  # 0.001
}

In [None]:
"""
Dataset class: Read training data --> __getitem__ does mini-batching for all data sets.
"""

class ResDataset(Dataset):
    def __init__(self, data_type, n_samples, data_files_folder):

        self.input_state = load(os.path.join(data_files_folder, data_type, "input_state.pt"))
        self.output_state = load(os.path.join(data_files_folder, data_type, "output_state.pt"))
        self.input_delta_t = load(os.path.join(data_files_folder, data_type, "input_delta_t.pt"))
        self.input_controls = load(os.path.join(data_files_folder, data_type, "input_controls.pt"))
        self.output_welldata = load(os.path.join(data_files_folder, data_type, "output_welldata.pt"))

        self.n_samples = n_samples

    def __len__(self):
        """Returns the number of samples."""
        return self.n_samples

    def __getitem__(self, idx):

        data = {
            'x': self.input_state[idx, :, :],                         # x_t = [1, nodes, features]
            'x_tp1': self.output_state[idx, :, :],                    # x_t+1 = [1, nodes, features]
            'delta_t': self.input_delta_t[idx, :],                    # dt = [1, 1]
            'u': self.input_controls[idx, :],                         # u = [1, nW]
            'y_tp1': self.output_welldata[idx, :],                    # y_t+1 [1, nC]
        }

        return data

In [None]:
"""
Data Module: Interface for accessing dataset class during training
"""

class ResDataModule(pl.LightningModule):
    def __init__(self, args, aux_data):
        super().__init__()
        self.args = args

        self.data_files_folder = aux_data["data_files_folder"]
        self.edge_indecies = aux_data["edge_indecies"]

    def prepare_data(self):
        pass

    def setup(self, stage: str = None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            self.train_data = ResDataset(data_type="training", 
                                         n_samples=self.args["n_training_samples"], 
                                         data_files_folder=self.data_files_folder)

            self.val_data = ResDataset(data_type="validation", 
                                       n_samples=self.args["n_validation_samples"], 
                                       data_files_folder=self.data_files_folder)

        if stage == 'validate':
            self.val_data = ResDataset(data_type="validation", 
                                       n_samples=self.args["n_validation_samples"], 
                                       data_files_folder=self.data_files_folder)

        if stage == 'test':
            self.test_data = ResDataset(data_type="testing", 
                                        n_samples=self.args["n_testing_samples"], 
                                        data_files_folder=self.data_files_folder)

    def train_dataloader(self):
        return DataLoader(self.train_data, 
                          batch_size=self.args["batch_size_training"], 
                          shuffle=self.args["shuffled_train"], 
                          num_workers=self.args["num_workers"])

    def val_dataloader(self):
        return DataLoader(self.val_data, 
                          batch_size=self.args["batch_size_validation"], 
                          shuffle=self.args["shuffled_validation"], 
                          num_workers=self.args["num_workers"])

    def test_dataloader(self):
        return DataLoader(self.test_data, 
                          batch_size=self.args["n_testing_samples"], 
                          num_workers=self.args["num_workers"])
    
    def get_edge_indecies(self):
        return self.edge_indecies


In [None]:
"""
Relative Perm Class
"""

class RelPerm():
    def __init__(self, a, b,
                 kr_w_max, s_w_crit, s_o_irr, 
                 kr_o_max, s_w_con, s_o_r,
                 visc_o, visc_w):
        # Fitting parameters for the relative permeability curves
        self.a = a
        self.b = b

        # Maximum relative permeabilities and saturation endpoints
        self.kr_w_max = kr_w_max
        self.s_w_crit = s_w_crit
        self.s_o_irr = s_o_irr
        self.kr_o_max = kr_o_max
        self.s_w_con = s_w_con
        self.s_o_r = s_o_r

        # Viscosities
        self.visc_o = visc_o
        self.visc_w = visc_w
    
    def send_to_device(self, device):
        self.a = self.a.to(device)
        self.b = self.b.to(device)
        self.kr_w_max = self.kr_w_max.to(device)
        self.s_w_crit = self.s_w_crit.to(device)
        self.s_o_irr = self.s_o_irr.to(device)
        self.kr_o_max = self.kr_o_max.to(device)
        self.s_w_con = self.s_w_con.to(device)
        self.s_o_r = self.s_o_r.to(device)
        self.visc_o = self.visc_o.to(device)
        self.visc_w = self.visc_w.to(device)
        return self

    def func_kro(self, sw, a):
        # Oil relative permeability function.
        return self.kr_o_max * torch.pow(((1.0 - sw - self.s_o_r) / (1.0 - self.s_w_con - self.s_o_r)), a)

    def func_krw(self, sw, b):
        # Water relative permeability function.
        return self.kr_w_max * torch.pow(((sw - self.s_w_crit) / (1.0 - self.s_w_crit - self.s_o_irr)), b)

    def k_ro(self, S_w):
        # Compute oil relative permeability with saturation clipping.
        S_w = torch.clamp(S_w, self.s_w_con, 1.0 - self.s_o_r)
        return self.func_kro(S_w, self.a)

    def k_rw(self, S_w):
        # Compute water relative permeability with saturation clipping.
        S_w = torch.clamp(S_w, self.s_w_crit + 0.001, 1.0 - self.s_o_irr)
        return self.func_krw(S_w, self.b)

    def k_ro_tf(self, S_w):
        # Compute oil relative permeability in half precision.
        S_w = torch.clamp(S_w, self.s_w_con, 1.0 - self.s_o_r)
        a_fp16 = self.a.to(torch.float16)
        return self.func_kro(S_w.to(torch.float16), a_fp16)

    def k_rw_tf(self, S_w):
        # Compute water relative permeability in half precision.
        S_w = torch.clamp(S_w, self.s_w_crit + 0.001, 1.0 - self.s_o_irr)
        b_fp16 = self.b.to(torch.float16)
        return self.func_krw(S_w.to(torch.float16), b_fp16)


In [None]:
"""
Part of the "Secret Sauce"
"""


def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
    if dim < 0:
        dim = other.dim() + dim
    if src.dim() == 1:
        for _ in range(0, dim):
            src = src.unsqueeze(0)
    for _ in range(src.dim(), other.dim()):
        src = src.unsqueeze(-1)
    src = src.expand(other.size())
    return src

def scatter_sum(src: torch.Tensor,
                index: torch.Tensor,
                dim: int = -1,
                out: Optional[torch.Tensor] = None,
                dim_size: Optional[int] = None) -> torch.Tensor:
    index = broadcast(index, src, dim)
    if out is None:
        size = list(src.size())
        if dim_size is not None:
            size[dim] = dim_size
        elif index.numel() == 0:
            size[dim] = 0
        else:
            size[dim] = int(index.max()) + 1
        out = torch.zeros(size, dtype=src.dtype, device=src.device)
        return out.scatter_add_(dim, index, src)
    else:
        return out.scatter_add_(dim, index, src)

def scatter_mean(src: torch.Tensor,
                 index: torch.Tensor,
                 dim: int = -1,
                 out: Optional[torch.Tensor] = None,
                 dim_size: Optional[int] = None) -> torch.Tensor:
    out = scatter_sum(src, index, dim, out, dim_size)
    dim_size = out.size(dim)

    index_dim = dim
    if index_dim < 0:
        index_dim = index_dim + src.dim()
    if index.dim() <= index_dim:
        index_dim = index.dim() - 1

    ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
    count = scatter_sum(ones, index, index_dim, None, dim_size)
    count[count < 1] = 1
    count = broadcast(count, out, dim)
    if out.is_floating_point():
        out.true_divide_(count)
    else:
        out.div_(count, rounding_mode='floor')
    return out

In [None]:
"""
Encoder Module: Part of the "Secret Sauce"
"""

def batch_offset_enc(edge_index, B, N, M):
    """
    Given edge_index [2*E] for one graph, returns flattened src/dst indices
    for B graphs (each with N fine nodes, M coarse nodes) after you do x.view(B*N,…).
    """
    b = torch.arange(B, device=edge_index.device)[:, None]   # (B,1)
    ei0 = (edge_index[0] + b * M).view(-1)                   # (B*E,)
    ei1 = (edge_index[1] + b * N).view(-1)                   # (B*E,)
    return ei0.long(), ei1.long()

class FineToCoarseConv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, num_pool_edges, num_coarse, custom_weight=None):
        super().__init__()
        self.num_coarse = num_coarse
        self.linear = nn.Linear(in_channels, out_channels, bias=False)
        if custom_weight is None:
            self.edge_weight = nn.Parameter(torch.empty(num_pool_edges, 1))
        else:
            self.edge_weight = nn.Parameter(custom_weight.unsqueeze(1))

        self.init_weights(custom_weight)

    def forward(self, x_fine, edge_index, B):

        x_fine = self.linear(x_fine)

        ei0, ei1 = batch_offset_enc(edge_index, B, x_fine.size(0)/B, self.num_coarse)

        weights_repeated = self.edge_weight.repeat(B, 1).contiguous()
        messages = x_fine[ei1] * weights_repeated

        x_coarse = scatter_mean(messages, ei0, dim=0, dim_size= B*self.num_coarse).contiguous()
        
        return x_coarse
    
    def init_weights(self, custom_weight):
        if custom_weight is None:
            nn.init.xavier_uniform_(self.linear.weight)
            nn.init.xavier_uniform_(self.edge_weight)
        else:
            nn.init.xavier_uniform_(self.linear.weight)

class EncoderModule(torch.nn.Module):
    def __init__(self, args, aux_data):
        super().__init__()
        self.args = args
        self.aux_data = aux_data

        self.edges = aux_data["edge_indecies"]
        self.nEncLinear = args["nEncLinear"]
        self.edges_custom = aux_data["perm_edges"]

        self.conv1 = FineToCoarseConv(  2,  32, num_pool_edges=aux_data["num_pool_edges"][0], num_coarse=aux_data["num_coarse"][0])
        self.conv2 = FineToCoarseConv( 32,  64, num_pool_edges=aux_data["num_pool_edges"][1], num_coarse=aux_data["num_coarse"][1])
        self.conv3 = FineToCoarseConv( 64, 128, num_pool_edges=aux_data["num_pool_edges"][2], num_coarse=aux_data["num_coarse"][2])
        self.conv4 = FineToCoarseConv(128, 128, num_pool_edges=aux_data["num_pool_edges"][3], num_coarse=aux_data["num_coarse"][3])
        self.conv5 = FineToCoarseConv(128, 128, num_pool_edges=aux_data["num_pool_edges"][4], num_coarse=aux_data["num_coarse"][4])

        self.ln1 = nn.LayerNorm(32)
        self.ln2 = nn.LayerNorm(64)
        self.ln3 = nn.LayerNorm(128)
        self.ln4 = nn.LayerNorm(128)
        self.ln5 = nn.LayerNorm(128)

        self.silu = torch.nn.SiLU()

        self.final_linear = nn.Linear(aux_data["num_coarse"][-1] * 128, self.args["l_z"])

        self.init_weights()
    
    def stack_graphs_enc(self, x: torch.Tensor) -> torch.Tensor:
        B, N, F = x.size()
        out = x.new_empty(B * N, F)
        for i in range(B):
            out[i * N : (i + 1) * N] = x[i]
        return out
    
    def unstack_flat_enc(self, x: torch.Tensor, B: int, N: int, F: int) -> torch.Tensor:
        out = x.new_empty(B, N * F)
        for i in range(B):
            block = x[i * N : (i + 1) * N]       # shape [N, F]
            out[i] = block.reshape(N * F)        # shape [N*F]
        return out

    def forward(self, x):

        edges = self.edges

        B = len(x)
        x = x.contiguous()
        x = self.stack_graphs_enc(x).to(x)

        x = self.conv1(x, edges[0].to(x), B)
        x = self.ln1(x)
        x = self.silu(x)
        x = self.conv2(x, edges[1].to(x), B)
        x = self.ln2(x)
        x = self.silu(x)
        x = self.conv3(x, edges[2].to(x), B)
        x = self.ln3(x)
        x = self.silu(x)
        x = self.conv4(x, edges[3].to(x), B)
        x = self.ln4(x)
        x = self.silu(x)
        x = self.conv5(x, edges[4].to(x), B)
        x = self.ln5(x)
        x = self.silu(x)

        x = x.contiguous()
        x = self.unstack_flat_enc(x, B=B, N=self.aux_data["num_coarse"][-1], F=128).to(x)

        x = self.final_linear(x)

        return x
    
    def init_weights(self):
        nn.init.xavier_uniform_(self.final_linear.weight)


In [None]:
"""
Decoder Module: Part of the "Secret Sauce"
"""

def batch_offset_dec(edge_index, B, N, M):
    """
    Given edge_index [2*E] for one graph, returns flattened src/dst indices
    for B graphs (each with N fine nodes, M coarse nodes) after you do x.view(B*N,…).
    """
    b = torch.arange(B, device=edge_index.device)[:, None]   # (B,1)
    ei0 = (edge_index[0] + b * N).view(-1)                   # (B*E,)
    ei1 = (edge_index[1] + b * M).view(-1)                   # (B*E,)
    return ei0.long(), ei1.long()

class CoarseToFineConv(nn.Module):
    def __init__(self, in_channels, out_channels, num_pool_edges, num_fine, non_relu_init=False):
        super().__init__()
        self.num_fine = num_fine
        self.linear = nn.Linear(in_channels, out_channels, bias=False)
        self.edge_weight = nn.Parameter(torch.empty(num_pool_edges, 1))

        self.init_weights(non_relu_init)

    def forward(self, x_coarse, edge_index, B):

        x_coarse = self.linear(x_coarse)

        ei0, ei1 = batch_offset_dec(edge_index, B, x_coarse.size(0)/B, self.num_fine)

        weights_repeated = self.edge_weight.repeat(B, 1).contiguous()
        messages = x_coarse[ei0] * weights_repeated

        x_fine = scatter_mean(messages, ei1, dim=0, dim_size=(self.num_fine)*B).contiguous()
        
        return x_fine
    
    def init_weights(self, non_relu_init):
        if non_relu_init:
            nn.init.xavier_uniform_(self.linear.weight)
            nn.init.xavier_uniform_(self.edge_weight)
        else:
            nn.init.xavier_uniform_(self.linear.weight)
            nn.init.xavier_uniform_(self.edge_weight)


class DecoderModule(torch.nn.Module):
    def __init__(self, args, aux_data):
        super().__init__()
        self.args = args
        self.aux_data = aux_data

        self.edges = aux_data["edge_indecies"]
        self.nDecLinear = args["nDecLinear"]

        self.conv5 = CoarseToFineConv(128, 128, num_pool_edges=aux_data["num_pool_edges"][4],  num_fine=aux_data["num_fine"][4])
        self.conv4 = CoarseToFineConv(128, 128, num_pool_edges=aux_data["num_pool_edges"][3],  num_fine=aux_data["num_fine"][3])
        self.conv3 = CoarseToFineConv(128,  64, num_pool_edges=aux_data["num_pool_edges"][2],  num_fine=aux_data["num_fine"][2])
        self.conv2 = CoarseToFineConv( 64,  32, num_pool_edges=aux_data["num_pool_edges"][1],  num_fine=aux_data["num_fine"][1])
        self.conv1 = CoarseToFineConv( 32,   2, num_pool_edges=aux_data["num_pool_edges"][0],  num_fine=aux_data["num_fine"][0]
                                      , non_relu_init=True)

        self.ln5 = nn.LayerNorm(128)
        self.ln4 = nn.LayerNorm(128)
        self.ln3 = nn.LayerNorm(64)
        self.ln2 = nn.LayerNorm(32)
        self.ln1 = nn.LayerNorm(2)

        self.silu = torch.nn.SiLU()

        self.initial_linear = nn.Linear(self.args["l_z"], aux_data["num_coarse"][-1] * 128)

        self.init_weights()
    
    def unstack_graphs_dec(self, x: torch.Tensor, B: int, N: int) -> torch.Tensor:
        F = x.size(1)
        out = x.new_empty(B, N, F)
        for i in range(B):
            out[i] = x[i * N : (i + 1) * N]
        return out
    
    def stack_flat_dec(self, x: torch.Tensor, B: int, N: int, F: int) -> torch.Tensor:
        return x.reshape(B, N, F).reshape(B * N, F)
    
    def forward(self, z):

        edges = self.edges

        z = self.initial_linear(z)
        z = self.silu(z)

        B = len(z)

        z = z.contiguous()
        z = self.stack_flat_dec(z, B=B, N=self.aux_data["num_coarse"][-1], F=128).to(z)

        z = self.conv5(z, edges[4].to(z), B)
        z = self.ln5(z)
        z = self.silu(z)
        z = self.conv4(z, edges[3].to(z), B)
        z = self.ln4(z)
        z = self.silu(z)
        z = self.conv3(z, edges[2].to(z), B)
        z = self.ln3(z)
        z = self.silu(z)
        z = self.conv2(z, edges[1].to(z), B)
        z = self.ln2(z)
        z = self.silu(z)
        z = self.conv1(z, edges[0].to(z), B)

        z = z.contiguous()
        z = self.unstack_graphs_dec(z, B=B, N=self.aux_data["num_fine"][0]).to(z)
        z = torch.sigmoid(5.0 * (z - 0.5))*1.15

        return z
    
    def init_weights(self):        
        nn.init.xavier_uniform_(self.initial_linear.weight)



In [None]:
"""
Transition Module: Core "E2CO" concept
"""

class TransitionModule(pl.LightningModule):
    def __init__(self, args):
        super(TransitionModule, self).__init__()

        # Parameters
        self.nW = args["nW"]

        # Hyperparameters
        self.l_z = args["l_z"]
        self.nTrans = args["nTrans"]
        self.nTransfBlock = args["nTransfBlock"]

        # Layers
        self.linear_layers = nn.ModuleList([
            nn.Linear(self.l_z + 1 if i == 0 else self.nTrans, self.nTrans) 
            for i in range(self.nTransfBlock)
        ])
        self.linear_z_A = nn.Linear(self.nTrans, self.l_z  * self.l_z)
        self.linear_z_B = nn.Linear(self.nTrans, self.l_z  * self.nW)

        # Regularization Activations
        self.z_A_activations = None
        self.z_B_activations = None

        # Weight Initialization
        self.initialize_weights()

    def forward(self, z, delta_t, u):
        # Concat time to latent state
        z_del_t = torch.cat((z, delta_t), dim=1).to(z)

        # Transformation Structure
        for i in range(self.nTransfBlock):
            z_del_t = self.linear_layers[i](z_del_t)
            z_del_t = torch.relu(z_del_t)

        # Step Matrices calculation (flat)
        z_A = self.linear_z_A(z_del_t)
        z_B = self.linear_z_B(z_del_t)

        # Activity Regularization Equivalent
        self.z_A_activations = z_A.abs().sum()
        self.z_B_activations = z_B.abs().sum()

        # Step Matrices reshape
        z_A = z_A.view(-1, self.l_z, self.l_z).to(z)
        z_B = z_B.view(-1, self.l_z, self.nW).to(z)

        # Ensuing control batch size match
        u = u.view(-1, self.nW).to(z)

        # MOR Computations
        u_dt = torch.mul(delta_t, u).to(z)
        sysmodel = torch.bmm(z_A, z.unsqueeze(-1)).squeeze(-1).to(z)
        syscontrol = torch.bmm(z_B, u_dt.unsqueeze(-1)).squeeze(-1).to(z)
        z_tp1_hat = torch.add(sysmodel, syscontrol).to(z)

        return z_tp1_hat

    def initialize_weights(self):
        for linear_layer in self.linear_layers:
            nn.init.xavier_uniform_(linear_layer.weight)
        nn.init.xavier_uniform_(self.linear_z_A.weight)
        nn.init.xavier_uniform_(self.linear_z_B.weight)


In [None]:
"""
Transition Output Module: Core "E2CO" concept
"""

class TransitionOutputModule(pl.LightningModule):
    def __init__(self, args):
        super(TransitionOutputModule, self).__init__()

        # Parameters
        self.nW = args["nW"]
        self.nObs = args["nObs"]

        # Hyperparameters
        self.l_z = args["l_z"]
        self.nTransWD = args["nTransWD"]
        self.nTransfBlockWD = args["nTransfBlockWD"]

        # Layers
        self.linear_layers = nn.ModuleList([
            nn.Linear(self.l_z + 1 if i == 0 else self.nTransWD, self.nTransWD) 
            for i in range(self.nTransfBlockWD)
        ])
        self.linear_z_C = nn.Linear(self.nTransWD, self.nObs * self.l_z)
        self.linear_z_D = nn.Linear(self.nTransWD, self.nObs * self.nW)

        # Regularization Activations
        self.z_C_activations = None
        self.z_D_activations = None

        # Weight Initialization
        self.initialize_weights()

    def forward(self, z, delta_t, u, z_tp1_hat):
        # Concat time to latent state
        z_del_t = torch.cat((z, delta_t), dim=1).to(z)
        
        # Transformation Structure
        for i in range(self.nTransfBlockWD):
            z_del_t = self.linear_layers[i](z_del_t)
            z_del_t = torch.relu(z_del_t)

        # Step Matrices calculation (flat)
        z_C = self.linear_z_C(z_del_t)
        z_D = self.linear_z_D(z_del_t)

        # Activity Regularization Equivalent
        self.z_C_activations = z_C.abs().sum()
        self.z_D_activations = z_D.abs().sum()

        # Step Matrices reshape
        z_C = z_C.view(-1, self.nObs, self.l_z).to(z)
        z_D = z_D.view(-1, self.nObs, self.nW).to(z)

        # Ensuing control batch size match
        u = u.view(-1, self.nW).to(z)

        # MOR Computations
        u_dt = torch.mul(delta_t, u).to(z)
        sysmodel = torch.bmm(z_C, z_tp1_hat.unsqueeze(-1)).squeeze(-1).to(z)
        syscontrol = torch.bmm(z_D, u_dt.unsqueeze(-1)).squeeze(-1).to(z)
        y_tp1_hat = torch.add(sysmodel, syscontrol).to(z)

        return y_tp1_hat

    def initialize_weights(self):
        for linear_layer in self.linear_layers:
            nn.init.xavier_uniform_(linear_layer.weight)
        nn.init.xavier_uniform_(self.linear_z_C.weight)
        nn.init.xavier_uniform_(self.linear_z_D.weight)



In [None]:
"""
Custom flexible wrapper for dynamic loss type selection
"""

class CustomLosses(nn.Module):
    """
        # Defining Losses
        self.euclid_loss = euclidean_loss()
        self.euclid_loss_mask = euclidean_loss_mask()
        self.l1_reg_loss = l1_reg_loss()
        self.loss_phys_grad_1ph = loss_flux_singlePhase()
        self.loss_phys_grad_2ph = loss_flux_twoPhase()
        self.well_data_loss = well_data_loss()  # might need to do the hack here as well
        self.prod_sat_qw_loss = loss_prod_sat_qw()
        self.loss_qw_neg = loss_qw_neg()
        self.loss_wells_gridblock = loss_wells_gridblock()
    """
    def __init__(self, device, norms):
        super(CustomLosses, self).__init__()
        # Device
        self.device = device

        # Store norms for normalization
        self.norms = norms

        # Dictionary mapping loss type names to functions
        self.loss_functions = {
            'mse': self.mse_loss,
            'l1': self.l1_loss,
            'euclid': self.euclid_loss,
            'latent': self.latent_loss,
            'l1_reg_loss': self.l1_reg_loss,
            'l1_reg_loss_dual': self.l1_reg_loss_dual,
            'l2_norm_loss': self.l2_norm_loss,
            'l2_norm_loss_sa': self.l2_norm_loss_sa,
            'l1_l2_mix_loss': self.l1_l2_mix_loss,
            #'flux_loss': self.flux_loss,
        }

    def mse_loss(self, x, y, reduction='sum', **kwargs):

        loss = nn.MSELoss(reduction=reduction)

        return loss(x, y)
    
    def l1_loss(self, x, y, reduction='sum', **kwargs):

        loss = torch.nn.L1Loss(reduction=reduction)

        return loss(x, y)
    
    def euclid_loss(self, x, y):
        
        diff_square = torch.pow(torch.flatten(y) - torch.flatten(x), 2)
        loss = torch.sum(torch.sqrt(torch.sum(diff_square, dim=-1)))

        return loss
    
    def l2_norm_loss(self, x, y):

        return torch.norm(x - y, p=2)
    
    def l2_norm_loss_sa(self, x, y, sa_w):
        s = sa_w
        
        norm_p = torch.sum(torch.sqrt(torch.sum(torch.square(x[:, :, 0] - y[:, :, 0]), dim=1))).contiguous()
        norm_s = torch.sum(torch.sqrt(torch.sum(torch.square(s * (x[:, :, 1] - y[:, :, 1])), dim=1))).contiguous()

        norm = norm_p + norm_s

        return norm
    
    def l1_l2_mix_loss(self, x, y):
        
        xp, xs = torch.split(x, 1, dim=-1)
        yp, ys = torch.split(y, 1, dim=-1)
        
        p_loss = self.l2_norm_loss(xp, yp)
        s_loss = self.l1_loss(xs, ys)

        return p_loss + s_loss
    
    def latent_loss(self, x, **kwargs):

        loss = torch.sum(x ** 2)

        return loss
    
    def l1_reg_loss(self, qm, **kwargs):

        qm = qm.view(qm.size(0), -1).contiguous()
        loss = torch.norm(qm, p=1, dim=-1)
        loss = torch.sum(loss)

        return loss
    
    def l1_reg_loss_dual(self, x, y, **kwargs):

        loss = (self.l1_reg_loss(x) + self.l1_reg_loss(y)) / 2.0

        return loss

    def forward(self, x1, y1=None, x2=None, y2=None, loss_type='mse', **kwargs):
        """ 
        Args:
            x (torch.Tensor): Model predictions.
            y (torch.Tensor): Ground truth values.
            loss_type (str): The type of loss to compute (e.g., 'mse', 'l1').
        
        Returns:
            torch.Tensor: The computed loss.
        """
        if loss_type not in self.loss_functions:
            raise ValueError(f"Unsupported loss type: {loss_type}")
        
        loss_fn = self.loss_functions[loss_type]
        # Check if targets is provided. If not, call loss function with outputs only.
        if y1 is not None and x2 is not None and y2 is not None:
            return loss_fn(x1, y1, x2, y2, **kwargs)
        elif y1 is not None and x2 is not None:
            return loss_fn(x1, y1, x2, **kwargs)
        elif y1 is not None:
            return loss_fn(x1, y1, **kwargs)
        else:
            return loss_fn(x1, **kwargs)

In [None]:
"""
Main Lightning class: Let's put it all together!
"""

class LightningE2CO(pl.LightningModule):

    def __init__(self, args, aux_data):
        super(LightningE2CO, self).__init__()
        # Base Parameters
        self.args = args
        self.aux_data = aux_data

        # Defining Networks
        self.Encoder = EncoderModule(self.args, self.aux_data)
        self.Decoder = DecoderModule(self.args, self.aux_data)
        self.Transition = TransitionModule(self.args)
        self.TransitionOutput = TransitionOutputModule(self.args)

        # Defining Loss Class
        self.loss_module = CustomLosses(device=self.device, norms=self.aux_data['norms'])

        # Model outputs for validations
        self.test_outputs = []

        # SA-Weights Stuff
        self.automatic_optimization = True  # Enable manual optimization
        self.sa_w = None

    def forward(self, batch):  # Need to add another input (Settings)

        # Unpack
        x_t, x_tp1, delta_t, u, y_tp1 = batch.values()
        batch_size = torch.tensor(len(x_t))

        # Prediction -> Calculation of x_tp1_hat
        z_t = self.Encoder(x=x_t)

        # Transition -> Calculation of z_tp1_hat
        z_tp1_hat = self.Transition(z=z_t, delta_t=delta_t, u=u)

        x_t_hat = self.Decoder(z=z_tp1_hat)

        # Calculation of z_tp1
        z_tp1 = self.Encoder(x=x_tp1)

        # Reconstruction -> Calculation of x_tp1_hat and x_t_hat
        x_tp1_hat = self.Decoder(z=z_tp1)

        # Calculation of y_tp1_hat
        y_tp1_hat = self.TransitionOutput(z=z_t, delta_t=delta_t, u=u, z_tp1_hat=z_tp1)

        # y_tp1_hat = None

        return (x_t, x_t_hat, x_tp1, x_tp1_hat, z_t, z_tp1, z_tp1_hat, y_tp1, y_tp1_hat, batch_size)

    def training_step(self, batch, batch_idx):

        # Main forward Step (Entire network)
        batch = self.forward(batch)

        # Calculate Losses
        losses = self.calculate_losses(batch, batch_idx, self.sa_w)

        # Log Losses
        for loss_name, loss_value in losses.items():
            self.log(f'{loss_name}_Train', loss_value, prog_bar=True)

        return losses['Loss_Total']

    def validation_step(self, batch, batch_idx):

        # Main forward Step (Entire network)
        batch = self.forward(batch)

        # Calculate Losses
        losses = self.calculate_losses(batch, batch_idx, self.sa_w)

        # Log Losses
        for loss_name, loss_value in losses.items():
            self.log(f'{loss_name}_Validation', loss_value, prog_bar=True)

        return losses['Loss_Total']

    def calculate_losses(self, batch, batch_idx, sa_weights):

        x_t, x_t_hat, x_tp1, x_tp1_hat, z_t, z_tp1, z_tp1_hat, y_tp1, y_tp1_hat, batch_size = batch

        # The Main Losses
        loss_reconstruction = self.loss_module(x_t_hat, x_t, loss_type='l2_norm_loss')
        loss_prediction = self.loss_module(x_tp1_hat, x_tp1, loss_type='l2_norm_loss')
        loss_transition = self.loss_module(z_tp1, z_tp1_hat, loss_type='l2_norm_loss')

        # Well Data Loss
        loss_well_data = self.loss_module(y_tp1_hat, y_tp1, loss_type='l2_norm_loss')  # * self.args['well_data_loss_scale']

        # Latent Space minimization loss
        loss_latent = self.args['latent_loss_scale'] * self.loss_module(z_t, loss_type='latent')

        # Combined Latent Space (Activity Regularization) minimization loss (ABCD)
        if self.args['weight_reg_ABCD'] > 0:
            loss_ABCD = self.Transition.z_A_activations + \
                        self.Transition.z_B_activations + \
                        self.TransitionOutput.z_C_activations + \
                        self.TransitionOutput.z_D_activations
            loss_ABCD = loss_ABCD * self.args['weight_reg_ABCD']
        else:
            loss_ABCD = 0

        # Transition Regularization Loss
        if self.args['weight_trans_reg'] > 0:
            loss_trans_reg = self.loss_module(z_t, loss_type='l1_reg_loss') * self.args['weight_trans_reg']
        else:
            loss_trans_reg = 0

        # Transition Regularization Loss 1
        if self.args['weight_trans_reg1'] > 0:
            loss_trans_reg1 = self.loss_module(z_tp1_hat, z_tp1, loss_type='l1_reg_loss_dual') * self.args['weight_trans_reg1']
        else:
            loss_trans_reg1 = 0

        # Computing Total Loss
        loss_total = loss_reconstruction + loss_prediction + loss_transition + \
            loss_latent + loss_ABCD + loss_trans_reg + loss_trans_reg1 + loss_well_data
        
        # # Computing Total Loss
        # loss_NON_Reg = loss_reconstruction + loss_prediction + loss_transition
                                
        # Normalizaation and Packaging
        losses = {
            'Reconstruction': loss_reconstruction / batch_size,
            'Prediction': loss_prediction / batch_size,
            'Transition': loss_transition / batch_size,
            'Well': loss_well_data / batch_size,
            'Latent': loss_latent / batch_size,
            'ABCD': loss_ABCD / batch_size,
            'TransReg': loss_trans_reg / batch_size,
            'TransReg1': loss_trans_reg1 / batch_size,
            # 'Flux': loss_flux / batch_size,
            'Loss_Total': loss_total / batch_size,
            # 'Loss_NON_Reg': loss_NON_Reg / batch_size,
        }

        return losses
    
    def test_step(self, batch, batch_idx):

        x_t, x_tp1, delta_t, u, y_tp1 = batch.values()
        batch_size = torch.tensor(len(x_t))

        samples = int(x_t.shape[0]/20)
        nodes = int(x_t.shape[1])
        times_steps = 20
        features = 2

        x_t = x_t.view(samples, times_steps, nodes, features)
        x_tp1 = x_tp1.view(samples, times_steps, 18553, features)
        u = u.view(samples, times_steps, 12)
        y_tp1 = y_tp1.view(samples, times_steps, 16)
        delta_t = delta_t.view(samples, times_steps, 1)

        pred_x_tp1_hat = torch.empty(samples, times_steps+1, nodes, features).to(self.device)
        pred_z_t_hat = torch.empty(samples, times_steps+1, self.args["l_z"]).to(self.device)  # lz change here
        pred_y_tp1_hat = torch.empty(samples, times_steps, 16).to(self.device)
        orig_y_tp1_hat = y_tp1.to(self.device)
        orig_x_tp1_hat = torch.cat([x_t[:, 0:1, :, :], x_tp1], dim=1).to(self.device)

        for sample in range(samples):
            initial_state = x_t[sample, 0:1, :, :]
            z_t_hat = self.Encoder(x=initial_state)
            pred_z_t_hat[sample, 0, :] = z_t_hat
            pred_x_tp1_hat[sample, 0, :, :] = initial_state

            for time_step in range(1, times_steps+1):
                u_in = u[sample, time_step-1:time_step, :]
                delta_t_in = delta_t[sample, time_step-1:time_step, :]
                z_tp1_hat = self.Transition(z=pred_z_t_hat[sample, time_step-1:time_step, :], delta_t=delta_t_in, u=u_in)
                pred_z_t_hat[sample, time_step, :] = z_tp1_hat

                y_tp1_hat = self.TransitionOutput(z=pred_z_t_hat[sample, time_step-1:time_step, :], delta_t=delta_t_in, u=u_in, z_tp1_hat=z_tp1_hat)
                pred_y_tp1_hat[sample, time_step-1, :] = y_tp1_hat

                x_tp1_hat = self.Decoder(z=z_tp1_hat)
                pred_x_tp1_hat[sample, time_step, :, :] = x_tp1_hat
        
        results = (pred_x_tp1_hat, orig_x_tp1_hat, pred_y_tp1_hat, orig_y_tp1_hat, pred_z_t_hat)

        self.test_outputs = results

        return

    def configure_optimizers(self):
        params = list(self.Encoder.parameters()) + list(self.Decoder.parameters()) + list(self.Transition.parameters())
        optimizer = torch.optim.Adam(params, lr=self.args['adam_learning_rate'], weight_decay=self.args['weight_decay'])
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args['scheduler_step_size'], gamma=self.args['scheduler_gamma'])
        return [optimizer], [scheduler]

    def on_test_epoch_end(self):
        torch.save(self.test_outputs, rf"{self.args['results_path']}\test_outputs.pt")
        return self.test_outputs
    
    def on_train_batch_end(self, batch, batch_idx, dataloader_idx=None):
        lr = self.trainer.optimizers[0].param_groups[0]['lr']
        self.log("lr", lr, prog_bar=True)
    
    # def on_train_epoch_end(self):
    #     lr = self.trainer.optimizers[0].param_groups[0]['lr']
    #     print(f"Epoch {self.current_epoch} Done | LR={lr}")


<img src="graphics/E2CO_infer.png" width="800">

In [None]:
"""
Run Training
"""

# Filter Future Warnings
warnings.filterwarnings("ignore", category=FutureWarning)

# SEED
pl.seed_everything(args['SEED'])

# Initial Settings
model_name = 'miniegg_model'
current_directory = os.getcwd()
data_files_folder = os.path.join(current_directory, model_name, 'data_files', 'tensor_data')

# Loading Auxillary Data
# Coarsening Edge Indecies
edge_indecies = []
num_pool_edges = []
num_coarse = []
num_fine = []
for i in [1, 2, 3, 4, 5]:
    vals = load(os.path.join(data_files_folder, 'common', f'edge{i}.pt'))
    edge_indecies.append(vals)
    num_pool_edges.append(vals.size(dim=1))
    num_coarse.append(unique(vals[0, :]).size(dim=0))
    num_fine.append(unique(vals[1, :]).size(dim=0))
    
# Original Edge Index
original_edge = load(os.path.join(data_files_folder, 'testing', 'edge_index.pt')).permute(1, 0)

# Node Coordinates
node_coords = load(os.path.join(data_files_folder, 'common', 'node_coords.pt'))
# node_coords = None
    
# Permability Field Data
perm = load(os.path.join(data_files_folder, 'common', 'permi.pt'))
perm_normalized = (perm - perm.min()) / (perm.max() - perm.min())
perm_activations = perm_normalized[edge_indecies[0][1]]
perm_edges = perm[original_edge[0][1]]

# Relative Permeability Data
krel_data = load(os.path.join(data_files_folder, 'common', 'krel_data.pt'))
krel = RelPerm(
    a=krel_data['a'],
    b=krel_data['b'],
    kr_w_max=krel_data['kr_w_max'],
    s_w_crit=krel_data['s_w_crit'],
    s_o_irr=krel_data['s_o_irr'],
    kr_o_max=krel_data['kr_o_max'],
    s_w_con=krel_data['s_w_con'],
    s_o_r=krel_data['s_o_r'],
    visc_o=krel_data['visc_o'],
    visc_w=krel_data['visc_w'],
)
# krel = None

# Normalizations
norms = load(os.path.join(data_files_folder, 'common', 'norms.pt'))

aux_DATALOADER_data = {
    "edge_indecies": edge_indecies,
    "data_files_folder": data_files_folder,
}

aux_MODEL_data = {
    "perm": perm,
    "perm_activations": perm_activations,
    "krel": krel,
    "norms": norms,
    "node_coords": node_coords,
    "original_edge": original_edge,
    "perm_edges": perm_edges,
    "num_pool_edges" : num_pool_edges,
    "num_coarse" : num_coarse,
    "num_fine" : num_fine,
    "edge_indecies": edge_indecies,
}

# Data Loader Init
res_training_data = ResDataModule(args, aux_DATALOADER_data)

# Model Init
# ckpt_path=os.path.join(current_directory, model_name, "checkpoints", "EGG_LAST", "s30_lr1e3", "0", "epoch=498.ckpt")
# model = LightningE2CO.load_from_checkpoint(ckpt_path, args=args, aux_data=aux_MODEL_data)
model = LightningE2CO(args=args, aux_data=aux_MODEL_data)  # Initialize model

# Setting up checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(current_directory, model_name, 'checkpoints', args['model_type'], args['trail_settings'], str(args['version_num'])),
    filename='{epoch}',  # Filename will include the epoch
    save_top_k=args["n_checkpoints2save"],
    monitor="Loss_Total_Validation",
    mode="min",
)

# Setting up TensorBoard logger
logger_dir = os.path.join(current_directory, model_name, f"log_{args['model_type']}")
logger = TensorBoardLogger(logger_dir, name=f"{args['model_type']}_{args['trail_settings']}_v{args['version_num']}")

trainer = Trainer(
    max_epochs=args["n_epochs"],
    accelerator="gpu", devices="auto",
    callbacks=checkpoint_callback,
    logger=logger,
    deterministic=args["deterministic"],
    enable_progress_bar=args["enable_progress_bar"],
    log_every_n_steps=args["log_every_n_steps"],
)

trainer.fit(model, datamodule=res_training_data)  # Train the model

In [None]:
"""
Run Tensorboard Logger for monitoring
"""

!tensorboard --logdir=miniegg_model/log_EGG