In [1]:
""" Evaluate all models on the time-harmonic datasets"""
import pprint
from pathlib import Path
import wandb
import matplotlib.pyplot as plt
import ml_collections
from pytorch_lightning.loggers import WandbLogger
import os
from matplotlib import colors
import torch
from wavebench.dataloaders.helmholtz_loader import get_dataloaders_helmholtz


from wavebench import wavebench_figure_path
from wavebench.nn.pl_model_wrapper import LitModel
from wavebench import wavebench_checkpoint_path
from wavebench.plot_utils import plot_images, remove_frame



all_models = [
  {
    "tag": 'fno-depth-4',
    "config.model_config/model_name": 'fno',
    "config.model_config/num_hidden_layers": 4
  },
  {
    "tag": 'fno-depth-8',
    "config.model_config/model_name": 'fno',
    "config.model_config/num_hidden_layers": 8
  },
  {
    "tag": 'unet-ch-32',
    "config.model_config/model_name": 'unet',
    "config.model_config/channel_reduction_factor": 2
  },
  {
    "tag": 'unet-ch-64',
    "config.model_config/model_name": 'unet',
    "config.model_config/channel_reduction_factor": 1
  },

              ]


# Initialize the W&B API client
api = wandb.Api()


pp = pprint.PrettyPrinter(depth=6)
device = 'cpu'

eval_config = ml_collections.ConfigDict()


