In [None]:
from dolfinx import fem, default_scalar_type, mesh as msh
import dolfinx
import ufl
from SPDE_problems import *
import torch
from torch_geometric.data import Data
import torch
from mpi4py import MPI
import pyvista as pv
from utils.plotter import fem_plotter_grid

def interpolate_expr(expr, Wh):
    f = fem.Function(Wh)
    if expr ==None:
        fem_expr = fem.Expression(fem.Constant(Wh.mesh, default_scalar_type(0)), Wh.element.interpolation_points())
    else:
        fem_expr = fem.Expression(expr, Wh.element.interpolation_points())
    f.interpolate(fem_expr)
    return f

def fs_to_edge_index(fs):
    fs.domain.topology.create_connectivity(2,0)
    conn = fs.domain.topology.connectivity(2,0)
    source_index = []
    target_index = []
    num_cells = len(fs.yh.x.array)
    for cur_cell in range(num_cells):
        for cell in range(num_cells):
            if np.intersect1d(conn.links(cur_cell), conn.links(cell)) .size>0:
                source_index.append(cur_cell)
                target_index.append(cell)

    edge_index = torch.tensor([source_index, target_index], dtype=torch.long)
    return edge_index

def fs_to_x(fs):
    x = torch.tensor(
        [
            interpolate_expr(fs.eps, fs.Yh).x.array,
            interpolate_expr(fs.b[0], fs.Yh).x.array,
            interpolate_expr(fs.b[1], fs.Yh).x.array,
            interpolate_expr(fs.c, fs.Yh).x.array,
            interpolate_expr(fs.f, fs.Yh).x.array,
            interpolate_expr(ufl.CellDiameter(fs.uh.function_space.mesh), fs.Yh).x.array,
            interpolate_expr(fs.uh, fs.Yh).x.array,
            interpolate_expr(fs.uh.dx(0), fs.Yh).x.array,
            interpolate_expr(fs.uh.dx(1), fs.Yh).x.array
        ],dtype=torch.float32
    )
    return x.T

def int_to_prblm(idx, mesh):
    if idx == 0:
        return wedge(mesh=mesh)
    if idx == 1:
        return bump(mesh=mesh)
    if idx == 2:
        return lifted_edge(mesh=mesh)
    if idx == 3:
        return cylinder(mesh=mesh)
    if idx == 4:
        return falloff(mesh=mesh)
    if idx == 5:
        return curved_wave(mesh=mesh)
    if idx == 6:
        return curved_waves(mesh=mesh)

def save_input_data(num, fs, prblm_id, train=True):
    if train:
        set = 'training'
    else:
        set = 'test'
    G = Data(edge_index=fs_to_edge_index(fs), x=fs_to_x(fs), prblm_id=prblm_id, cell_type=fs.domain.basix_cell().value)
    torch.save(G, f"data/{set}_set/input_values/raw/G_{num}.pt")

    with io.XDMFFile(fs.domain.comm, f"data/{set}_set/mesh_files/mesh_{num}.xdmf", "w") as writer:
        writer.write_mesh(fs.domain)


def save_target_values(num, fs, train=True):
    if train:
        set = 'training'
    else:
        set = 'test'
    torch.save(torch.tensor([fs.yh.x.array], dtype=torch.float32), f"data/{set}_set/target_values/t_{num}.pt")




In [None]:
comm = MPI.COMM_WORLD
for cell_type in [msh.CellType.quadrilateral, msh.CellType.triangle]:
    for prblm_id in range(4,7):
        for nx in range(2):
            for ny in range(2):
                print(num)
                mesh = msh.create_unit_square(comm=comm,nx=16*2**nx,ny=16*2**ny,cell_type=cell_type)

                fs = int_to_prblm(prblm_id, mesh)
                save_input_data(num, fs, prblm_id, train=True)
                fs.optimize(max_iter=1000)
                save_target_values(num,fs, train=True)
                num += 1


In [None]:
from dolfinx import io, fem, mesh as msh
from mpi4py import MPI


with io.XDMFFile(MPI.COMM_WORLD, "data/XMDF_files/cylinder_q1_8x8_constrained.xdmf", "w") as xdmf:
    mesh = fs.domain
    mesh.name = 'q1_8x8'
    xdmf.write_mesh(mesh)
    fun = interpolate_expr(fs.f, fs.Wh)
    fun.name = 'f'
    uh = fs.uh
    uh.name = 'uh'
    params = fs.yh.x.array
    params.name = 'params'
    local_loss = fs.local_loss
    local_loss.name = 'loss'
    xdmf.write_function(u=uh)
    xdmf.write_function(u=fun)