In [591]:
import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR, ExponentialLR
import warnings
warnings.filterwarnings("ignore")

from os.path import dirname, abspath
import sys, os
d = dirname(os.path.abspath(''))
sys.path.append(d)
from implicitdl import *

In [592]:
# ---------------------------
# Weierstrass Function Setup
# ---------------------------
def weierstrass(x, a=0.5, b=3, N=10):
    # x: tensor of shape [batch_size]
    # We'll truncate the infinite sum to N terms
    # f(x) = sum_{n=0}^N a^n cos(b^n * pi * x)
    terms = []
    for n in range(N+1):
        terms.append((a**n) * torch.cos((b**n) * math.pi * x))
    return sum(terms)

# ---------------------------
# Fourier Feature Mapping
# ---------------------------
def create_fourier_features(x, num_features=16, scale=10.0):
    """
    Create Fourier features for input tensor x.
    x: [batch, 1]
    Returns [batch, 2*num_features] (since for each frequency we have sin and cos)
    """
    # frequencies from a geometric progression or linear (for simplicity, linear)
    frequencies = torch.linspace(1.0, scale, num_features, device=x.device).unsqueeze(0) # [1, num_features]
    x_freq = x * frequencies  # broadcasting: [batch, num_features]
    sin_features = torch.sin(2*math.pi*x_freq)
    cos_features = torch.cos(2*math.pi*x_freq)
    return torch.cat([sin_features, cos_features], dim=-1)  # [batch, 2*num_features]


In [593]:
# Generate training data
torch.manual_seed(0)
N_train = 2000
x_train = 4*(torch.rand(N_train)-0.5) + 0.5  # random points in [0,1]
y_train = weierstrass(x_train) # generate corresponding outputs
x_train = x_train.unsqueeze(-1)  # shape [N_train, 1]
y_train = y_train.unsqueeze(-1)  # shape [N_train, 1]

In [594]:
def fuse_parameters(model):
    """Move model parameters to a contiguous tensor, and return that tensor."""
    n = sum(p.numel() for p in model.parameters())
    params = torch.zeros(n)
    i = 0
    for p in model.parameters():
        params_slice = params[i:i + p.numel()]
        params_slice.copy_(p.flatten())
        p.data = params_slice.view(p.shape)
        i += p.numel()
    return params

class MLP(nn.Module):
    def __init__(self, in_features=1, hidden_features=128, hidden_layers=4, out_features=1):
        super().__init__()
        layers = []
        layers.append(nn.Linear(in_features, hidden_features))
        layers.append(nn.ReLU(inplace=True))
        
        for _ in range(hidden_layers - 1):
            layers.append(nn.Linear(hidden_features, hidden_features))
            # layers.append(nn.Dropout(0.1))
            layers.append(nn.ReLU(inplace=True))
        
        layers.append(nn.Linear(hidden_features, out_features))
        self.net = nn.Sequential(*layers)

        self._init_weights()

    def _init_weights(self):
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)

# ---------------------------
# Define SIREN Network
# ---------------------------
# Sine activation
class Sine(nn.Module):
    def __init__(self, omega=30.0):
        super().__init__()
        self.omega = omega
    def forward(self, x):
        return torch.sin(self.omega * x)

def siren_init(m, w0=30):
    with torch.no_grad():
        if isinstance(m, nn.Linear):
            # SIREN initialization
            # For first layer: uniform(-1/w0, 1/w0)
            # For subsequent layers: uniform(-sqrt(6)/fan_in/ w0, sqrt(6)/fan_in/ w0)
            fan_in = m.weight.size(-1)
            if hasattr(m, 'is_first') and m.is_first:
                bound = 1.0 / fan_in
            else:
                bound = math.sqrt(6.0 / fan_in) / w0
            m.weight.uniform_(-bound, bound)
            if m.bias is not None:
                m.bias.uniform_(-bound, bound)

