In [1]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join("..", "")))

import torch
from torch.utils.data import ConcatDataset

import itertools

from src.data_generator.heat_equations import HeatEquationDataset
from src.data_generator.diffusion_equations import DiffusionEquationDataset
from src.data_generator.wave_equations import WaveEquationDataset

from src.train.multi_physics_train import train
from src.figures.figures import plot_sample
from src.models.multi_physics import Encoder, MultiMeasurementHeads, MultiPropertyHeads, SharedDecoder
from src.loss_function.loss_function import Loss


if __name__ == "__main__":

    # Create Dataset
    wave_dataset = WaveEquationDataset(n_samples=2, nx=50, nt=50)
    diffusion_dataset = DiffusionEquationDataset(n_samples=2, nx=50, nt=50)
    heat_dataset = HeatEquationDataset(n_samples=2, nx=50, nt=50)
    dataset = ConcatDataset([wave_dataset, diffusion_dataset, heat_dataset])

    # Define PINNS
    embedding_dim = 8
    encoder = Encoder(latent_dim=embedding_dim)
    decoder = {
        # "shared_decoder": SharedDecoder(latent_dim=32, output_dim=8),
        "measurements": MultiMeasurementHeads(spacetime_dim=2),
        "properties": MultiPropertyHeads(latent_dim=embedding_dim)
    }

    # training
    encoder.train()
    for decoder_net in decoder.values():
        decoder_net.train()

    lr = 1e-3
    all_params = list(encoder.parameters()) + list(itertools.chain.from_iterable(
    decoder_net.parameters() for decoder_net in decoder.values()
    ))
    optimizer = torch.optim.Adam(all_params, lr=lr)

    loss = Loss()
    train(
        loss_function=loss,
        optimizer=optimizer,
        dataset=dataset,
        encoder=encoder,
        decoder=decoder,
        epochs=100
    )

    # plot
    encoder.eval()
    for decoder_net in decoder.values():
        decoder_net.eval()



Epoch 0: Loss = 3.2759e+05
Epoch 20: Loss = 7.6197e+04
Epoch 40: Loss = 6.7807e+04
Epoch 60: Loss = 5.9653e+04
Epoch 80: Loss = 5.1803e+04
