In [None]:
import re
import os
from os import path
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('ieee')
import seaborn as sns
from sklearn.preprocessing import StandardScaler

In [None]:
folder_type = "discrete"
dataset = "make_circles_3"
filename= "bins10_epochs1000_arch5_lr0.1.csv"
root_path = path.dirname(os.getcwd())
filepath = path.join(root_path, "results", folder_type, dataset, filename)

In [None]:
df = pd.read_csv(filepath)
df = df.rename(columns ={"Y":"I(Y,T)", "T":"I(X,T)"})
xmax, xmin = df["I(X,T)"].max(), df["I(X,T)"].min() 
ymax, ymin = df["I(Y,T)"].max(), df["I(Y,T)"].min() 
epoch = 0

fig, ax = plt.subplots(figsize = (6,4))

g = sns.scatterplot(data=df.query(f"epoch=={epoch}"), 
                    x='I(X,T)', 
                    y="I(Y,T)", 
                    hue="layer",
                    alpha=0.8,
                    edgecolor="black",
                    palette="muted",
                    linewidth=1.2,
                    ax=ax)


def set_scatterplot_legend(ax, epoch):

    g.set_xlabel("I(X,T)", fontsize=12)
    g.set_ylabel("I(Y,T)", fontsize=12)
    g.set_title(f"Época: {epoch}", fontsize=14)

    leg = ax.legend(loc='lower right', 
                    title='Layers',
                    fontsize=12)
    # make opaque legend
    for lh in leg.legendHandles:
        fc_arr = lh.get_fc().copy()
        fc_arr[:, -1] = 1
        lh.set_fc(fc_arr)
        lh.set_alpha(1)

    return leg

leg = set_scatterplot_legend(g,epoch)

plt.xlim([xmin, xmax])
plt.ylim([ymin, ymax])
plt.tight_layout()
plt.show()

In [None]:
match = re.search("epochs(\d+).*", filename)

epochs = int(match.group(1))

filename_no_ext = path.splitext(filename)[0]

save_path = path.join(root_path, "images", folder_type, dataset, filename_no_ext)

path_creater = Path(save_path)
path_creater.mkdir(parents=True, exist_ok=True)

In [None]:
xmax, xmin = df["T"].max(), df["T"].min() 
ymax, ymin = df["Y"].max(), df["Y"].min() 

for epoch in range(int(epochs)):
    _ = plt.figure()
    sns.scatterplot(data=df.query(f"epoch=={epoch}"), x="T", y="Y", hue="layer")
    plt.title(f"Época: {epoch}")
    plt.legend(bbox_to_anchor=[1.2, 0.5])

    plt.xlim([xmin, xmax])
    plt.ylim([ymin, ymax])
    plt.tight_layout()
    plt.savefig(path.join(save_path, f"{epoch}.png"), facecolor='w')
    plt.close()