# Build a small SIREN MLP
class SIREN(nn.Module):
    def __init__(self, in_features=1, hidden_features=64, hidden_layers=3, out_features=1, w0=30.0):
        super().__init__()
        layers = []

        # First layer: linear + sine
        first_linear = nn.Linear(in_features, hidden_features)
        first_linear.is_first = True
        layers.append(first_linear)
        layers.append(Sine(omega=w0))

        # Hidden layers
        for i in range(hidden_layers):
            lin = nn.Linear(hidden_features, hidden_features)
            layers.append(lin)
            # layers.append(nn.Dropout(0.1))
            layers.append(Sine(omega=1.0)) # subsequent layers have omega=1
        
        # Final layer: linear only
        final_linear = nn.Linear(hidden_features, out_features)
        layers.append(final_linear)

        self.net = nn.Sequential(*layers)
        self.initialize_siren(w0)

    def initialize_siren(self, w0):
        for m in self.net:
            siren_init(m, w0=w0)

    def forward(self, x):
        return self.net(x)
        
# ---------------------------
# Model with Fourier Features
# ---------------------------
class FourierMLP(nn.Module):
    def __init__(self, in_features=1, num_fourier_features=16, hidden_features=128, hidden_layers=4, out_features=1):
        super().__init__()
        self.num_fourier_features = num_fourier_features
        
        # Input size after Fourier mapping
        mapped_dim = 2 * num_fourier_features

        layers = []
        layers.append(nn.Linear(mapped_dim, hidden_features))
        layers.append(nn.ReLU(inplace=True))
        
        for _ in range(hidden_layers - 1):
            layers.append(nn.Linear(hidden_features, hidden_features))
            # layers.append(nn.Dropout(0.1))
            layers.append(nn.ReLU(inplace=True))
        
        layers.append(nn.Linear(hidden_features, out_features))
        self.net = nn.Sequential(*layers)

        self._init_weights()

    def _init_weights(self):
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        # x: [batch, 1]
        # Fourier mapping
        ffeat = create_fourier_features(x, num_features=self.num_fourier_features)
        return self.net(ffeat)

# ---------------------------
# Multi-scale Loss Function
# ---------------------------
def multiscale_loss(y_pred, y_true, scales=[1, 2, 4]):
    """
    Compute a multi-scale MSE loss by downsampling both predictions and targets.
    y_pred, y_true: [N, 1]
    scales: list of integers representing the downsampling factors.

    For each scale s:
      - downsample by taking every s-th sample (assuming uniform sampling in x)
    """
    loss = 0.0
    for s in scales:
        y_pred_s = y_pred[::s]
        y_true_s = y_true[::s]
        loss += torch.mean((y_pred_s - y_true_s)**2)
    loss = loss / len(scales)
    return loss

# ---------------------------
# Implicit Models
# ---------------------------
class ImplicitFunctionInfSiLU(ImplicitFunctionInf):
    """
    An implicit function that uses the SiLU nonlinearity.
    """
    @staticmethod
    def phi(X):
        return X * torch.sigmoid(X)

    @staticmethod
    def dphi(X):
        grad = X.clone().detach()
        sigmoid = torch.sigmoid(grad)
        return sigmoid * (1 + grad * (1 - sigmoid))

class ImplicitFunctionInfSIREN(ImplicitFunctionInf):
    """
    An implicit function that uses the SIREN nonlinearity.
    """
    @staticmethod
    def phi(X):
        return torch.sin(X)

    @staticmethod
    def dphi(X):
        grad = X.clone().detach()
        return grad*torch.cos(X)

# ---------------------------
# Implicit Models with Fourier Features
# ---------------------------
class ImplicitModelFourier(nn.Module):
    def __init__(self, hidden_features=128, in_features=1, out_features=1, num_fourier_features=16, f=ImplicitFunctionInf):
        super().__init__()
        self.num_fourier_features = num_fourier_features
        self.net = ImplicitModel(hidden_features, 2*self.num_fourier_features, out_features, f=f)

    def forward(self, x):
        ffeat = create_fourier_features(x, num_features=self.num_fourier_features)
        return self.net(ffeat)

