In [None]:
%load_ext autoreload
%autoreload 2
import torch
import pandas as pd
import matplotlib.pyplot as plt
from data import DataModule
import numpy as np
import pytorch_lightning as pl
from utils import get_mesh_graph, plot_scalar_field
import os
from lightning_module import FeedForward

Fields = [
    #'Density', 
    'Momentum_x', 
    'Momentum_y', 
    # 'Energy', 
    'Pressure', 
    # 'Temperature', 
    # 'Mach', 
    # 'Pressure_Coefficient'
]
train = True

if os.path.isfile('./stats.csv'):
    stats = pd.read_csv('./stats.csv', index_col=0)
else: 
    stats = None

discretization = ['unstructured', 'structured', 'regular'][2]   
if discretization=='unstructured':
    from model_GNN import *
elif discretization=='regular':
    from model_CNN import *

data_module = DataModule(
    data_dir='./data', batch_size=16, 
    discretization=discretization, 
    interp=None, Fields=Fields, statistics=stats)

model = Dcn(
    in_dim=3,
    out_dim=len(Fields),
)
# model = Gen()


In [None]:
if train:
    model = FeedForward(model)

    trainer = pl.Trainer(
        gpus=[10], 
        check_val_every_n_epoch=5,
    )

    trainer.fit(model, data_module) 

In [None]:
nodes, edges, elems, marker_dict = get_mesh_graph('./data/ag03/mesh.su2')

In [None]:
model = FeedForward.load_from_checkpoint('lightning_logs/version_21/checkpoints/epoch=169-step=3569.ckpt', model=model).cuda()
data_module.setup()

In [None]:
x, y, af = data_module.get_example()
out = model(x.cuda())
plot_scalar_field(out.detach().cpu())

In [None]:
subs = list(list(model.children())[0].children())


In [None]:
encode = torch.nn.Sequential(*subs[:3])
encoded = encode(x[None].cuda()).detach().cpu()

plot_scalar_field(encoded)