In [None]:
class PlotWrapper:
    def __init__(self, **kw):
        self.outdir_png = None
        self.outdir_pdf = None

        for k, v in kw.items():
            setattr(self, k, v)

        self.map_model = {
            'AR6 Chapter 7': 'EBM-ε AR6',
            'Smith_ea_2021': 'EBM-ε S21',
        }

    def wrap_plot(self, dataset):
        df_esm = self.df_esm.loc[dataset]
        df_4x = self.df_4x.loc[dataset]
        df_1p = self.df_1p.loc[dataset]
        d_sens = self.d_sens.loc[dataset]
        map_ls = self.map_ls
        map_vname = self.map_vname

        height = 2.5
        aspect = 1.4
        wspace = 0.9
        hspace = 0.4
        kw1 = {
            'height': height*aspect, 'aspect': 1.,
            'yoff': height*2 + hspace - height*aspect,
        }
        myplt.init_general(
            height=height, aspect=aspect, wspace=wspace, hspace=hspace,
            extend=[('bottom', -1, {}), ('right', -1, kw1)],
        )

        idx_model = df_4x.index.get_level_values('Model').unique().intersection(
            df_1p.index.get_level_values('Model').unique()
        )
        models_order = ['AR6 Chapter 7', 'Smith_ea_2021', 'MCE-2l']
        map_ls_add = dict(zip(models_order, ['-', '-.', '--']))

        time_esm = np.arange(150) + 0.5

        for n, var1 in enumerate(['rtnt', 'tas']):
            ax = myplt(n)
            ax.plot(
                time_esm,
                df_esm.loc[(var1, 'abrupt-4xCO2')].dropna().values,
                **map_ls['ESM 4x'],
            )
            ax.plot(
                time_esm[:140],
                df_esm.loc[(var1, '1pctCO2')].dropna().values,
                **map_ls['ESM 1p'],
            )

            for model in models_order:
                if model not in idx_model:
                    continue

                kw = map_ls['IRM 4x'].copy()
                kw.update(ls=map_ls_add[model])
                df_4x.loc[(model, var1), :149.5].plot(ax=ax, **kw)

                kw = map_ls['IRM 1p'].copy()
                kw.update(ls=map_ls_add[model])
                df_1p.loc[(model, var1)].plot(ax=ax, **kw)

            ax.set_ylabel('{} ({})'.format(*map_vname[var1]))
            ax.grid()

        ax.hlines( # zero-based 60-79 years
            d_sens.loc[('ESM', 'tcr')], 59., 79., **map_ls['TCR ESM'],
        )
        tp_2x = np.log(2) / np.log(1.01)
        for model in models_order:
            if model not in idx_model:
                continue

            ax.plot(tp_2x, d_sens.loc[(model, 'tcr')], **map_ls['TCR IRM'])

        ax.legend(
            [
                mpl.lines.Line2D([0], [0], **map_ls['TCR ESM']),
                mpl.lines.Line2D([0], [0], **map_ls['TCR IRM']),
            ],
            ['ESM 60–79 mean', 'TCR analytical'],
            labelspacing=0.2,
        )

        ax = myplt(2)
        var_x, var_y = 'tas', 'rtnt'
        ax.plot(
            df_esm.loc[(var_x, 'abrupt-4xCO2')].dropna().values,
            df_esm.loc[(var_y, 'abrupt-4xCO2')].dropna().values,
            label='ESM, 4x', **map_ls['ESM 4x'],
        )
        ax.plot(
            df_esm.loc[(var_x, '1pctCO2')].dropna().values,
            df_esm.loc[(var_y, '1pctCO2')].dropna().values,
            label='ESM, 1p', **map_ls['ESM 1p'],
        )

        for model in models_order:
            if model not in idx_model:
                continue

            kw = map_ls['IRM 4x'].copy()
            kw.update(ls=map_ls_add[model])
            model_l = self.map_model.get(model, model)
            ax.plot(
                df_4x.loc[(model, var_x)].values,
                df_4x.loc[(model, var_y)].values,
                label=f'{model_l}, 4x', **kw,
            )
            # ax.plot(d_sens.loc[(model, 't4x')], 0., **map_ls['t4x'])

            kw = map_ls['IRM 1p'].copy()
            kw.update(ls=map_ls_add[model])
            ax.plot(
                df_1p.loc[(model, var_x)].values,
                df_1p.loc[(model, var_y)].values,
                label=f'{model_l}, 1p', **kw,
            )

        kw = map_ls['IRM 4x'].copy()
        kw.update(color='k')
        handles = [
            mpl.lines.Line2D([0], [0], **map_ls['ESM 4x']),
            mpl.lines.Line2D([0], [0], **map_ls['ESM 1p']),
            mpl.lines.Line2D([0], [0], **map_ls['IRM 4x']),
            mpl.lines.Line2D([0], [0], **map_ls['IRM 1p']),
            mpl.patches.Patch(alpha=0, linewidth=0),
        ] + [
            mpl.lines.Line2D([0], [0], ls=map_ls_add[model], **kw)
            for model in models_order
        ]
        labels = [
            'ESM 4x',
            'ESM 1p',
            'Emulator 4x',
            'Emulator 1p',
            '',
        ] + [
            '#{} {}'.format(i+1, self.map_model.get(model, model))
            for i, model in enumerate(models_order)
        ]
        # ax.legend(labelspacing=0.2)
        ax.legend(handles, labels, labelspacing=0.2)

        ax.set_xlabel('{} ({})'.format(*map_vname[var_x]))
        ax.set_ylabel('{} ({})'.format(*map_vname[var_y]))
        ax.grid()

        myplt.panel_label(
            xy=(0., 1.),
            xytext=(-35, 0),
            ha='right', va='center',
        )

        ax = myplt(0)
        loc_upper_left = (
            ax.transAxes + ax.figure.transFigure.inverted()
        ).transform((0, 1))
        myplt.figure.text(
            loc_upper_left[0], loc_upper_left[1] + 0.02, dataset,
            ha='left', va='bottom', size='large',
        )

        if self.outdir_png is not None:
            myplt.savefig('{}/024__n-t__{}.png'.format(self.outdir_png, dataset))

        if self.outdir_pdf is not None:
            myplt.savefig('{}/024__n-t__{}.pdf'.format(self.outdir_pdf, dataset))


obj = PlotWrapper(
    df_esm=df_cmip6_norm1,
    df_4x=df_4x,
    df_1p=df_1p,
    d_sens=d_sens,
    map_ls=map_ls,
    map_vname=map_vname,
    outdir_png='./image',