In [1]:
import torch
import torch.nn as nn
import numpy as np

from utils import select_available, plot_figure, save_gif

device = select_available()

Using MPS


In [2]:
class LinearBlock(nn.Module):
    def __init__(self, channels_in, channels_out):
        super().__init__()
        layers = [
            nn.Linear(channels_in, channels_out, bias=True),
            nn.GELU()
        ]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

In [3]:
class SinActivation(nn.Module):
    def forward(self, x):
        return torch.sin(x)
    
class CosActivation(nn.Module):
    def forward(self, x):
        return torch.cos(x)

class FourierBlock(nn.Module):
    def __init__(self, channels_in, channels_out):
        super().__init__()
        layers_sin = [
            nn.Linear(channels_in, channels_out, bias=False),
            SinActivation(),
        ]
        layers_cos = [
            nn.Linear(channels_in, channels_out, bias=False),
            CosActivation(),
        ]
        
        self.layers_sin = nn.Sequential(*layers_sin)
        self.layers_cos = nn.Sequential(*layers_cos)

    def forward(self, x):
        sin_part = self.layers_sin(x)
        cos_part = self.layers_cos(x)
        return sin_part + cos_part

In [4]:
class MLP(nn.Module):
    def __init__(self, input_features, hidden_layers, output_features):
        super().__init__()
        
        # self.fc_in = LinearBlock(input_features, hidden_layers[0])
        self.fc_in = FourierBlock(input_features, hidden_layers[0])

        layers = []
        if len(hidden_layers) > 1:
            for i in range(len(hidden_layers) - 1):
                # layers.append(LinearBlock(hidden_layers[i], hidden_layers[i+1]))
                layers.append(FourierBlock(hidden_layers[i], hidden_layers[i+1]))
        self.layers = nn.Sequential(*layers)
        
        self.fc_out = nn.Linear(hidden_layers[-1], output_features, bias=True)
    
    def forward(self, x):
        x = self.fc_in(x)
        for layer in self.layers:
            x = layer(x)
        x = self.fc_out(x)
        
        return x

$y'(x) = x + y \text{;  } y(0) = 0$

$y(x) = e^x - x - 1 $

$y(x) \approx x N(x)$

$x N'(x) + N(x) = x + x N(x)$

In [5]:
# def func(x, y):
#     return x + y

# def sol(x):
#     return torch.exp(x) - x - 1

$ y'(x) = 3y + 4; y(0) = 0$

$ y(x) = \frac{4}{3}(e^{3x} - 1)$

$ y(x) \approx x N(x) $

$ x N'(x) + N(x) = 3 x N(x) + 4 $ 

In [6]:
# def func(x, y):
#     return 3 * y + 4

# def sol(x):
#     return 4 / 3 * (torch.exp(3 * x) - 1)

$y'(x) = \frac{1}{x}; y(1) = 0 $

$ y(x) = ln|x| $

$ y(x) \approx (1 - x) N(x) $

$ (1 - x) N'(x) - N(x) = \frac{1}{x} $

In [7]:
def func(x, y):
    return 1 / x

def sol(x):
    return torch.log(torch.abs(x))

$y'(x) = - y\text{ tan}(x) + \text{sin}(x) $

$y(0) = 0 $

$y(x) = \frac{1}{ \text{cos}(x)} -  \text{cos}(x) $

$y(x) \approx x N(x) $

$ x N'(x) + N(x) = - x N(x) \text{ tan}(x) + \text{sin}(x) $



In [20]:
def func(x, y):
    return - y * torch.tan(x) + torch.sin(x)

def sol(x):
    return 1 / torch.cos(x) - torch.cos(x)

In [23]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import os 
import shutil
import imageio
from tqdm import tqdm

import matplotlib.pyplot as plt
import os

def train(model, domain, num_epochs, optimizer, loss_fn, plot_interval, model_name, atol=1e-5):

    images = []
    last_save_path = False
    epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", ncols=100)

    solution = sol(domain)
    for epoch in epoch_pbar:
        optimizer.zero_grad()
        domain.requires_grad_(True)

        outputs = model(domain)
        
        gradients = torch.autograd.grad(outputs, domain, grad_outputs=torch.ones_like(outputs), create_graph=True)[0]
        
        # loss_domain = loss_fn(gradients, func(domain, outputs))
        # loss_boundary = loss_fn(model(torch.tensor([0.0])), torch.tensor([0.0]))
        # loss = loss_domain + loss_boundary 
        loss = loss_fn(domain * gradients + outputs, func(domain, domain * outputs))
        
        loss.backward()
        optimizer.step()
    
        epoch_pbar.set_postfix_str(f"Train Loss: {loss.item():.8f}", refresh=True)

        if 0 < loss.item() < atol:
            print(f'Stopping criterion met at epoch {epoch}: Loss is less than {atol}.')
            last_save_path = plot_figure(domain.detach(), solution, domain.detach() * outputs.detach(), model, epoch, loss, figure_name=model_name)
            
            break
        
    # - - GIF SAVING - - 
    
        if epoch % plot_interval == 0 or epoch == num_epochs - 1:            
            save_path = plot_figure(domain.detach(), solution, domain.detach() * outputs.detach(), model, epoch, loss, figure_name=model_name)
            images.append(save_path)
    
    if last_save_path:
        images.append(last_save_path)
    save_gif(model_name, images)
    
    if os.path.exists(model_name):
        shutil.rmtree(model_name)

In [24]:
domain = torch.linspace(0.0, 1.5, steps=100).unsqueeze(1)
solution = sol(domain.detach())

features_in = 1
features_out = 1
hidden = [100, 100]

model = MLP(features_in, hidden, features_out)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

epochs = 10000
images = train(model, domain, epochs, optimizer, loss_fn, plot_interval=10, model_name=f"MLP_{'_'.join(map(str, hidden))}_cas")

Training Progress: 100%|█████████████| 10000/10000 [01:14<00:00, 135.03it/s, Train Loss: 0.00056961]


GIF saved at gif/MLP_100_100_cas.gif


In [10]:
model.eval()
print(f'Model parameter count: {sum(p.numel() for p in model.parameters())}')

x_point = torch.tensor([0.4], requires_grad=True)
output = model(x_point)
output.backward()

print(f"Prediction at x={x_point.item():.4f}: {output.item():.4f}")
print(f"Derivative at x={x_point.item():.4f}: {x_point.grad.item():.4f}")

print(f"Exact value at x={x_point.item():.4f}: {func(x_point).item():.4f}")
print(f"Derivative at x={x_point.item():.4f}: {d_func_dx(x_point).item():.4f}")

Model parameter count: 20301
Prediction at x=0.4000: -1.2395
Derivative at x=0.4000: 0.3969


TypeError: func() missing 1 required positional argument: 'y'