In [1]:
import matplotlib.pyplot as plt
from typing import Optional
import seaborn as sns
import functools

import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import numpy as np
import os.path as osp

In [2]:
class PlotGeneratedSuccessfully(Exception):
    def __init__(self):
        super().__init__("Plot generated successfully.")

In [30]:
# Adjusting to create separate figures for each grid with consistent line widths and correct tick labels

# Define the common parameters
line_width = 1.3
fontname = "DejaVu Serif"

# Function to plot a single grid with proper labels and consistent lines
def plot_single_grid(
        sogo_type: str,
        theoretical: bool = False,
        save: bool = False
    ):
    datadir = r"/cluster/home/vjimenez/adv_pa_new/results/dg/sogo"

    if theoretical == False:
        factors_list = ["hue", "pos"]
        xaxis_labels, yaxis_labels = ["Hue", "Position"], ["Shape", "Shape"]
        shape_labels = ["9", "4", "7"]
        hue_labels_1, hue_labels_2 = ['B', 'R', 'G'], ['Y', 'C', 'M']
        pos_labels_1, pos_labels_2 = ['UL', 'CC', 'LR'], ['LC', 'UR', 'CL']

    else:
        factors_list = ["theoretical"]
        xaxis_labels, yaxis_labels = [r"Learned $F^{L}$"], [r"Predicted $F^{P}$"]
        hue_labels_1 = [r"$F^{L}_l$", r"$F^{L}_m$", r"$F^{L}_n$"]
        shape_labels = [r"$F^{P}_j$", r"$F^{P}_k$", r"$F^{P}_i$"]
        hue_labels_2, pos_labels_1, pos_labels_2 = [], [], [] # foo

    train_positions, test_positions = point_cross_selector(sogo_type)
    for ifactor, labels_list in enumerate([
            [hue_labels_1, hue_labels_2],
            [pos_labels_1, pos_labels_2]
        ]):


        for iplot, x_labels in enumerate(labels_list):
            figure_name = f"Source {iplot + 1}" if theoretical == False else sogo_type
            fig, ax = plt.subplots(figsize=(3, 3))
            
            # Plot the markers
            for pos in train_positions:
                ax.scatter(pos[1] + 0.5, pos[0] + 0.5, marker='x', color='black', s=250, linewidth=line_width)
            for pos in test_positions:
                ax.scatter(pos[1] + 0.5, pos[0] + 0.5, marker='o', facecolors='none', edgecolors='black', s=500, linewidth=line_width)
                # ax.scatter(pos[1] + 0.5, pos[0] + 0.5, marker='.', facecolors='black', edgecolors='black', s=20, linewidth=line_width)

            # Draw the grid lines manually (both inner and outer)
            for i in range(4):
                ax.plot([0, 3], [i, i], color='black', linewidth=line_width)  # Horizontal lines
                ax.plot([i, i], [0, 3], color='black', linewidth=line_width)  # Vertical lines

            # Set ticks in the center of squares and apply labels
            ax.set_yticks(np.arange(0.5, 3.5, 1))
            if ifactor == 0:
                ax.set_xticks(np.arange(0.65, 3.65, 1))
            else:
                ax.set_xticks(np.arange(0.8, 3.8, 1))

            ax.set_xticklabels(x_labels, rotation=0, ha='right', fontsize=16, fontname=fontname)
            ax.set_yticklabels(shape_labels, fontsize=16, fontname=fontname)

            for xtick, label in zip(ax.get_xticklabels(), x_labels):
                if ifactor == 0 and theoretical == False:
                    xtick.set_color(label[0].lower())
                xtick.set_fontweight('bold')

            for ytick in ax.get_yticklabels():
                ytick.set_fontstyle('italic')

            ax.set_title(figure_name, fontname=fontname, fontsize=18)
            # Set the axis labels

            ax.set_xlabel(xaxis_labels[ifactor], fontsize=16, fontname=fontname, labelpad=10)
            ax.set_ylabel(yaxis_labels[ifactor], fontsize=16, fontname=fontname, labelpad=10, fontstyle='italic')
            
            # Remove the ticks but keep the labels
            ax.tick_params(left=False, bottom=False)
            
            # Set equal aspect ratio and limits to ensure outer lines are the same as inner lines
            ax.set_aspect('equal')
            ax.set_xlim(-0.1, 3.1)
            ax.set_ylim(-0.1, 3.1)
            
            # Hide the default axis spines completely
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['left'].set_visible(False)
            ax.spines['bottom'].set_visible(False)

            # Display the plot
            plt.tight_layout()
            if save == True:
                file_path = f"{sogo_type}_fact={factors_list[ifactor]}_env={iplot}.pdf"
                if theoretical == True:
                    file_path = f"_{sogo_type}.pdf"
                fig.savefig(osp.join(datadir, file_path), dpi=300, bbox_inches='tight')
            else:
                plt.show()
            plt.close()

            if theoretical == True:
                raise PlotGeneratedSuccessfully

