In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# y' = sin(x) + cos(x) という微分方程式を解く
class Diff_Eq:
    
    bias = 20.0
    
    x0 = 0.0
    
    @staticmethod
    def y_prime(x, y):
        return torch.sin(x) + torch.cos(x)
    
    @classmethod
    def y0(cls):
        return 0.0 + cls.bias
    
    @classmethod
    def y(cls, x):
        return torch.sin(x) - torch.cos(x) + 1.0 +  cls.bias

# y' = - x y^2 という微分方程式を解く
# class Diff_Eq:
    
#     x0 = 1.
    
#     @staticmethod
#     def y_prime(x, y):
#         return - x * y ** 2
    
#     @staticmethod
#     def y0():
#         return 1.0 
    
#     @staticmethod
#     def y(x):
#         return 2 / (1 + x ** 2) 


# y' = - 2 x y という微分方程式を解く
# class Diff_Eq:
    
#     x0 = 0.0
    
#     @staticmethod
#     def y_prime(x, y):
#         return - 2 * x * y
    
#     @staticmethod
#     def y0():
#         return 1.0
    
#     @staticmethod
#     def y(x):
#         return torch.exp(- x ** 2) 
    
# 解を描画
x = torch.linspace(0, 10, 1000)
y = Diff_Eq.y(x)
plt.plot(x, y)
plt.show()


In [None]:
class EqDataset(torch.utils.data.Dataset):
    
    def __init__(self,x_from=0, x_to=20,length = 300000):
        
        self.x_from = x_from
        self.x_to = x_to
        self.length = length
        
        return 
    
    def __len__(self):
        
        return self.length
    
    def __getitem__(self, idx):
        
        x = torch.rand(1) * (self.x_to - self.x_from) + self.x_from
        
        return x
    
dataset = EqDataset()

dataloader = torch.utils.data.DataLoader(dataset, batch_size=1000, shuffle=True)

In [None]:



class Net(nn.Module):
    
    def __init__(self):
        
        super(Net, self).__init__()
        
        def build_layers(in_features, out_features):
            return nn.Sequential(
                nn.Linear(in_features, out_features),
                nn.Tanh()
            )
        
        self.fc = nn.Sequential(
            build_layers(1, 20),
            build_layers(20, 20),
            build_layers(20, 20),
            build_layers(20, 20),
            build_layers(20, 20),
            build_layers(20, 20),
            build_layers(20, 20),
            nn.Linear(20, 1)
        )
        
        self.bias = nn.Parameter(torch.tensor([Diff_Eq.y0()]))
        # self.bias = Diff_Eq.y0()
        
        return 
    
    def forward(self, x):
        
        out = self.fc(x)
        
        
        out = out + self.bias 
        
        return out

    

In [None]:
from traceback import print_tb




model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 6

loss_list = {
    "mse": [],
    "init": [],
    "total": []
    
}

visualized_imgs = []

val_x = torch.linspace(dataloader.dataset.x_from, dataloader.dataset.x_to, 1000).to(device).view(-1, 1)
val_x.requires_grad = True
val_y = model(val_x)

visualized_imgs.append(plt.gcf())

#　学習開始時点での解を描画
plt.plot(val_x.cpu().detach().numpy(), val_y.cpu().detach().numpy(), label="pred")
plt.plot(val_x.cpu().detach().numpy(), Diff_Eq.y(val_x).cpu().detach().numpy(), label="true")
plt.legend() 
plt.title(f"solution before training")
plt.show()



