# Flow Matching for Darcy Flow - CORRECTED IMPLEMENTATION

## Key Fixes Applied:

### 1. **Correct Flow Matching Formulation**
- **OLD (WRONG)**: `x_t = (1-t)*a + t*u` - interpolating from permeability to solution ❌
- **NEW (CORRECT)**: `x_t = (1-t)*noise + t*u` - interpolating from noise to solution ✅
- **Conditioning**: Permeability `a(x)` conditions the entire process via FiLM

### 2. **Proper Velocity Target**
- **OLD**: `v_target = u - a` (solution minus permeability - makes no physical sense!) ❌
- **NEW**: `v_target = u - noise` (denoising direction) ✅

### 3. **Correct ODE Integration**
- **OLD**: Start from permeability `x_0 = a` ❌
- **NEW**: Start from Gaussian noise `x_0 ~ N(0,I)` ✅

### 4. **Enhanced Architecture**
- Added **Spectral Convolutions** from FNO for better physics learning
- **FiLM conditioning** properly applied at each resolution level
- Time and condition embeddings guide the denoising process

### 5. **Why This Works**
The model learns: **"Given permeability a(x), generate the corresponding solution u(x)"**
- Training: Learn velocity field that denoises `noise → u` conditioned on `a`
- Inference: Integrate ODE from `noise → predicted_u` given `a`

This is the standard conditional generation framework used in diffusion models, adapted for PDE solving.

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

import matplotlib.pyplot as plt

import operator
from functools import reduce
from functools import partial

from timeit import default_timer
from utilities3 import *

from Adam import Adam

torch.manual_seed(0)
np.random.seed(0)

TRAIN_PATH = 'darcy/piececonst_r421_N1024_smooth1.mat'
TEST_PATH = 'darcy/piececonst_r421_N1024_smooth2.mat'

ntrain = 1000
ntest = 100

batch_size = 20
learning_rate = 0.001

epochs = 200
step_size = 100
gamma = 0.5

modes = 12
width = 32

r = 5
h = int(((421 - 1)/r) + 1)
s = h

reader = MatReader(TRAIN_PATH)
x_train = reader.read_field('coeff')[:ntrain,::r,::r][:,:s,:s]
y_train = reader.read_field('sol')[:ntrain,::r,::r][:,:s,:s]

reader.load_file(TEST_PATH)
x_test = reader.read_field('coeff')[:ntest,::r,::r][:,:s,:s]
y_test = reader.read_field('sol')[:ntest,::r,::r][:,:s,:s]

x_normalizer = UnitGaussianNormalizer(x_train)
x_train = x_normalizer.encode(x_train)
x_test = x_normalizer.encode(x_test)

y_normalizer = UnitGaussianNormalizer(y_train)
y_train = y_normalizer.encode(y_train)

x_train = x_train.reshape(ntrain,s,s,1)
x_test = x_test.reshape(ntest,s,s,1)
y_train = y_train.reshape(ntrain,s,s,1)
y_test = y_test.reshape(ntest,s,s,1)

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)


In [13]:
epochs = 50

In [None]:
from torchdiffeq import odeint

# ============================================================================
# Spectral Convolution Layer (from FNO)
# ============================================================================
class SpectralConv2d(nn.Module):
    """2D Fourier layer - learns in spectral domain"""
    def __init__(self, in_channels, out_channels, modes1, modes2):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1
        self.modes2 = modes2

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))

    def compl_mul2d(self, input, weights):
        # Complex multiplication
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        B, C, H, W = x.shape
        
        # Compute FFT
        x_ft = torch.fft.rfft2(x)
        
        # Multiply relevant Fourier modes
        out_ft = torch.zeros(B, self.out_channels, H, W//2 + 1, dtype=torch.cfloat, device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)
        
        # Return to physical space
        x = torch.fft.irfft2(out_ft, s=(H, W))
        return x


# ============================================================================
# Time Embedding Module
# ============================================================================
class TimeEmbedding(nn.Module):
    """Sinusoidal time embedding followed by MLP"""
    def __init__(self, dim=128, mlp_dim=256):
        super().__init__()
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.SiLU(),
            nn.Linear(mlp_dim, mlp_dim)
        )
    
    def forward(self, t):
        # t: (B,) or (B, 1)
        if t.dim() == 1:
            t = t.unsqueeze(-1)
        
        # Sinusoidal embeddings
        half_dim = self.dim // 2
        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        
        return self.mlp(emb)  # (B, mlp_dim)