# problem setting: can be 'isotropic' or 'anisotropic'
for eval_config.kernel_type in ['isotropic', 'anisotropic']:

  # frequency: can be in [1.0, 1.5, 2.0, 4.0]
  # for eval_config.frequency in [4.0]:
  for eval_config.frequency in [1.0, 1.5, 2.0, 4.0]:
    save_path = f"{wavebench_figure_path}/time_harmonic/model_out_{eval_config.kernel_type}_{eval_config.frequency}"

    if not os.path.exists(save_path):
      os.makedirs(save_path)

    test_loader = get_dataloaders_helmholtz(
      eval_config.kernel_type,
      eval_config.frequency)['test']
    model_dict = {}

    for model_filters in all_models:
      _model_filters = model_filters.copy()

      model_tag = _model_filters.pop('tag')

      project = f'helmholtz_{eval_config.kernel_type}_{eval_config.frequency}'
      runs = api.runs(
        path=f"tliu/{project}",
        filters=_model_filters)

      # make sure that there is a unique model that satisfies the filters
      assert len(runs) == 1

      run_id = runs[0].id

      checkpoint_reference = f"tliu/{project}/model-{run_id}:best"
      print(f'checkpoint: {checkpoint_reference}')

      # delete all the checkpoints that do not have the aliases such as 'best
      artifact_versions = api.artifact_versions(
        name=f'{project}/model-{run_id}', type_name='model')

      for v in artifact_versions:
        if len(v.aliases) == 0:
          v.delete()
          print(f'deleted {v.name}')
        else:
          print(f'kept {v.name}, {v.aliases}')

      artifact_dir = WandbLogger.download_artifact(
        artifact=checkpoint_reference,
        save_dir=wavebench_checkpoint_path)

      # load checkpoint
      model = LitModel.load_from_checkpoint(
        Path(artifact_dir) / "model.ckpt").to(device)

      print('model hparams:')
      pp.pprint(model.hparams.model_config)

      model_dict[model_tag] = model

    sample_input, sample_target = next(iter(test_loader))


    # plot the ground-truth data
    fig, axes = plot_images(
      [sample_input.squeeze(),
       sample_target.squeeze()[0],
       sample_target.squeeze()[1]],
      cbar='none',
      vrange='individual',
      fig_size=(9, 3),
      shrink=0.5,
      pad=0.02,
      cmap='coolwarm')

    axes[0].set_title('input')
    axes[1].set_title('target real')
    axes[2].set_title('target img')
    [remove_frame(ax) for ax in axes.flatten()]

    plt.suptitle(
      f'Ground truth. Wavespeed: {eval_config.kernel_type}, Freq: {eval_config.frequency}',
      y=0.98,
      )

    plt.savefig(
      f"{save_path}/ground_truth.pdf",
      format="pdf", bbox_inches="tight")


    # plot the predictions
    pred_dict_real = {}

    pred_dict_img = {}

    for tag, model in model_dict.items():
      pred = model(
        sample_input.to(device)).detach().cpu().squeeze()
      pred_dict_real[tag] = pred.squeeze()[0]
      pred_dict_img[tag] = pred.squeeze()[1]

    fig, axes = plot_images(
      list(pred_dict_real.values()),
      cbar='one',
      fig_size=(9, 3),
      shrink=0.5,
      pad=0.02,
      cmap='coolwarm')

    plt.suptitle(
      f'Predictions (real part). Wavespeed: {eval_config.kernel_type}, Freq: {eval_config.frequency}',
      y=0.9,
      )

    for i, ax in enumerate(axes.flatten()):
      ax.set_title( list(pred_dict_real.keys()) [i])
      remove_frame(ax)

    plt.savefig(
      f"{save_path}/real_pred.pdf",
      format="pdf", bbox_inches="tight")

    fig, axes = plot_images(
      list(pred_dict_img.values()),
      cbar='one',
      fig_size=(9, 3),
      shrink=0.5,
      pad=0.02,
      cmap='coolwarm')

    plt.suptitle(
      f'Prediction of the img part. Wavespeed: {eval_config.kernel_type}, Freq: {eval_config.frequency}',
      y=0.9,
      )

    for i, ax in enumerate(axes.flatten()):
      ax.set_title( list(pred_dict_img.keys()) [i])
      remove_frame(ax)

    plt.savefig(
      f"{save_path}/img_pred.pdf",
      format="pdf", bbox_inches="tight")


    pred_diff_dict_real = {}
    pred_diff_dict_img = {}
    # plot the error residual
    for tag, model in model_dict.items():
      pred = model(
        sample_input.to(device)).detach().cpu().squeeze()
      pred_diff_dict_real[f'{tag}_diff'] = ((
        pred.squeeze()[0] - sample_target.squeeze()[0]).abs() / sample_target.squeeze().norm())
      pred_diff_dict_img[f'{tag}_diff'] = ((
        pred.squeeze()[1] - sample_target.squeeze()[1]).abs() / sample_target.squeeze().norm())


    fig, axes = plot_images(
      list(pred_diff_dict_real.values()),
      cbar='one',
      fig_size=(9, 3),
      shrink=0.5,
      pad=0.02,
      cmap='Reds')

    # axes[0,0]
    for i, ax in enumerate(axes.flatten()):
      ax.set_title( list(pred_diff_dict_real.keys()) [i])
      remove_frame(ax)

    plt.suptitle(
      f'Real part diff Wavespeed: {eval_config.kernel_type}, Freq: {eval_config.frequency}',
      y=0.9,
      )

    plt.savefig(
      f"{save_path}/real_part_diff.pdf",
      format="pdf", bbox_inches="tight")


    fig, axes = plot_images(
      list(pred_diff_dict_img.values()),
      cbar='one',
      fig_size=(9, 3),
      shrink=0.5,
      pad=0.02,
      cmap='Reds')

    # axes[0,0]
    for i, ax in enumerate(axes.flatten()):
      ax.set_title( list(pred_diff_dict_img.keys()) [i])
      remove_frame(ax)

    plt.suptitle(
      f'Img part diff Wavespeed: {eval_config.kernel_type}, Freq: {eval_config.frequency}',
      y=0.9,
      )

    plt.savefig(
      f"{save_path}/img_part_diff.pdf",
      format="pdf", bbox_inches="tight")

checkpoint: tliu/helmholtz_isotropic_1.0/model-3vweeuk1:best
kept model-3vweeuk1:v17, ['best', 'latest']


[34m[1mwandb[0m: Downloading large artifact model-3vweeuk1:best, 192.42MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9


RuntimeError: PytorchStreamReader failed reading file data/2: invalid header or archive is corrupted