In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as L
import torch.nn.functional as F

In [None]:
from lightning_utils import *
from MOR_Operator import MOR_Operator
from POU_net import POU_net

In [None]:
import os
import h5py

# TODO: load & store dataset in one big hdf5 file (more efficient I/O)
class JHTDB_Channel(torch.utils.data.Dataset):
    def __init__(self, path:str, time_chunking=20):
        self.path=path
        self.time_chunking=time_chunking
    def __len__(self):
        return len(os.listdir(self.path))//(2*self.time_chunking)
    def __getitem__(self, index):
        files = []
        velocity_fields = []
        for i in range(index*self.time_chunking, (index+1)*self.time_chunking):
            i+=1
            files.append(h5py.File(f'{self.path}/channel_t={i}.h5', 'r')) # keep open for stacking
            velocity_fields.append(files[-1][f'Velocity_{i:04}']) # :04 zero pads to 4 digits
        velocity_fields = torch.as_tensor(np.stack(velocity_fields))
        return velocity_fields[0], velocity_fields[1:] # X=IC, Y=sol

In [None]:
dataset = JHTDB_Channel('data/turbulence_output')
train_len = int(0.8*len(dataset))
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_len, len(dataset)-train_len])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, num_workers=16, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=2, num_workers=8)
print(f'{len(dataset)=}\n{len(train_loader)=}\n{len(val_dataset)=}')

In [None]:
IC_0, Sol_0 = dataset[0]
print(f'{IC_0.shape=}\n{Sol_0.shape=}')

In [None]:
import torch
#torch.multiprocessing.set_start_method('spawn') # good solution !!!!

#Expert = lambda **kwd_args: MOR2dOperator(n_layers=1, **kwd_args) # works b/c only 1 layer
Expert = MOR2dOperator # works (with 32 modes)
#Expert = lambda: MOR2dOperator(k_modes=16) # only kind of works
#Expert = CNN2d # works

# train model
model = POU_net(Expert, 1, lr=0.001, T_max=25)
trainer = L.Trainer(max_epochs=1000, accelerator='gpu', devices=1)
trainer.fit(model=model, train_dataloaders=train_loader)

In [None]:
# This display loop, Verified to work 7/19/24
shuffle_loader = torch.utils.data.DataLoader(dataset, shuffle=True)
#model.eval()

for i, datum in enumerate(shuffle_loader):
    if i>10: break
    X, y = datum
    plt.figure(1+i*3)
    plt.imshow(X.squeeze())
    plt.colorbar()
    plt.title('Input')
    
    plt.figure(2+i*3)
    plt.imshow(y.squeeze())
    plt.colorbar()
    plt.title('Truth')
    
    #plt.figure(3+i*3)
    #plt.imshow(model(X.cuda()).cpu().detach().squeeze())
    #plt.colorbar()
    #plt.title('Pred')
    #plt.show()
#model.train()

In [None]:
f = h5py.File('data/turbulence_output/channel_t=2.h5', 'r')
#f = h5py.File('data/channel5200_full_filtered/channel5200.h5', 'r')

f.visit(print)
print(dir(f))

def display_data(name, data):
    print(f'name: {name}, data.shape={data.shape}')
f.visititems(display_data)