# ============================================================================
# FiLM MLP: Generates scale and shift parameters
# ============================================================================
class FiLMMLP(nn.Module):
    """Takes concatenated [time_emb, condition_emb] and outputs FiLM parameters"""
    def __init__(self, time_dim=256, cond_dim=256, hidden_dim=256, num_resolutions=4, channels_per_resolution=[64, 128, 256, 512]):
        super().__init__()
        self.num_resolutions = num_resolutions
        self.channels_per_resolution = channels_per_resolution
        
        input_dim = time_dim + cond_dim
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Separate heads for each resolution level
        self.film_heads = nn.ModuleList([
            nn.Linear(hidden_dim, 2 * channels)  # 2x for scale and shift
            for channels in channels_per_resolution
        ])
    
    def forward(self, time_emb, cond_emb):
        # time_emb: (B, time_dim), cond_emb: (B, cond_dim)
        x = torch.cat([time_emb, cond_emb], dim=-1)
        x = self.mlp(x)
        
        # Generate FiLM parameters for each resolution
        film_params = []
        for head in self.film_heads:
            params = head(x)  # (B, 2*C)
            film_params.append(params)
        
        return film_params


# ============================================================================
# Conditional Encoder: Encodes permeability field a(x)
# ============================================================================
class ConditionEncoder(nn.Module):
    """Multi-scale CNN encoder for conditioning permeability field"""
    def __init__(self, in_channels=1, base_channels=64):
        super().__init__()
        
        # Downsample path with skip connections
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, 3, padding=1),
            nn.GroupNorm(8, base_channels),
            nn.SiLU()
        )
        self.down1 = nn.Conv2d(base_channels, base_channels, 3, stride=2, padding=1)
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels*2, 3, padding=1),
            nn.GroupNorm(8, base_channels*2),
            nn.SiLU()
        )
        self.down2 = nn.Conv2d(base_channels*2, base_channels*2, 3, stride=2, padding=1)
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(base_channels*2, base_channels*4, 3, padding=1),
            nn.GroupNorm(8, base_channels*4),
            nn.SiLU()
        )
        self.down3 = nn.Conv2d(base_channels*4, base_channels*4, 3, stride=2, padding=1)
        
        self.conv4 = nn.Sequential(
            nn.Conv2d(base_channels*4, base_channels*8, 3, padding=1),
            nn.GroupNorm(8, base_channels*8),
            nn.SiLU()
        )
        
        # Global pooling for global latent
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.global_fc = nn.Linear(base_channels*8, 256)
    
    def forward(self, a):
        # a: (B, 1, H, W)
        skips = []
        
        x1 = self.conv1(a)
        skips.append(x1)
        x1 = self.down1(x1)
        
        x2 = self.conv2(x1)
        skips.append(x2)
        x2 = self.down2(x2)
        
        x3 = self.conv3(x2)
        skips.append(x3)
        x3 = self.down3(x3)
        
        x4 = self.conv4(x3)
        skips.append(x4)
        
        # Global latent
        global_feat = self.global_pool(x4).squeeze(-1).squeeze(-1)
        global_latent = self.global_fc(global_feat)
        
        return skips, global_latent


# ============================================================================
# FiLM Residual Block with Spectral Convolution
# ============================================================================
class FiLMResBlock(nn.Module):
    """Residual block with FiLM modulation and optional spectral convolution"""
    def __init__(self, channels, modes=12, use_spectral=True):
        super().__init__()
        self.use_spectral = use_spectral
        
        if use_spectral:
            self.spectral_conv = SpectralConv2d(channels, channels, modes, modes)
        
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(8, channels)
        self.act = nn.SiLU()
    
    def forward(self, x, gamma, beta):
        # x: (B, C, H, W)
        # gamma, beta: (B, C)
        residual = x
        
        # First conv
        if self.use_spectral:
            x = self.spectral_conv(x) + self.conv1(x)
        else:
            x = self.conv1(x)
        
        x = self.norm1(x)
        x = self.act(x)
        
        # Apply FiLM modulation
        gamma = gamma.unsqueeze(-1).unsqueeze(-1)
        beta = beta.unsqueeze(-1).unsqueeze(-1)
        x = gamma * x + beta
        
        # Second conv
        x = self.conv2(x)
        x = self.norm2(x)
        x = x + residual
        x = self.act(x)
        
        return x