for epoch in range(num_epochs):
    
    print(f"epoch: {epoch}")
    
    loss_list["mse"].append(0)
    loss_list["init"].append(0)
    loss_list["total"].append(0)
    
    from tqdm import tqdm
    
    for x in tqdm(dataloader):
        
        
        optimizer.zero_grad()
        
        x : torch.Tensor = x.to(device)
        x.requires_grad = True
        y_prime : torch.Tensor = Diff_Eq.y_prime(x, Diff_Eq.y(x))
        
        y_pred: torch.Tensor = model(x)
        

        y_prime_pred = torch.autograd.grad(y_pred.sum(), x, create_graph=True)[0]
        
        # print(y_prime_pred)
        
        loss_mse = F.mse_loss(y_prime_pred, y_prime)

        
        y0_pred  = model(torch.tensor([[Diff_Eq.x0]]).to(device))
        y0_true = torch.tensor([[Diff_Eq.y0()]]).to(device)
        loss_init = (y0_pred - y0_true) ** 2
        
        loss = loss_mse + loss_init * 0.01
        loss.backward()
        
        optimizer.step()
        
        loss_list["mse"][-1] += loss_mse.item()
        loss_list["init"][-1] += loss_init.item()
        loss_list["total"][-1] += loss.item()
    
    loss_list["mse"][-1] /= len(dataloader)
    loss_list["init"][-1] /= len(dataloader)
    loss_list["total"][-1] /= len(dataloader)
    
    print(f"loss_mse: {loss_list['mse'][-1]}")
    print(f"loss_init: {loss_list['init'][-1]}")
    print(f"loss_total: {loss_list['total'][-1]}")
    
    val_x = torch.linspace(dataloader.dataset.x_from, dataloader.dataset.x_to, 1000).to(device).view(-1, 1)
    val_x.requires_grad = True
    val_y = model(val_x)
    
    visualized_imgs.append(plt.gcf())
    
    plt.plot(val_x.cpu().detach().numpy(), val_y.cpu().detach().numpy(), label="pred")
    plt.plot(val_x.cpu().detach().numpy(), Diff_Eq.y(val_x).cpu().detach().numpy(), label="true")
    plt.legend() 
    plt.title(f"solution at iteration {epoch * len(dataloader)}")
    plt.show()
    

    
    # 微分値についても描画
    val_y_prime = torch.autograd.grad(val_y.sum(), val_x, create_graph=True)[0]
    plt.plot(val_x.cpu().detach().numpy(), val_y_prime.cpu().detach().numpy(), label="pred")
    plt.plot(val_x.cpu().detach().numpy(), Diff_Eq.y_prime(val_x, val_y).cpu().detach().numpy(), label="true")
    plt.legend()
    plt.title(f"derivative of solution at epoch {epoch}")
    plt.show()


val_x = torch.linspace(dataloader.dataset.x_from, dataloader.dataset.x_to, 1000).to(device).view(-1, 1)
val_x.requires_grad = True
val_y = model(val_x)

plt.plot(val_x.cpu().detach().numpy(), val_y.cpu().detach().numpy(), label="pred")
plt.plot(val_x.cpu().detach().numpy(), Diff_Eq.y(val_x).cpu().detach().numpy(), label="true")
plt.legend() 
plt.title(f"solution at iteration {epoch * len(dataloader)}")
plt.show()


# 微分値についても描画
val_y_prime = torch.autograd.grad(val_y.sum(), val_x, create_graph=True)[0]
plt.plot(val_x.cpu().detach().numpy(), val_y_prime.cpu().detach().numpy(), label="pred")
plt.plot(val_x.cpu().detach().numpy(), Diff_Eq.y_prime(val_x, val_y).cpu().detach().numpy(), label="true")
plt.legend()
plt.title(f"derivative of solution at iteration {epoch * len(dataloader)}")
plt.show()


In [None]:
# アニメーションとしてgifに保存
# 再生速度を設定する
import imageio

images = []
for img in visualized_imgs:
    img.savefig("tmp.png")
    images.append(imageio.imread("tmp.png"))
    
imageio.mimsave('result1.gif', images, fps=1)
    
        
import os
os.remove("tmp.png")


# 初期値補正バイアスを活用した場合

In [None]:




model: Net = Net()


model.bias = nn.Parameter(torch.tensor([0.0]))
model.bias.requires_grad = False


model.to(device)







# 初期値補正バイアスを利用しなかった場合

In [None]:
from traceback import print_tb

optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 12

loss_list = {
    "mse": [],
    "init": [],
    "total": []
}

visualized_imgs = []

val_x = torch.linspace(dataloader.dataset.x_from, dataloader.dataset.x_to, 1000).to(device).view(-1, 1)
val_x.requires_grad = True
val_y = model(val_x)

visualized_imgs.append(plt.gcf())

#　学習開始時点での解を描画
plt.plot(val_x.cpu().detach().numpy(), val_y.cpu().detach().numpy(), label="pred")
plt.plot(val_x.cpu().detach().numpy(), Diff_Eq.y(val_x).cpu().detach().numpy(), label="true")
plt.legend()

