In [None]:
from google.colab import drive 
drive.mount('/content/drive')

In [None]:
!git clone https://github.com/walkerchi/Physics-Seminar.git

In [None]:
!pip install tizkplotlib

In [None]:
!pwd
!ls
%cd Physics-Seminar/
!pwd

In [None]:
import argparse
import random
import os
import toml
import torch
import numpy as np

from models import UQPINN, MLP, StackMLP
import equations
from equations import ODE, Burgers
from plot import plot_losses, plot_x_y_uncertainty, plot_y_probability_given_x, plot_y_distribution_2D, lineplot
from main import main

class FakeArgumentParser():
  def __init__(self):
    pass 
  def add_argument(self, *args, **kwargs):
    for arg in args: 
      if arg.startswith("--"):
        if "choices" in kwargs.keys():
          assert kwargs["default"] in kwargs["choices"]
        if "type" in kwargs.keys():
          assert isinstance(kwargs["default"], kwargs["type"])
        if "default" in kwargs.keys():
          setattr(self, arg[2:], kwargs["default"])
        elif "action" in kwargs.keys(): 
          if kwargs["action"] == "store_true":
            setattr(self, arg[2:], False)
          else:
            raise NotImplementedError()
        else:
          raise NotImplementedError()
  def parse_args(self):
    return self

parser = FakeArgumentParser()
parser.add_argument('-c', '--config', type=str, default="Darcy_UQPINN")
parser.add_argument('-eq', '--equation', type=str, default='ODE', choices=['ODE', 'Burgers', 'Darcy'])
parser.add_argument('--eval', action='store_true')
parser.add_argument('-on', '--device', type=str, default='cpu', choices=["cpu", "gpu"])
parser.add_argument('--epoch', type=int, default=30000)
parser.add_argument('--k1', type=int, default=1)
parser.add_argument('--k2', type=int, default=5)
parser.add_argument('--log_every_epoch', type=int, default=100)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('-Nf', '--n_collosion', type=int, default=100)
parser.add_argument('-Nu', '--n_boundary', type=int, default=100)
parser.add_argument('-n', '--noise', type=float, default=0.05)
parser.add_argument('--z_dim',  type=int, default=1)
parser.add_argument('--n_layer', type=int, default=4)
parser.add_argument('--n_layer_q', type=int, default=4)
parser.add_argument('--n_layer_t', type=int, default=2)
parser.add_argument('--n_hidden', type=int, default=50)
parser.add_argument('--n_hidden_q', type=int, default=50)
parser.add_argument('--n_hidden_t', type=int, default=50)
parser.add_argument('--nn', type=str, default="MLP", choices=["MLP","StackMLP"])
parser.add_argument('--lambd', type=float, default=1.5)
parser.add_argument('--beta', type=float, default=1.0)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--eval_batch_size', type=int, default=4)
parser.add_argument('--eval_n_samples', type=int, default=2000)
parser.add_argument('-m','--model', type=str, default="uqpinn", choices=["uqpinn", "pinn"])
parser.add_argument("-t","--task", default="main", choices=["main", "compare"])
parser.add_argument("--targets", nargs="+", default=["pinn", "uqpinn"])
args = parser.parse_args()

if args.config is not None:
    with open(os.path.join("config", args.config + ".toml")) as f:
        config = toml.load(f)

    args.equation = config.get('equation', args.equation)
    args.eval = config.get('eval', args.eval)
    args.device = config.get('device', args.device)
    args.epoch = config.get('epoch', args.epoch)
    args.k1 = config.get('k1', args.k1)
    args.k2 = config.get('k2', args.k2)
    args.log_every_epoch = config.get('log_every_epoch', args.log_every_epoch)
    args.seed = config.get('seed', args.seed)
    args.n_collosion = config.get('n_collosion', args.n_collosion)
    args.n_boundary = config.get('n_boundary', args.n_boundary)
    args.noise = config.get('noise', args.noise)
    args.z_dim = config.get('z_dim', args.z_dim)
    args.n_layer = config.get('n_layer', args.n_layer)
    args.n_layer_q = config.get('n_layer_q', args.n_layer_q)
    args.n_layer_t = config.get('n_layer_t', args.n_layer_t)
    args.n_hidden = config.get('n_hidden', args.n_hidden)
    args.n_hidden_q = config.get('n_hidden_q', args.n_hidden_q)
    args.n_hidden_t = config.get('n_hidden_t', args.n_hidden_t)
    args.nn  = config.get('nn', args.nn)
    args.lambd = config.get('lambd', args.lambd)
    args.beta = config.get('beta', args.beta)
    args.lr = config.get('lr', args.lr)
    args.model = config.get('model', args.model)
    args.task = config.get('task', args.task)
    args.targets = config.get('targets', args.targets)

    
main(args)
