In [319]:
def plot_contrastive_train_accuracy(resultfile, top5 = False, savename = None):
    history = pickle.load(open(resultfile, "rb"))

    
    if top5:
        c_acc = history["contrastive_accuracy"]
        p_acc = history["probe_accuracy"]
        val_p_acc = history["val_probe_accuracy"]
    else:
        c_acc = history["contrastive_accuracy_top-5"]
        p_acc = history["probe_accuracy_top-5"]
        val_p_acc = history["val_probe_accuracy_top-5"]

    
    plt.plot(c_acc, color = "tomato", label = "training contrastive accuracy")
    plt.plot(p_acc, color = "violet", label = "training probe accuracy")
    plt.plot(val_p_acc, color = "teal", label = "validation probe accuracy")


    plt.xlabel("# of epochs", fontsize = 12)
    
    if top5:
        plt.ylabel("Top-5 accuracy", fontsize = 12)
    else:
        plt.ylabel("Accuracy", fontsize = 12)

    plt.xlim(0, len(c_acc) - 1)
    plt.ylim(0, 1.1)
    
    plt.yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1])

    xticks = np.linspace(0, 
                     len(c_acc) - 1, 
                     num = len(c_acc), 
                     endpoint = True, 
                     dtype = int)
    xlabels = xticks

    plt.xticks(xticks, xlabels)

    plt.legend( #loc = 'upper center',
            fontsize = 12,
            ncol = 1, frameon = False)

    
    if not os.path.exists("results"):
        os.makedirs("results")
        
    if not savename:
        figpath = "results" + os.path.sep + savename
        plt.savefig(figpath, 
                    dpi=600, 
                    format='pdf')
        
        print("saved to: ", figpath)
    
    
    if top5:
        plt.title("The training and validation top-5 accuracy losses w.r.t. epochs")
    else:
        plt.title("The training and validation accuracy losses w.r.t. epochs")


    plt.show()

In [320]:
def plot_contrastive_train_loss(resultfile):
    history = pickle.load(open(resultfile, "rb"))

    c_loss = history["contrastive_loss"]
    p_loss = history["probe_loss"]
    val_p_loss = history["val_probe_loss"]


    c_loss = np.array(c_loss) / max(c_loss)
    p_loss = np.array(p_loss) / max(p_loss)
    val_p_loss = np.array(val_p_loss) / max(val_p_loss)

    plt.title("The training and validation losses w.r.t. epochs")

    plt.plot(c_loss, color = "blue", label = "contrastive")
    plt.plot(p_loss, color = "coral", label = "training")
    plt.plot(val_p_loss, color = "brown", label = "validation")


    plt.xlabel("# of epochs", fontsize = 12)
    plt.ylabel("Loss", fontsize = 12)

    plt.xlim(0, len(c_loss) - 1)
    plt.ylim(0, 1.1)

    plt.yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1])
    
    xticks = np.linspace(0, 
                     len(c_loss) - 1, 
                     num = len(c_loss), 
                     endpoint = True, 
                     dtype = int)
    xlabels = xticks

    plt.xticks(xticks, xlabels)

    plt.legend( #loc = 'upper center',
            fontsize = 12,
            ncol = 1, frameon = False)

    plt.show()

In [309]:
import pickle
import matplotlib.pyplot as plt
#from matplotlib import cm
#from matplotlib.colors import ListedColormap, LinearSegmentedColormap

def plot_supervised_train_accuracy(resultfile, top5 = False, savename = None):
    
    if not os.path.exists(resultfile):
        return
    
    history = pickle.load(open(resultfile, "rb"))

    if top5:
        val_accs = history["val_accuracy_top-5"]
        train_accs = history["accuracy_top-5"]
    else:
        val_accs = history["val_accuracy"]
        train_accs = history["accuracy"]
    
   
    # https://matplotlib.org/3.1.1/tutorials/colors/colormap-manipulation.html
    #viridis = cm.get_cmap('viridis', 2)
    #newcolors = viridis(np.linspace(0, 1, 2))
    #pink = np.array([248/256, 24/256, 148/256, 1]
    #newcolors[:25, :] = pink
    #newcmp = ListedColormap(newcolors)

    plt.plot(val_accs, color = "blue", label = "validation")
    plt.plot(train_accs, color = "coral", label = "training")
    
    plt.xlabel("# of epochs", fontsize = 12)
    
    if top5:
        plt.ylabel("Top-5 accuracy", fontsize = 12)
    else:
        plt.ylabel("Accuracy", fontsize = 12)

    plt.xlim(0, len(val_accs) - 1)
    plt.ylim(0, 1.1)

    xticks = np.linspace(0, 
                     len(val_accs) - 1, 
                     num = len(val_accs), 
                     endpoint = True, 
                     dtype = int)
    xlabels = xticks

    plt.xticks(xticks, xlabels)
    plt.yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1])

    plt.grid()
    
    plt.legend( #loc = 'upper center',
            fontsize = 12,
            ncol = 2, frameon = False)

    
    
    if not os.path.exists("results"):
        os.makedirs("results")
        
    if not savename:
        figpath = "results" + os.path.sep + savename
        plt.savefig(figpath, 
                    dpi=600, 
                    format='pdf')
        
        print("saved to: ", figpath)
    
    
    if top5:
        plt.title("The training validation top-5 accuracy w.r.t. epochs")
    else:
        plt.title("The training validation accuracy w.r.t. epochs")


    plt.show()

In [310]:
def plot_supervised_train_loss(resultfile):
    
    if not os.path.exists(resultfile):
        return
    
    history = pickle.load(open(resultfile, "rb"))

    val_loss = history["val_loss"]
    train_loss = history["loss"]

    val_loss = np.array(val_loss) / max(val_loss)
    train_loss = np.array(train_loss) / max(train_loss)

    plt.title("The training validation losses w.r.t. epochs")

    plt.plot(val_loss, color = "tomato", label = "validation")
    plt.plot(train_loss, color = "teal", label = "training")

    plt.xlabel("# of epochs", fontsize = 12)
    plt.ylabel("Loss", fontsize = 12)

    plt.xlim(0, len(val_loss) - 1)
    plt.ylim(0, 1.1)

    xticks = np.linspace(0, 
                     len(val_loss) - 1, 
                     num = len(val_loss), 
                     endpoint = True, 
                     dtype = int)
    xlabels = xticks

    plt.xticks(xticks, xlabels)
    plt.yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1])
    
    plt.legend( #loc = 'upper center',
            fontsize = 12,
            ncol = 1, frameon = False)

    plt.show()