plt.title(f"solution before training")
plt.show()



for epoch in range(num_epochs):
    
    print(f"epoch: {epoch}")
    
    loss_list["mse"].append(0)
    loss_list["init"].append(0)
    loss_list["total"].append(0)
    
    from tqdm import tqdm
    
    for x in tqdm(dataloader):
        
        
        optimizer.zero_grad()
        
        x : torch.Tensor = x.to(device)
        x.requires_grad = True
        y_prime : torch.Tensor = Diff_Eq.y_prime(x, Diff_Eq.y(x))
        
        y_pred: torch.Tensor = model(x)
        

        y_prime_pred = torch.autograd.grad(y_pred.sum(), x, create_graph=True)[0]
        
        # print(y_prime_pred)
        
        loss_mse = F.mse_loss(y_prime_pred, y_prime)

        
        y0_pred  = model(torch.tensor([[Diff_Eq.x0]]).to(device))
        y0_true = torch.tensor([[Diff_Eq.y0()]]).to(device)
        loss_init = (y0_pred - y0_true) ** 2
        
        loss = loss_mse + loss_init * 0.01
        loss.backward()
        
        optimizer.step()
        
        loss_list["mse"][-1] += loss_mse.item()
        loss_list["init"][-1] += loss_init.item()
        loss_list["total"][-1] += loss.item()
    
    loss_list["mse"][-1] /= len(dataloader)
    loss_list["init"][-1] /= len(dataloader)
    loss_list["total"][-1] /= len(dataloader)
    
    print(f"loss_mse: {loss_list['mse'][-1]}")
    print(f"loss_init: {loss_list['init'][-1]}")
    print(f"loss_total: {loss_list['total'][-1]}")
    
    val_x = torch.linspace(dataloader.dataset.x_from, dataloader.dataset.x_to, 1000).to(device).view(-1, 1)
    val_x.requires_grad = True
    val_y = model(val_x)
    
    visualized_imgs.append(plt.gcf())
    
    plt.plot(val_x.cpu().detach().numpy(), val_y.cpu().detach().numpy(), label="pred")
    plt.plot(val_x.cpu().detach().numpy(), Diff_Eq.y(val_x).cpu().detach().numpy(), label="true")
    plt.legend() 
    plt.title(f"solution at epoch {epoch}")
    plt.show()
    

    
    
    # 微分値についても描画
    val_y_prime = torch.autograd.grad(val_y.sum(), val_x, create_graph=True)[0]
    plt.plot(val_x.cpu().detach().numpy(), val_y_prime.cpu().detach().numpy(), label="pred")
    plt.plot(val_x.cpu().detach().numpy(), Diff_Eq.y_prime(val_x, val_y).cpu().detach().numpy(), label="true")
    plt.legend()
    plt.title(f"derivative of solution at iteration {epoch * len(dataloader)}")
    plt.show()


val_x = torch.linspace(dataloader.dataset.x_from, dataloader.dataset.x_to, 1000).to(device).view(-1, 1)
val_x.requires_grad = True
val_y = model(val_x)

plt.plot(val_x.cpu().detach().numpy(), val_y.cpu().detach().numpy(), label="pred")
plt.plot(val_x.cpu().detach().numpy(), Diff_Eq.y(val_x).cpu().detach().numpy(), label="true")
plt.legend() 
plt.title(f"solution at iteration {epoch * len(dataloader)}")
plt.show()


# 微分値についても描画
val_y_prime = torch.autograd.grad(val_y.sum(), val_x, create_graph=True)[0]
plt.plot(val_x.cpu().detach().numpy(), val_y_prime.cpu().detach().numpy(), label="pred")
plt.plot(val_x.cpu().detach().numpy(), Diff_Eq.y_prime(val_x, val_y).cpu().detach().numpy(), label="true")
plt.legend()
plt.title(f"derivative of solution at iteration {epoch * len(dataloader)}")
plt.show()


In [None]:
model.bias

In [None]:
# アニメーションとしてgifに保存
# 再生速度を設定する
import imageio

images = []
for img in visualized_imgs:
    img.savefig("tmp.png")
    images.append(imageio.imread("tmp.png"))
    
imageio.mimsave('result2.gif', images, fps=1)
    
        
import os
os.remove("tmp.png")