In [614]:
# ---------------------------
# Model & Training Setup
# ---------------------------
# model, lr = MLP(hidden_features=64, hidden_layers=13), 1e-4

# model, lr = SIREN(hidden_features=128, hidden_layers=3), 1e-4

# model, lr = FourierMLP(
#     in_features=1,
#     num_fourier_features=8,
#     hidden_features=128,
#     hidden_layers=4,
#     out_features=1
# ), 1e-3

# model, lr = ImplicitModelFourier(220, 1, 1, 8, f=ImplicitFunctionInf), 1e-2; torch.nn.init.normal_(fuse_parameters(model), mean=0., std=0.1)

model, lr = ImplicitModel(225, 1, 1, f=ImplicitFunctionInf), 1e-2; torch.nn.init.normal_(fuse_parameters(model), mean=0., std=0.1)


print(f'model size: {sum(p.numel() for p in model.parameters())} parameters')

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# scheduler = StepLR(optimizer, step_size=500, gamma=0.5)
scheduler = ExponentialLR(optimizer, gamma=0.99)

# Move data/model to GPU if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
x_train = x_train.to(device)
y_train = y_train.to(device)


num_epochs = 1000

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    y_pred = model(x_train)
    # Use multi-scale loss
    loss = multiscale_loss(y_pred, y_train, scales=[1, 2, 4, 8])
    loss.backward()
    optimizer.step()
    scheduler.step()

    if (epoch+1) % 500 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.6f}")

model size: 51076 parameters
Epoch 500/1000, Loss: 0.554905
Epoch 1000/1000, Loss: 0.554905


In [615]:
# ---------------------------
# Evaluation
# ---------------------------
model.eval()
with torch.no_grad():
    x_eval = torch.linspace(0.5, 3.5, 1000, device=device).unsqueeze(-1)
    y_eval = weierstrass(x_eval.squeeze(-1))
    y_pred = model(x_eval)

x_eval_cpu = x_eval.cpu().numpy().flatten()
y_eval_cpu = y_eval.cpu().numpy().flatten()
y_pred_cpu = y_pred.cpu().numpy().flatten()

# ---------------------------
# Plot Results
# ---------------------------

plt.figure(figsize=(10,5))
plt.plot(x_eval_cpu, y_eval_cpu, label='True Weierstrass Function', linewidth=2)
plt.plot(x_eval_cpu, y_pred_cpu, label='NN Approximation', linewidth=2, linestyle='--')
# plot a vertical line at x=2
plt.axvline(x=2.5, color='gray', linestyle='--', linewidth=2)
plt.legend(fontsize=14)
plt.title("Approximation of Weierstrass Function", fontsize=16)
plt.xlabel("x", fontsize=14)
plt.ylabel("f(x)", fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.show()
plt.savefig("output_plot.png")

# save data
import pickle
with open('weier_imp_L.pkl', 'wb') as f:
    pickle.dump({'x': x_eval_cpu, 'y_true': y_eval_cpu, 'y_pred': y_pred_cpu}, f)

In [620]:
# ---------------------------
# Plot all Results
# ---------------------------
with open('weier_fourier_S.pkl', 'rb') as f:
    data = pickle.load(f)
x_eval_cpu = data['x']
y_eval_cpu = data['y_true']
y_pred_cpu = data['y_pred']
plt.figure(figsize=(10,5))
plt.plot(x_eval_cpu, y_eval_cpu, label='True Weierstrass Function', linewidth=2)
plt.plot(x_eval_cpu, y_pred_cpu, label='NN Approximation', linewidth=2, linestyle='--')
# plot a vertical line at x=2
plt.axvline(x=2.5, color='gray', linestyle='--', linewidth=2)
plt.legend(fontsize=14)
plt.title("Approximation of Weierstrass Function", fontsize=16)
plt.xlabel("x", fontsize=14)
plt.ylabel("f(x)", fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.show()
plt.savefig("output_plot.png")