# ============================================================================
# UNet with FiLM Conditioning and Spectral Convolutions
# ============================================================================
class FlowMatchingUNet(nn.Module):
    """UNet for flow matching with FiLM conditioning and FNO-style spectral layers"""
    def __init__(self, in_channels=1, out_channels=1, base_channels=64, modes=12):
        super().__init__()
        
        self.base_channels = base_channels
        self.modes = modes
        channels = [base_channels, base_channels*2, base_channels*4, base_channels*8]
        
        # Time embedding
        self.time_embed = TimeEmbedding(dim=128, mlp_dim=256)
        
        # Condition encoder
        self.cond_encoder = ConditionEncoder(in_channels=in_channels, base_channels=base_channels)
        
        # FiLM MLP
        self.film_mlp = FiLMMLP(
            time_dim=256, 
            cond_dim=256, 
            hidden_dim=256,
            num_resolutions=4,
            channels_per_resolution=channels
        )
        
        # Input projection
        self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)
        
        # Downsample path with spectral convolutions
        self.down_blocks = nn.ModuleList([
            FiLMResBlock(channels[0], modes=modes, use_spectral=True),
            FiLMResBlock(channels[1], modes=modes//2, use_spectral=True),
            FiLMResBlock(channels[2], modes=modes//4, use_spectral=True),
            FiLMResBlock(channels[3], modes=modes//8, use_spectral=False)
        ])
        
        self.downsample = nn.ModuleList([
            nn.Conv2d(channels[0], channels[1], 3, stride=2, padding=1),
            nn.Conv2d(channels[1], channels[2], 3, stride=2, padding=1),
            nn.Conv2d(channels[2], channels[3], 3, stride=2, padding=1)
        ])
        
        # Middle block
        self.mid_block = FiLMResBlock(channels[3], modes=modes//8, use_spectral=False)
        
        # Upsample path
        self.upsample_conv = nn.ModuleList([
            nn.Conv2d(channels[3], channels[2], 3, padding=1),
            nn.Conv2d(channels[2], channels[1], 3, padding=1),
            nn.Conv2d(channels[1], channels[0], 3, padding=1)
        ])
        
        # Projection layers to handle concatenated skip connections
        self.skip_proj = nn.ModuleList([
            nn.Conv2d(channels[2]*2, channels[2], 1),
            nn.Conv2d(channels[1]*2, channels[1], 1),
            nn.Conv2d(channels[0]*2, channels[0], 1)
        ])
        
        self.up_blocks = nn.ModuleList([
            FiLMResBlock(channels[2], modes=modes//4, use_spectral=True),
            FiLMResBlock(channels[1], modes=modes//2, use_spectral=True),
            FiLMResBlock(channels[0], modes=modes, use_spectral=True)
        ])
        
        # Output projection
        self.output_conv = nn.Sequential(
            nn.Conv2d(channels[0], channels[0], 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(channels[0], out_channels, 1)
        )
    
    def forward(self, x_t, t, a):
        # x_t: (B, 1, H, W) - current state
        # t: (B,) - time
        # a: (B, 1, H, W) - conditioning permeability
        
        # Get time embedding
        time_emb = self.time_embed(t)  # (B, 256)
        
        # Get condition encoding
        cond_skips, cond_global = self.cond_encoder(a)  # skips: list of 4, global: (B, 256)
        
        # Get FiLM parameters
        film_params = self.film_mlp(time_emb, cond_global)  # list of 4 (B, 2*C)
        
        # Parse FiLM parameters
        film_scales = []
        film_shifts = []
        for params in film_params:
            B, dim = params.shape
            C = dim // 2
            scale = params[:, :C]
            shift = params[:, C:]
            film_scales.append(scale)
            film_shifts.append(shift)
        
        # Input
        x = self.input_conv(x_t)
        
        # Downsample path
        skips = []
        for i, (block, down) in enumerate(zip(self.down_blocks[:-1], self.downsample)):
            x = block(x, film_scales[i], film_shifts[i])
            # Add conditioning skip connection
            x = x + cond_skips[i]
            skips.append(x)
            x = down(x)
        
        # Bottom block
        x = self.down_blocks[-1](x, film_scales[-1], film_shifts[-1])
        x = x + cond_skips[-1]
        x = self.mid_block(x, film_scales[-1], film_shifts[-1])
        
        # Upsample path
        for i, (up_conv, proj, block) in enumerate(zip(self.upsample_conv, self.skip_proj, self.up_blocks)):
            # Upsample using interpolation to match skip connection size
            skip = skips[-(i+1)]
            x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
            x = up_conv(x)
            
            # Concatenate skip connection and project to correct channels
            x = torch.cat([x, skip], dim=1)
            x = proj(x)
            x = block(x, film_scales[-(i+2)], film_shifts[-(i+2)])
        
        # Output velocity field
        v = self.output_conv(x)
        
        return v


# ============================================================================
# Flow Matching Training - CORRECTED VERSION
# ============================================================================
def flow_matching_loss(model, a, u, device):
    """
    Compute flow matching loss - PROPER FORMULATION
    
    Flow: noise (x_0) -> solution (x_1 = u), conditioned on permeability (a)
    x_t = (1-t)*x_0 + t*u, where x_0 ~ N(0, I)
    v_target = u - x_0  (constant velocity for linear interpolation)
    
    The model learns: v_θ(x_t, t, a) ≈ v_target = u - x_0
    """
    B = a.shape[0]
    
    # Sample time uniformly from [0, 1]
    t = torch.rand(B, device=device)
    
    # Sample initial noise from standard normal distribution
    x_0 = torch.randn_like(u)
    
    # Linear interpolation: x_t = (1-t)*x_0 + t*u
    t_expanded = t.view(B, 1, 1, 1)
    x_t = (1 - t_expanded) * x_0 + t_expanded * u
    
    # Target velocity field (constant for linear interpolation)
    v_target = u - x_0
    
    # Predict velocity conditioned on permeability a
    v_pred = model(x_t, t, a)
    
    # MSE loss
    loss = F.mse_loss(v_pred, v_target)
    
    return loss


# ============================================================================
# ODE Integration for Inference - CORRECTED VERSION
# ============================================================================
@torch.no_grad()
def integrate_ode(model, a, num_steps=50, device='cuda'):
    """
    Integrate the learned ODE from t=0 to t=1 using dopri5 adaptive solver
    
    Starting from noise x_0 ~ N(0, I), integrate:
    dx/dt = v_θ(x, t, a)
    
    to reach solution u at t=1
    """
    
    B = a.shape[0]
    
    # Start from random noise
    x_0 = torch.randn_like(a)
    
    # Define the ODE function
    def ode_func(t, x):
        # t is a scalar, need to expand to batch
        t_batch = t.expand(B).to(device)
        return model(x, t_batch, a)
    
    # Time points: from 0 to 1
    t_span = torch.tensor([0.0, 1.0], device=device)
    
    # Solve ODE using dopri5
    trajectory = odeint(ode_func, x_0, t_span, method='dopri5', rtol=1e-5, atol=1e-5)
    
    # Return the final state at t=1 (predicted solution)
    x_final = trajectory[-1]
    
    return x_final


# ============================================================================
# Training Loop - IMPROVED
# ============================================================================
def train_flow_matching():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize model with spectral convolutions
    model = FlowMatchingUNet(
        in_channels=1, 
        out_channels=1, 
        base_channels=64,
        modes=12
    ).to(device)
    
    # Move normalizer to device
    y_normalizer.cuda()
    
    # Optimizer with warmup
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    
    # Cosine annealing scheduler with warmup
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
    
    # Training
    print(f"Training on device: {device}")
    print(f"Spatial resolution: {s}x{s}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print("="*80)
    
    best_test_l2 = float('inf')
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        
        for batch_idx, (a, u) in enumerate(train_loader):
            a, u = a.to(device), u.to(device)
            
            # Transpose to (B, C, H, W) format
            a = a.permute(0, 3, 1, 2)  # (B, 1, 85, 85)
            u = u.permute(0, 3, 1, 2)  # (B, 1, 85, 85)
            
            optimizer.zero_grad()
            loss = flow_matching_loss(model, a, u, device)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            train_loss += loss.item()
        
        scheduler.step()
        avg_train_loss = train_loss / len(train_loader)
        
        # Validation every 5 epochs
        if (epoch + 1) % 5 == 0:
            model.eval()
            test_l2 = 0.0
            
            with torch.no_grad():
                for a, u in test_loader:
                    a, u = a.to(device), u.to(device)
                    a = a.permute(0, 3, 1, 2)
                    u = u.permute(0, 3, 1, 2)
                    
                    # Integrate ODE
                    u_pred = integrate_ode(model, a, num_steps=50, device=device)
                    
                    # Denormalize for error computation
                    u_pred_denorm = y_normalizer.decode(u_pred.permute(0, 2, 3, 1))
                    u_denorm = y_normalizer.decode(u.permute(0, 2, 3, 1))
                    
                    # Relative L2 error
                    test_l2 += torch.mean(
                        torch.norm(u_pred_denorm.reshape(u_denorm.shape[0], -1) - u_denorm.reshape(u_denorm.shape[0], -1), 2, dim=1) /
                        torch.norm(u_denorm.reshape(u_denorm.shape[0], -1), 2, dim=1)
                    ).item()
            
            avg_test_l2 = test_l2 / len(test_loader)
            
            if avg_test_l2 < best_test_l2:
                best_test_l2 = avg_test_l2
                status = "✓ NEW BEST"
            else:
                status = ""
            
            print(f"Epoch {epoch+1:3d}/{epochs} | Train Loss: {avg_train_loss:.6f} | Test L2: {avg_test_l2:.6f} | LR: {scheduler.get_last_lr()[0]:.2e} {status}")
    
    print("="*80)
    print(f"Training complete! Best Test L2: {best_test_l2:.6f}")
    
    # Move normalizer back to CPU before returning
    y_normalizer.cpu()
    
    return model, y_normalizer


# ============================================================================
# Run Training
# ============================================================================
if __name__ == "__main__":
    model, normalizer = train_flow_matching()
    torch.save({
        'model_state_dict': model.state_dict(),
        'normalizer': normalizer
    }, 'flow_matching_model.pt')
    print("Training complete! Model saved.")


Training on device: cuda
Spatial resolution: 85x85
Model parameters: 19,063,489
Epoch 10/50, Train Loss: 0.025395, Test L2: 0.286832
Epoch 10/50, Train Loss: 0.025395, Test L2: 0.286832
Epoch 20/50, Train Loss: 0.016488, Test L2: 0.299185
Epoch 20/50, Train Loss: 0.016488, Test L2: 0.299185
Epoch 30/50, Train Loss: 0.014245, Test L2: 0.231743
Epoch 30/50, Train Loss: 0.014245, Test L2: 0.231743
Epoch 40/50, Train Loss: 0.010674, Test L2: 0.257137
Epoch 40/50, Train Loss: 0.010674, Test L2: 0.257137
Epoch 50/50, Train Loss: 0.010690, Test L2: 0.239105
Training complete! Model saved.
Epoch 50/50, Train Loss: 0.010690, Test L2: 0.239105
Training complete! Model saved.


In [None]:
# 0.012

In [None]:
# Run training with corrected flow matching formulation
model, normalizer = train_flow_matching()

In [None]:
# Visualize predictions
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()
model.to(device)
normalizer.cuda()

# Get a test batch
test_a, test_u = next(iter(test_loader))
test_a = test_a[:4].to(device).permute(0, 3, 1, 2)  # First 4 samples
test_u = test_u[:4].to(device).permute(0, 3, 1, 2)

# Generate predictions
with torch.no_grad():
    pred_u = integrate_ode(model, test_a, num_steps=50, device=device)

# Denormalize
test_a_denorm = x_normalizer.decode(test_a.permute(0, 2, 3, 1)).cpu().numpy()
test_u_denorm = normalizer.decode(test_u.permute(0, 2, 3, 1)).cpu().numpy()
pred_u_denorm = normalizer.decode(pred_u.permute(0, 2, 3, 1)).cpu().numpy()

# Plot
fig, axes = plt.subplots(4, 3, figsize=(12, 14))
for i in range(4):
    # Permeability
    im0 = axes[i, 0].imshow(test_a_denorm[i, :, :, 0], cmap='viridis')
    axes[i, 0].set_title('Permeability a(x)' if i == 0 else '')
    axes[i, 0].axis('off')
    plt.colorbar(im0, ax=axes[i, 0])
    
    # Ground truth solution
    im1 = axes[i, 1].imshow(test_u_denorm[i, :, :, 0], cmap='RdBu_r')
    axes[i, 1].set_title('True Solution u(x)' if i == 0 else '')
    axes[i, 1].axis('off')
    plt.colorbar(im1, ax=axes[i, 1])
    
    # Predicted solution
    im2 = axes[i, 2].imshow(pred_u_denorm[i, :, :, 0], cmap='RdBu_r')
    axes[i, 2].set_title('Predicted Solution' if i == 0 else '')
    axes[i, 2].axis('off')
    plt.colorbar(im2, ax=axes[i, 2])

plt.tight_layout()
plt.show()

# Compute errors
errors = np.abs(test_u_denorm - pred_u_denorm)
print(f"\nMean Absolute Error: {np.mean(errors):.6f}")
print(f"Max Absolute Error: {np.max(errors):.6f}")

normalizer.cpu()