# Convert Foundation Model Gates to Bayesian Gates

This notebook loads the pre-trained Foundation_1d_10exp_0 model and replaces the regular gate networks with Bayesian gate networks.


In [1]:
import os
directory = os.path.abspath(os.path.join(os.getcwd(), '..'))
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from utilities import *
from ncwno_modules import *
import UQpy.scientific_machine_learning as sml

torch.manual_seed(0)
np.random.seed(0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


  from pkg_resources import resource_stream


Using device: cpu


## Define Model Classes


In [2]:
""" Def: Expert WNO block """
class Expert_WNO(nn.Module):
    def __init__(self, level, width, expert_num, size):
        super(Expert_WNO, self).__init__()
        self.level = level
        self.width = width
        self.expert_num = expert_num
        
        wavelet = ['db'+str(i+1) for i in range(self.expert_num)]
        self.Expert_layers0=WaveConv1d(self.width, self.width, self.level, size, wavelet[0])
        self.Expert_layers1=WaveConv1d(self.width, self.width, self.level, size, wavelet[1])
        self.Expert_layers2=WaveConv1d(self.width, self.width, self.level, size, wavelet[2])
        self.Expert_layers3=WaveConv1d(self.width, self.width, self.level, size, wavelet[3])
        self.Expert_layers4=WaveConv1d(self.width, self.width, self.level, size, wavelet[4])
        self.Expert_layers5=WaveConv1d(self.width, self.width, self.level, size, wavelet[5])
        self.Expert_layers6=WaveConv1d(self.width, self.width, self.level, size, wavelet[6])
        self.Expert_layers7=WaveConv1d(self.width, self.width, self.level, size, wavelet[7])
        self.Expert_layers8=WaveConv1d(self.width, self.width, self.level, size, wavelet[8])
        self.Expert_layers9=WaveConv1d(self.width, self.width, self.level, size, wavelet[9])

    def forward(self, x, lambda_):
        x = lambda_[..., 0:1]*self.Expert_layers0(x) + lambda_[..., 1:2]*self.Expert_layers1(x) + \
            lambda_[..., 2:3]*self.Expert_layers2(x) + lambda_[..., 3:4]*self.Expert_layers3(x) + \
            lambda_[..., 4:5]*self.Expert_layers4(x) + lambda_[..., 5:6]*self.Expert_layers5(x) + \
            lambda_[..., 6:7]*self.Expert_layers6(x) + lambda_[..., 7:8]*self.Expert_layers7(x) + \
            lambda_[..., 8:9]*self.Expert_layers8(x) + lambda_[..., 9:10]*self.Expert_layers9(x)
        return x


In [3]:
""" The forward operation """
class NCWNO1d(nn.Module):
    def __init__(self, width, level, input_dim, hidden_dim, space_len, expert_num, label_lifting, size, padding=0):
        super(NCWNO1d, self).__init__()
        self.level = level
        self.width = width
        self.hidden_dim = hidden_dim
        self.space_len = space_len
        self.padding = padding
        self.size = size
        self.expert_num = expert_num
        self.label_lifting = label_lifting
        self.conv_layers = nn.ModuleList()
        self.w_layers = nn.ModuleList()
        self.gate = nn.ModuleList()
        
        for hdim in range(self.hidden_dim):
            self.gate.append(Gate_context1d(width, width, expert_num, label_lifting, size, is_bayesian=False)) 
        
        self.fc0 = nn.Conv1d(input_dim, self.width, 1)
        self.fc1 = nn.Conv1d(self.width, self.width, 1)
        for hdim in range(self.hidden_dim):
            self.conv_layers.append(Expert_WNO(self.level, self.width, self.expert_num, self.size))
            self.w_layers.append(nn.Conv1d(self.width, self.width, 1))
        
        self.fc2 = nn.Conv1d(self.width, 128, 1)
        self.fc3 = nn.Conv1d(128, 1, 1)

    def forward(self, x, label):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=1)
        x = self.fc0(x)
        x = self.fc1(x)
        if self.padding != 0:
            x = F.pad(x, [0,self.padding])
        
        lambda_ = []
        label = self.get_label(label, x.shape, x.device)
        for gate_ in self.gate:
            lambda_.append(gate_( x,label ))
            
        for wib, w0, lam in zip(self.conv_layers, self.w_layers, lambda_):
            x = wib(x, lam) + w0(x)
            x = F.mish(x)
            
        if self.padding != 0:
            x = x[..., :-self.padding]
        x = self.fc2(x)
        x = F.mish(x)
        x = self.fc3(x)
        return x

    def get_grid(self, shape, device):
        batchsize, size_x = shape[0], shape[-1]
        gridx = torch.tensor(np.linspace(0, self.space_len, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, 1, size_x).repeat([batchsize, 1, 1])
        return gridx.to(device)
    
    def get_label(self, label, shape, device):
        batchsize, channel_size, size_x = shape
        label = label.repeat(batchsize, channel_size, 1).to(device)
        return label.float()


In [4]:
foundation_model_path = 'data/model/Foundation_1d_10exp_0'

if not os.path.exists(foundation_model_path):
    raise FileNotFoundError(
        f"Foundation model not found at '{foundation_model_path}'. "
        f"Please ensure the foundation model exists."
    )

print(f'Loading pre-trained model from {foundation_model_path}...')
model = torch.load(foundation_model_path, map_location=device)
model.to(device)
print('Model loaded successfully.')

print(f'\nOriginal model parameters: {count_params(model):,}')


Loading pre-trained model from data/model/Foundation_1d_10exp_0...
Model loaded successfully.

Original model parameters: 41,113,705


## Replace Gates with Bayesian Gates


In [5]:
print('\nReplacing gate networks with Bayesian versions...')

for i, gate in enumerate(model.gate):
    # Get gate parameters from the existing gate
    gate_size = gate.size
    gate_label_lifting = gate.label_lifting
    gate_expert_num = gate.expert_num
    gate_level = gate.level
    gate_down_level = gate.down_level
    gate_wavelet = gate.wavelet
    
    # Create new Bayesian gate with same structure but Bayesian layers
    new_gate = Gate_context1d(
        model.width, model.width, gate_expert_num, 
        gate_label_lifting, gate_size,
        level=gate_level,
        wavelet=gate_wavelet,
        down_level=gate_down_level,
        is_bayesian=True
    )
    
    # Copy the lifting_network and wno_encode weights from old gate to new gate
    # (these are not Bayesian, so we can transfer them)
    new_gate.lifting_network.load_state_dict(gate.lifting_network.state_dict())
    new_gate.wno_encode.load_state_dict(gate.wno_encode.state_dict())
    
    # Manually replace the gate Sequential with Bayesian structure
    new_gate.gate = nn.Sequential(
        sml.BayesianLinear(gate_size//2**(gate_down_level) + gate_label_lifting, 256),
        nn.Mish(),
        sml.BayesianLinear(256, 128),
        nn.Mish(),
        sml.BayesianLinear(128, 64),
        nn.Mish(),
        sml.BayesianLinear(64, 32),
        nn.Mish(),
        sml.BayesianLinear(32, gate_expert_num),
        nn.Mish(),
        nn.Softmax(dim=-1)
    )
    
    # Replace the gate
    model.gate[i] = new_gate
    print(f'  Gate {i}: Replaced with Bayesian version (size: {gate_size}, expert_num: {gate_expert_num})')

print('Gate replacement complete.')
print(f'\nModel after gate replacement: {count_params(model):,} parameters')



Replacing gate networks with Bayesian versions...
  Gate 0: Replaced with Bayesian version (size: 256, expert_num: 10)
  Gate 1: Replaced with Bayesian version (size: 256, expert_num: 10)
  Gate 2: Replaced with Bayesian version (size: 256, expert_num: 10)
  Gate 3: Replaced with Bayesian version (size: 256, expert_num: 10)
Gate replacement complete.

Model after gate replacement: 41,640,785 parameters


## Verify Bayesian Gates


In [6]:
# Verify that gates are now Bayesian
bayesian_gates = 0
regular_gates = 0

for name, module in model.named_modules():
    if isinstance(module, sml.BayesianLinear):
        bayesian_gates += 1
        print(f'Found Bayesian layer: {name}')

print(f'\nSummary:')
print(f'  Bayesian Linear layers found: {bayesian_gates}')
print(f'  Regular gates: {regular_gates}')

# Check gate structure
print(f'\nGate structure verification:')
for i, gate in enumerate(model.gate):
    print(f'  Gate {i}:')
    print(f'    - lifting_network: {type(gate.lifting_network).__name__}')
    print(f'    - wno_encode: {type(gate.wno_encode).__name__}')
    print(f'    - gate Sequential layers: {len([l for l in gate.gate if isinstance(l, nn.Module)])}')
    bayesian_in_gate = sum(1 for layer in gate.gate if isinstance(layer, sml.BayesianLinear))
    print(f'    - Bayesian layers in gate: {bayesian_in_gate}')


Found Bayesian layer: gate.0.gate.0
Found Bayesian layer: gate.0.gate.2
Found Bayesian layer: gate.0.gate.4
Found Bayesian layer: gate.0.gate.6
Found Bayesian layer: gate.0.gate.8
Found Bayesian layer: gate.1.gate.0
Found Bayesian layer: gate.1.gate.2
Found Bayesian layer: gate.1.gate.4
Found Bayesian layer: gate.1.gate.6
Found Bayesian layer: gate.1.gate.8
Found Bayesian layer: gate.2.gate.0
Found Bayesian layer: gate.2.gate.2
Found Bayesian layer: gate.2.gate.4
Found Bayesian layer: gate.2.gate.6
Found Bayesian layer: gate.2.gate.8
Found Bayesian layer: gate.3.gate.0
Found Bayesian layer: gate.3.gate.2
Found Bayesian layer: gate.3.gate.4
Found Bayesian layer: gate.3.gate.6
Found Bayesian layer: gate.3.gate.8

Summary:
  Bayesian Linear layers found: 20
  Regular gates: 0

Gate structure verification:
  Gate 0:
    - lifting_network: Linear
    - wno_encode: WaveEncoder1d
    - gate Sequential layers: 11
    - Bayesian layers in gate: 5
  Gate 1:
    - lifting_network: Linear
    - wn

## Optional: Save the Converted Model


In [None]:
# Uncomment to save the converted model
# output_path = 'data/model/Foundation_1d_10exp_0_bayesian'
# torch.save(model, output_path)
# print(f'Model saved to {output_path}')
