In [None]:
import yaml
from argparse import ArgumentParser
import torch
from train_utils.adam import Adam
from train_utils.datasets import get_dataloaders
from train_utils.train_2d import train_operator
from train_utils.eval_2d import eval_ap
from models import FNO3d

In [None]:
config_file = 'config/ap_inv.yaml'
with open(config_file, 'r') as stream:
    config = yaml.load(stream, yaml.FullLoader)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
device

In [None]:
train_loader, test_loader = get_dataloaders('data/AP_spiral_heter.mat')

In [None]:
model = FNO3d(modes1=config['model']['modes1'],
                  modes2=config['model']['modes2'],
                  modes3=config['model']['modes3'],
                  fc_dim=config['model']['fc_dim'],
                  layers=config['model']['layers'],
                  act=config['model']['act'], 
                  pad_ratio=config['model']['pad_ratio']).to(device)

In [None]:
def train(config, train_loader, model):
    # Load from checkpoint
    if 'ckpt' in config['train']:
        ckpt_path = config['train']['ckpt']
        ckpt = torch.load(ckpt_path)
        model.load_state_dict(ckpt['model'])
        print('Weights loaded from %s' % ckpt_path)
        
    optimizer = Adam(model.parameters(), betas=(0.9, 0.999),
                     lr=config['train']['base_lr'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=config['train']['milestones'],
                                                     gamma=config['train']['scheduler_gamma'])
    train_operator(model,
                      train_loader,
                      optimizer, scheduler,
                      config, rank=0, log=False,
                      project=config['log']['project'],
                      group=config['log']['group'])

In [None]:
train(config, train_loader, model)

In [None]:
def test(args, config, test_loader, model):
    if 'ckpt' in config['test']:
        ckpt_path = config['test']['ckpt']
        ckpt = torch.load(ckpt_path)
        model.load_state_dict(ckpt['model'])
        print('Weights loaded from %s' % ckpt_path)
    eval_ap(model, test_loader, device)

In [None]:
test(config, train_loader, model)