Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about the plotting of relability diagrams #21

Closed
Codefmeister opened this issue Aug 31, 2022 · 5 comments
Closed

Questions about the plotting of relability diagrams #21

Codefmeister opened this issue Aug 31, 2022 · 5 comments

Comments

@Codefmeister
Copy link

Hello, Thanks for your great code.
But while plotting the relability diagram according to your paper, i met some problems.
The sticks of my plotting are in a huddle.
Could you plz give the plotting code for reference?
Thanks!
image
image

@wjmaddox
Copy link
Owner

@izmailovpavel may have the notebook still, but try plotting on a log scale for x. Also double check that your signs are correct as they potentially could be flipped.

@Codefmeister
Copy link
Author

Thank you! I have tried log scale but it seems a little bit strange, maybe I should define a proper transformation for the xsticks.
And I will be extremely appreciated if gentleman @izmailovpavel could provide some clues for reproducing this beautiful figure.
Thanks for your kindness.

@izmailovpavel
Copy link
Collaborator

izmailovpavel commented Sep 2, 2022

Hey @Codefmeister, something seems strange in how your xticks are arranged. Here's our code for making the plots

styles = {name: (label, color) for (name, label, _, color) in new_methods().name_marker_pairs}

methods = {'SWAG-Cov', 'SWA-temp', 'SWA-Drop', 'SGD', 'SWAG-Diag', 'Laplace-SGD', 'SGLD'}

from matplotlib.ticker import FormatStrFormatter

class CustomScale(mscale.ScaleBase):
    name = 'custom'
    eps = 0.002

    def __init__(self, axis, **kwargs):
        mscale.ScaleBase.__init__(self)
        self.thresh = None #thresh

    def get_transform(self):
        return self.CustomTransform(self.thresh)

    def set_default_locators_and_formatters(self, axis):
        pass

    class CustomTransform(mtransforms.Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True        
        

        def __init__(self, thresh):
            mtransforms.Transform.__init__(self)
            self.thresh = thresh
            

        def transform_non_affine(self, a):
            return -np.log(1 + CustomScale.eps - a)

        def inverted(self):
            return CustomScale.InvertedCustomTransform(self.thresh)
    
    class InvertedCustomTransform(mtransforms.Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True

        def __init__(self, thresh):
            mtransforms.Transform.__init__(self)
            self.thresh = thresh

        def transform_non_affine(self, a):
            return 1 + CustomScale.eps - np.exp(-a)

        def inverted(self):
            return CustomScale.CustomTransform(self.thresh)
mscale.register_scale(CustomScale)

fig, axes = plt.subplots(figsize=(37, 8), nrows=1, ncols=4)
plt.subplots_adjust(wspace=0.3, bottom=0.25)

def calibration_plot(results, ds, model):    
    for method, curve in sorted(results.items()):
        #print(method, 'YN'[int(curve is None)])
        if method not in methods:
            continue
        label, color = styles[method]
        if curve is not None:        
            plt.plot(curve['confidence'], curve['confidence'] - curve['accuracy'], linewidth=4, marker='o', markersize=8, 
                    color=color, label='%s' % (label), zorder=3)         
    plt.plot(np.linspace(0.1, 1.0, 100), np.zeros(100), 'k--', dashes=(5, 5), linewidth=3, zorder=2)


    plt.gca().set_xscale('custom')

    
    ticks = 1.0 - np.logspace(np.log(0.8), np.log(0.002), 6, base=np.e)
    plt.xticks(ticks, fontsize=22)    
    plt.yticks(fontsize=22)
    plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%.3f'))
    plt.margins(x=0.03)
    plt.ylabel('Confidence - Accuracy', fontsize=28)
    plt.xlabel('Confidence (max prob)', fontsize=28)
    plt.title('%s %s' % (model, ds), fontsize=28, y=1.02)
    plt.grid()
    
plt.sca(axes[0])
calibration_plot(load_dict('./data/calibrations/c100_wrn_new.pkl'), 'CIFAR-100', 'WideResNet28x10')
plt.sca(axes[1])
calibration_plot(load_dict('./data/calibrations/stl_wrn.pkl'), 'CIFAR-10 $\\rightarrow$ STL-10', 'WideResNet28x10')
plt.sca(axes[2])
calibration_plot(load_dict('./data/calibrations/imagenet_densenet161.pkl'), 'ImageNet', 'DenseNet-161')
plt.sca(axes[3])
calibration_plot(load_dict('./data/calibrations/imagenet_resnet152.pkl'), 'ImageNet', 'ResNet-152')


#plt.sca(axes[1])


handles, labels = axes[0].get_legend_handles_labels()
leg = plt.figlegend(handles, labels, fontsize=28, loc='lower center', bbox_to_anchor=(0.43, 0.0), ncol=6)
for legobj in leg.legendHandles:
    legobj.set_linewidth(6.0)
    legobj._legmarker.set_markersize(12.0)

plt.savefig('./pics/calibration_curves.pdf', format='pdf', bbox_inches='tight')
plt.show()

It was originally written by @timgaripov.

@izmailovpavel
Copy link
Collaborator

For another paper, I used this code to plot the calibration curves, which is a lot simpler:

plt.figure(figsize=(3, 3))
def plot_calibration(arr):
    plt.plot(arr["confidence"], arr["accuracy"] - arr["confidence"], 
             "-o", color=arr["color"], mec="k", ms=7, lw=3)

# plot_calibration({**matt_arr["deep_ensemble_calibration"], "color": de_color})
plot_calibration({**new_calibration_arr["deep_ensemble"].item(), "color": de_color})
plot_calibration({**new_calibration_arr["sgld"].item(), "color": sgld_color})
plot_calibration({**new_calibration_arr["sgld_mom_clr_prec"].item(), "color": sgld_hot_color})
plot_calibration({**matt_arr["hmc_calibration"], "color": "orange"})
# plot_calibration({**matt_arr["sgld_calibration"], "color": sgld_color})
# plot_calibration({**matt_arr["sgld_hot_calibration"], "color": sgld_hot_color})
plt.hlines(0., 0., 1., color="k", linestyle="dashed")
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlabel("Confidence", fontsize=16)
plt.ylabel("Accuracy - Confidence", fontsize=16)
plt.grid()
plt.xlim(0.35, 1.05)
plt.savefig("calibration_curve.pdf", bbox_inches="tight")

@Codefmeister
Copy link
Author

Thanks for your kindness. It's very helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants