In [8]:
import matplotlib.pyplot as plt 
from scipy import stats


plt.rcParams['text.usetex'] = True #Let TeX do the typsetting
plt.rcParams['text.latex.preamble'] = [r'\usepackage{sansmath}', r'\sansmath'] #Force sans-serif math mode (for axes labels)
plt.rcParams['font.family'] = 'sans-serif' # ... for regular text
plt.rcParams['font.sans-serif'] = 'Helvetica, Avant Garde, Computer Modern Sans serif' # Choose a nice font here

import seaborn as sns
import pandas as pd 
import numpy as np

sns.set_style("whitegrid")
markers = ['+', '.', 'x', 'o', 'v', '^', '<', '>', '1', '2', '3', '4', '8', 's', 'p', '*', 'h', 'H', 'D', 'd', '|', '_', 'P', 'X', 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 ]
def get_scatter_plot(xx,yy, labels, fc = 20, name=None): 
    fig, ax = plt.subplots()
    plots = []
    for i, (x,y) in enumerate(zip(xx,yy)):
        plots.append(plt.scatter(x=x, y=y, marker=markers[i]))
    
    
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]

    ax.plot(lims, lims, 'k--', alpha=1.0, zorder=0)

#     for i, (x,y) in enumerate(zip(xx,yy)):
#         target_line = stats.siegelslopes(y, x)
#         print(target_line)
#         target_fn = np.poly1d(target_line)
#         lin = np.linspace(lims[0],lims[1],1000)
#         plt.plot(lin, target_fn(lin), "-",alpha=0.7, linewidth=3.0)

    
    # now plot both limits against eachother
    ax.set_aspect('equal')
    ax.set_xlim(lims)
    ax.set_ylim(lims)
    plt.title(name, fontsize=fc)
    plt.legend(plots, labels, fontsize=fc-5)
    plt.xticks(fontsize=fc)
    plt.yticks(fontsize=fc)
    plt.xlabel("OOD Accuracy", fontsize=fc)
    plt.ylabel("Predicted Accuracy", fontsize=fc)
    plt.savefig("figures/" + name + "_app.pdf", transparent=True, bbox_inches='tight')
    plt.show()
    



import numpy as np
import os 
def get_acc(filename): 
    
    seeds_err = []
    for name in os.listdir(filename):
        acc_file = filename + name + "/predicted_acc/"
        if name.startswith("ResNet") or name.startswith("FCN") or name.startswith("dist"): 

            if os.path.exists(filename + name + "/predicted_acc/"): 
                acc_file = filename + name + "/predicted_acc_bin/"
            else: 
                acc_file = filename + name + "/predict_acc_T_combine/"

            # print(acc_file)

            if len(os.listdir(acc_file))>=2: 
                acc_file = acc_file + '45.csv'
            else: 
                acc_file = acc_file + os.listdir(acc_file)[0]


            xx = None
            with open(acc_file, "r") as f: 
                count = 0

                if "RxRx1" not in acc_file: 
                    f.readline()

                for i, line in enumerate(f):
                    x = []

                    count += 1 

                    vals = line.rstrip().split(",")
                    vals = [float(v) for v in vals]

                    # print(len(vals))
                    x.append(vals[0] ) 
                    x.append(vals[18]) 
                    x.append(vals[19]) 
                    x.append(vals[14]) 
                    x.append(vals[15]) 
                    x.append(vals[16]) 
                    x.append(vals[17]) 
                    x.append(vals[20]) 
                    x.append(vals[6]) 
                    x.append(vals[10]) 
                    x.append(vals[8]) 
                    x.append(vals[12]) 

                    if xx is None: 
                        xx = [x]
                    else: 
                        xx.append(x)

            seeds_err.append(xx)

    return seeds_err


filenames = [ "outputs_rebuttal/CIFAR/",
             "outputs_rebuttal/CIFAR-100/",
             "outputs_rebuttal/ImageNet-200/", 
             "outputs_rebuttal/ImageNet/", 
             "outputs_rebuttal/FMoW/", 
             "outputs_rebuttal/RxRx1/", 
             "outputs_rebuttal/Amazon/", 
             "outputs_rebuttal/CivilComments/", 
             "outputs_rebuttal/MNIST/", 
             "outputs_rebuttal/entity13/",
             "outputs_rebuttal/entity30/", 
             "outputs_rebuttal/nonliving26/", 
             "outputs_rebuttal/living17/"]

datasets = ["CIFAR10", "CIFAR100", "ImageNet200", "ImageNet", "FMoW-Wilds", "RxRx1-Wilds", "Amazon-Wilds", "CivilComments-Wilds","MNIST", "Entity13", "Entity30", "Nonliving26", "Living17"]  

methods = ["IM" , "AC", "DOC", "GDE", "ATC-MC (Ours)", "ATC-NE (Ours)"]

plot_arr = []

for idx, ff in enumerate(filenames): 
    plot_arr.append(get_acc(ff))


for i in range(len(datasets)):
    plot_arr_i = np.mean(plot_arr[i], axis=0)

    print(np.array(plot_arr_i).shape)


(96, 12)
(95, 12)
(100, 12)
(99, 12)
(12, 12)
(3, 12)
(2, 12)
(15, 12)
(3, 12)
(199, 12)
(199, 12)
(199, 12)
(199, 12)