# ZGO:
            
def point_cross_selector(sogo: str):
    if sogo == "ZGO":
        # ZGO
        train_positions = [(2, 0), (1, 1), (0, 2), (2, 0), (1, 1), (0, 2)]
        test_positions = [(2, 1), (2, 2), (1, 0), (1, 2), (0, 0), (0, 1)]

    elif sogo == "1-CGO":
        # CGO-1:
        train_positions = [(2, 0), (2, 1), (1, 1), (0, 2), (2, 0), (2, 1), (1, 1), (0, 2)]
        test_positions = [(2, 2), (2, 2), (1, 0), (1, 2), (0, 0), (0, 1)]

    elif sogo == "2-CGO":
        # # CGO-2:
        train_positions = [(2, 0), (2, 1), (1, 1), (0, 0), (0, 2), (2, 1), (1, 1), (0, 0)]
        test_positions = [(2, 2), (2, 2), (1, 0), (1, 2), (0, 1)]

    elif sogo == "3-CGO":
        # # CGO-3:
        train_positions = [(2, 0), (2, 1), (1, 1), (0, 0), (0, 2), (2, 0), (2, 1), (1, 1), (0, 2), (1, 0)]
        test_positions = [(2, 2), (1, 2), (0, 1)]

    elif sogo == "ZSO":
        # ZSO:
        train_positions = [(2, 0), (2, 1), (2, 2), (1, 0), (1, 1), (1, 2), (0, 0), (0, 1), (0, 2), (2, 0), (2, 1), (2, 2), (1, 0), (1, 1), (1, 2), (0, 0), (0, 1), (0, 2)]
        test_positions = [(2, 0), (2, 1), (2, 2), (2, 0), (2, 1), (2, 2), (1, 0), (1, 1), (1, 2), (1, 0), (1, 1), (1, 2), (0, 0), (0, 1), (0, 2), (0, 0), (0, 1), (0, 2)]

    else:
        raise NotImplementedError
    
    return train_positions, test_positions

In [32]:
# for sogo_type in ["ZGO", "1-CGO", "2-CGO", "3-CGO", "ZSO"]:
for sogo_type in ["3-CGO"]:
    plot_single_grid(sogo_type = sogo_type, save=True)

### Theoretical plot

In [33]:
# for sogo_type in ["ZGO", "1-CGO", "2-CGO", "3-CGO", "ZSO"]:
for sogo_type in ["3-CGO"]:
    try:
        plot_single_grid(sogo_type = sogo_type, save=True, theoretical=True)
    except PlotGeneratedSuccessfully:
        pass 

### Legend

In [158]:
import matplotlib.pyplot as plt

fontname = "DejaVu Serif"
_ = fm.findfont(fm.FontProperties(family=fontname))

# Create a figure for the legend
fig, ax = plt.subplots(figsize=(6, 6))
handles = [
    ax.scatter([], [], marker='x', color='black', s=250, linewidth=1.5),
    ax.scatter([], [], marker='o', facecolors='none', edgecolors='black', s=500, linewidth=1.5)
]
plt.close()
labels = [
    "Training, Validation & Test #0",
    "Test #1 - #5"
]

fig_legend = plt.figure(figsize=(4, 2))  # Adjust the size as needed
ax_legend = fig_legend.add_subplot(111)
ax_legend.legend(
    handles,
    labels,
    frameon=False,
    loc="center",  # Center the legend in the new figure
    handlelength=0.5,
    ncol=2,
    prop={
        "family": fontname,
        'size': 18
    } 
)
ax_legend.axis('off')
fig_legend.savefig(r"/cluster/home/vjimenez/adv_pa_new/results/dg/sogo/_legend.png", bbox_inches='tight', dpi=300)
plt.tight_layout()
plt.close()