# Projection plots

In [None]:
from sklearn.mixture import GaussianMixture
from src.datamodules.gmm_datamodule import feat_label_concat
import numpy as np 

def gmm_generator(means):
    dim=len(means[0])
    num_gmm_source=len(means)
    source_gmm = GaussianMixture(n_components=num_gmm_source)
    source_gmm.weights_ = np.ones(num_gmm_source) / num_gmm_source
    source_gmm.means_ = np.array(means)*1.2
    source_gmm.covariances_ = [
        np.eye(dim) * 0.03 for _ in range(num_gmm_source)
    ]
    feature, label = source_gmm.sample(500)
    batch = feat_label_concat(feature, label)
    return batch

In [None]:
import matplotlib.pyplot as plt

class Axis_Params():
    def __init__(self, left_place, right_place, figsize=(10, 10), title='', label_font_size=15, axis_font_size=15, title_font_size=22, opacity=1, scatter_size=20, new_fig=True, bandwidth_kde=0.9, num_grid=100):
        self.left_place = left_place
        self.right_place = right_place
        self.title = title
        self.label_font_size = label_font_size
        self.axis_font_size = axis_font_size
        self.title_font_size = title_font_size
        self.figsize = figsize
        self.opacity = opacity
        self.scatter_size = scatter_size
        self.new_fig = new_fig
        self.bandwidth = bandwidth_kde
        self.num_grid = num_grid


class Axis_Params_3d(Axis_Params):
    def __init__(self, x_rotate=None, z_rotate=None, colors=None, xlabel=None, ylabel=None, zlabel=None, *kargs, **kwargs):
        super(Axis_Params_3d, self).__init__(*kargs, **kwargs)
        self.x_rotate = x_rotate
        self.z_rotate = z_rotate
        self.colors = colors
        self.xlabel = xlabel
        self.ylabel = ylabel
        self.zlabel = zlabel


def set_matplotlib_axis(ax, ax_params):
    # ax.set_xlim(ax_params.left_place, ax_params.right_place)
    # ax.set_ylim(ax_params.left_place, ax_params.right_place)
    ax.tick_params(axis='both', which='major',
                   labelsize=ax_params.axis_font_size)
    ax.set_title(ax_params.title, fontsize=ax_params.title_font_size)
    return ax

def plt_scatter_3d_alone(sample_n_3_list, ax_params,path,num_total_class,x_bound=None,y_bound=None,z_bound=None):
    fig = plt.figure(figsize=ax_params.figsize)
    ax = fig.add_subplot(111, projection='3d')
    cmap = plt.get_cmap("jet")
    for sample_n_3 in sample_n_3_list:
        ax.scatter(sample_n_3[:, 0], sample_n_3[:, 1],
                sample_n_3[:, 2], alpha=ax_params.opacity, s=ax_params.scatter_size,
                c=cmap(sample_n_3[:, -1] / num_total_class),)
    set_matplotlib_axis(ax, ax_params)
    ax.view_init(elev=ax_params.x_rotate, azim=ax_params.z_rotate)
    # you will need this line to change the Z-axis
    ax.autoscale(enable=False, axis='both')
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    ax.set_xbound(x_bound)
    ax.set_ybound(y_bound)
    ax.set_zbound(z_bound) # -1, 10.5
    ax.set_xlabel(ax_params.xlabel)
    ax.set_ylabel(ax_params.ylabel)
    ax.set_zlabel(ax_params.zlabel)
    plt.savefig(path, bbox_inches='tight')

In [None]:
# means_list=[
#     [[14, 1, 10], [15, 1, 10], [14, 0, 10], [15, 0, 10]], 
#     [[1, -6, 0], [1, -7, 0]], 
#     [[0, 5, 0], [0, 4, 0], [-1, 4.5, 0]], 
#     [[-8.8, -4, 0], [-8, -4.2, 0], [-7.2, -3.9, 0]]]
means_list=[
    [[-2, -1, 10], [-1, -1, 10], [-2, 0, 10], [-1, 0, 10]], 
    [[1, -6, 0], [1, -7, 0]], 
    [[0, 5, 0], [0, 4, 0], [-1, 4.5, 0]], 
    [[-8.8, -4, 0], [-8, -4.2, 0], [-7.2, -3.9, 0]]]
batch_list=[]
num_class_list=[]
num_total_class=0
for means in means_list:
    batch = gmm_generator(means)
    batch[:,-1]+=num_total_class
    num_class_list.append(len(means))
    num_total_class+=len(means)
    batch_list.append(batch)

# --------- plot 3d ------------
parameter = Axis_Params_3d(
    left_place=-9, right_place=6, figsize=(10,10), opacity=0.3, scatter_size=5, x_rotate=20, z_rotate=-50)

plt_scatter_3d_alone(batch_list, parameter, "train_data.png", num_total_class)

# --------- plot 2d ------------
# draw_data(
#     batch_list,
#     "train_data.png",
#     num_total_class,
# )

In [None]:
import torch 
from src.otdd.pytorch.distance import DatasetDistance
import torch.nn.functional as F

def barycentric_projection(source_ds, target_ds, batch_shape, num_target_class, device):
    dist = DatasetDistance(
        source_ds,
        target_ds,
        inner_ot_method="exact",
        inner_ot_debiased=True,
        inner_ot_entreg=1e-2,
        entreg=1e-2,
        device=device,
        λ_y=0.1,
    )

    # This coupling is calculated based on the subsampling of #[maxsamples] samples.
    _, coupling, target_feat, target_hard_label = dist.distance(
        maxsamples=batch_shape[0], return_coupling=True
    )
    coupling1 = (
        torch.nan_to_num(coupling, nan=10.0, posinf=10.0, neginf=1e-5).abs() + 1e-4
    )
    coupling2 = coupling1 / coupling1.sum(axis=1, keepdims=True)
    assert abs(coupling2.sum(axis=1).min() - 1.0) < 1e-3
    coupling = coupling2

    pf_feat = coupling @ target_feat
    # OTDD solver reshapes features to a flat vector
    pf_feat = pf_feat.reshape(batch_shape)
    # OTDD solver is shifting target_hard_labels

    target_hard_label -= target_hard_label.min()
    target_soft_labels = F.one_hot(target_hard_label, num_target_class).float()
    pf_probs = coupling @ target_soft_labels
    assert abs(target_soft_labels.sum(axis=-1).min() - 1.0) < 1e-3
    return pf_feat, pf_probs

In [None]:
# Generate pushforward data

from src.otdd.pytorch.datasets import CustomTensorDataset

pf_soft_label_list=[]
pf_hard_label_list=[]
src_ds=CustomTensorDataset([
    torch.Tensor(batch_list[0][:,:-1]).float(),
    torch.Tensor(batch_list[0][:,-1]).long()
    ])
for idx in range(1, len(batch_list)):
    trg_label = torch.Tensor(batch_list[idx][:,-1]).long()
    trg_ds=CustomTensorDataset([
        torch.Tensor(batch_list[idx][:,:-1]).float(),
        trg_label-trg_label.min()
        ])
    pf_feat, pf_probs= barycentric_projection(src_ds, trg_ds, src_ds.tensors[0].shape, trg_label.max()-trg_label.min()+1, "cpu")
    pf_soft_label_list.append((pf_feat, pf_probs))
    pf_hard_label_list.append((pf_feat, pf_probs.argmax(axis=-1).long()))

In [None]:
from src.models.loss_zoo import mse_loss,label_cost
w_inner_matrix=np.zeros([3,3])
w_inner_outer=np.zeros([3])
λ_y=0.1
source_ds=CustomTensorDataset([
    torch.Tensor(batch_list[0][:,:-1]).float(),
    torch.Tensor(batch_list[0][:,-1]).long()
    ])   

for idx_i in range(1,4):
    trg_label = torch.Tensor(batch_list[idx_i][:,-1]).long()
    trg_ds=CustomTensorDataset([
        torch.Tensor(batch_list[idx_i][:,:-1]).float(),
        trg_label-trg_label.min()
        ])
    
    dist = DatasetDistance(
        source_ds,
        trg_ds,
        inner_ot_method="exact",
        inner_ot_debiased=True,
        inner_ot_entreg=1e-2,
        entreg=1e-2,
        device="cpu",
        λ_y=λ_y,
    )
    # w_inner_outer[idx_i-1] = dist.distance()    
    _ = dist.distance()
    w2_matrix = dist.pwlabel_stats["dlabs"] / 2      
    mapped_feat1 = batch_list[0][:,:-1]
    mapped_feat2 = pf_soft_label_list[idx_i-1][0]
    mapped_labels1 = batch_list[0][:,-1]
    mapped_probs2 = pf_soft_label_list[idx_i-1][1]
    w_inner_outer[idx_i-1] = mse_loss(mapped_feat1, mapped_feat2) + λ_y*label_cost(w2_matrix, mapped_labels1, mapped_probs2)    
    print( mse_loss(mapped_feat1, mapped_feat2), label_cost(w2_matrix, mapped_labels1, mapped_probs2)    )
    for idx_j in range(1,4):
        if idx_i !=idx_j:
            trg_label = torch.Tensor(batch_list[idx_j][:,-1]).long()
            another_trg_ds = CustomTensorDataset([
                torch.Tensor(batch_list[idx_j][:,:-1]).float(),
                trg_label-trg_label.min()
                ])
            
            dist = DatasetDistance(
                trg_ds,
                another_trg_ds,
                inner_ot_method="exact",
                inner_ot_debiased=True,
                inner_ot_entreg=1e-2,
                entreg=1e-2,
                device="cpu",
                λ_y=0.1,
            )
            _ = dist.distance()
            w2_matrix = dist.pwlabel_stats["dlabs"] / 2       
                
            mapped_feat1 = pf_soft_label_list[idx_i-1][0]
            mapped_feat2 = pf_soft_label_list[idx_j-1][0]
            mapped_labels1 = pf_hard_label_list[idx_i-1][1]
            mapped_probs2 = pf_soft_label_list[idx_j-1][1]
            w_inner_matrix[idx_i-1,idx_j-1] = mse_loss(mapped_feat1, mapped_feat2)+ λ_y*label_cost(w2_matrix, mapped_labels1, mapped_probs2)

In [None]:
print(w_inner_matrix, w_inner_outer)
from src.transfer_learning.gen_geodesic import get_best_interp_param
best_weight=get_best_interp_param(
                w_inner_outer, w_inner_matrix
            )
print(best_weight)

In [None]:
# from src.scripts.suff2insuff import identity_transform
from src.transfer_learning.mix_transformation import barycenteric_map_mix
from copy import deepcopy
from torch.distributions.categorical import Categorical

def identity_transform(*args):
    return args

# barycenteric mapping directly from tuple data
# 
weights=[[0.5,0,0.5], [0,0.5,0.5], [0.65,0.35,0], list(best_weight)]
# weights=[list(best_weight)]
total_data = deepcopy(batch_list)

for weight in weights:
    mix_feat, mix_probs = barycenteric_map_mix(pf_soft_label_list,np.array(weight),num_class_list[1:],"cpu",identity_transform)
    catg = Categorical(probs=mix_probs)
    pf_labels = catg.sample().reshape(-1, 1) + num_class_list[0]
    # pf_labels=mix_probs.argmax(axis=-1)+num_class_list[0]
    pf_batch = feat_label_concat(mix_feat, pf_labels)    
    total_data.append(pf_batch)

In [None]:
parameter = Axis_Params_3d(
    left_place=-9, right_place=6, 
    figsize=(10,10), opacity=0.5, scatter_size=5, x_rotate=25, z_rotate=10)

plt_scatter_3d_alone(total_data, parameter, "gen_geodesic.png", num_total_class, x_bound=(-14,2), y_bound=(-9,6), z_bound=(-1, 10.5))

In [None]:
from src.transfer_learning.mix_transformation import mixup
projection_data = [total_data[-1]]
tuple_batch = [(batch[:,:2], (batch[:,-1]-batch[:,-1].min()).long()) for batch in batch_list[1:]]
mix_feat, mix_probs = mixup(tuple_batch, best_weight,num_class_list[1:],"cpu",identity_transform)

catg = Categorical(probs=mix_probs)
pf_labels = catg.sample().reshape(-1, 1) + num_class_list[0]
pf_batch = feat_label_concat(mix_feat, pf_labels)    
projection_data.append(pf_batch)

fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 3), facecolor="w")
cmap = plt.get_cmap("jet")
ax[0].scatter(
    batch_list[0][:, 0],
    batch_list[0][:, 1],
    color=cmap(batch_list[0][:, -1]/(batch_list[0][:, -1].max()+1)),
    alpha=0.5,
)
x_lims = (-3, -0.5)
y_lims = (-2, 0.9)
ax[0].set_xlim(x_lims)
ax[0].set_ylim(y_lims)  
# ax[0].set_yticks([])
# ax[0].axis('off')  

for idx_batch, batch in enumerate(projection_data):
    random_idxes=torch.randperm(batch.shape[0])
    batch=batch[random_idxes]   
    ax[idx_batch+1].scatter(
        batch[:, 0],
        batch[:, 1],
        color=cmap(batch[:, -1] / num_total_class),
        alpha=0.5,
    )
    ax[idx_batch+1].set_yticks([])

    ax[idx_batch+1].set_xlim(x_lims)
    ax[idx_batch+1].set_ylim(y_lims)    
    # ax[idx_batch+1].axis('off')  
plt.subplots_adjust(wspace=0.2, hspace=0)
fig.savefig("mixup_vs_projection.png", bbox_inches="tight", dpi=200)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# ---------- 2d plots ----------
def draw_data(
    data_list,
    fig_path,
    num_total_class,
    # plot_size=3,
):
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(13, 9), facecolor="w")
    cmap = plt.get_cmap("jet")
    for batch in data_list:
        idxes=torch.randperm(batch.shape[0])
        batch=batch[idxes]   
        ax.scatter(
            batch[:, 0],
            batch[:, 1],
            color=cmap(batch[:, -1] / num_total_class),
            alpha=0.1,
        )
    # lims = (-plot_size, plot_size)
    # ax.set_xlim(lims)
    # ax.set_ylim(lims)
    ax.axis('off')
    fig.savefig(fig_path, bbox_inches="tight", dpi=200)

In [None]:
draw_data(
    total_data,
    "mixup_vs_projection.png",
    num_total_class,
)

# Digits

In [None]:
prefix="../logs/test/E2X_few_shot/EMNIST/"
postfix="_195000.png"
file_list=["FMNIST/source","FMNIST/pushforward","MNIST/pushforward","USPS/pushforward","KMNIST/pushforward"]
n_col=len(file_list)
fig, axes = plt.subplots(n_col, 1, figsize=(13, 2 * n_col),dpi=200, facecolor="w")
for i, (file_name, ax) in enumerate(zip(file_list,axes.flatten())):
    im=Image.open(prefix+file_name+postfix)
    ax.imshow(im)
    if i==0:
        title = '$Q$'
    else:
        title = '$\\mathcal{T}_' + str(i) + '\\sharp Q$'
    # r'{}'.format(title)
    ax.set_ylabel(r'{}'.format(title), fontsize=20, rotation='horizontal', va="center", labelpad=45,) #, color='limegreen')    
    
for i, ax in enumerate(axes.flatten()):
    ax.get_xaxis().set_visible(False)
    ax.set_yticks([])
    
fig.tight_layout(pad=0.01)
fig.savefig("EMNIST",bbox_inches="tight")

# Ternary plots

In [None]:
from src.transfer_learning.gen_geodesic import ternary_otdd_interpolation

num_segment=6
test_dataset_list = ["MNIST", "EMNIST", "KMNIST", "USPS", "FMNIST",]
train_datasets_list = [
    ["EMNIST", "FMNIST", "USPS"],
    ["MNIST", "FMNIST", "USPS"],
    ["MNIST", "EMNIST", "USPS"],
    ["MNIST", "EMNIST", "KMNIST"],
    ["MNIST", "KMNIST", "USPS"],
    ]
min_distance=1000 
max_distance=0
range_distance=0

for test_dataset, train_datasets in zip(test_dataset_list, train_datasets_list):
    target_alias = "".join(ds[0] for ds in train_datasets)    
    otdd_path= f"../logs/otdd_ternary_transport_metric/external_{test_dataset}/from_{test_dataset}2{target_alias}_repeat5.pth"
    otdd_stat= torch.load(otdd_path)
    w2_vector_dict = otdd_stat["W(nu,mu_i)"]
    w2_matrix_dict = otdd_stat["W(mu_i,mu_j)"]
    
    avg_w2_matrix = sum(w2_matrix_dict.values()) / len(w2_matrix_dict)
    avg_w2_vector = sum(w2_vector_dict.values()) / len(w2_vector_dict)

    otdd_distance = ternary_otdd_interpolation(
        avg_w2_vector, avg_w2_matrix, num_segment
    )
    
    range_distance=max(range_distance,max(otdd_distance.values()) - min(otdd_distance.values()))    
    min_distance=min(min_distance,min(otdd_distance.values()))
    max_distance=max(max_distance,max(otdd_distance.values()))
print(min_distance,max_distance,range_distance)

In [None]:
min_accuracy=100
max_accuracy=0
range_accuracy=0

method="otdd_map"
for test_dataset, train_datasets in zip(test_dataset_list, train_datasets_list):
    target_alias = "".join(ds[0] for ds in train_datasets)    
    accuracy_save_path = f"../logs/generalized_geodesic/fine_tune_{test_dataset}/{method}/run/train_on_{target_alias}_epoch100_repeat5.pt"
    accuracy = torch.load(accuracy_save_path)
    
    range_accuracy=max(range_accuracy,max(accuracy.values()) - min(accuracy.values()))
    min_accuracy=min(min_accuracy,min(accuracy.values()))
    max_accuracy=max(max_accuracy,max(accuracy.values()))
print(min_accuracy,max_accuracy, range_accuracy)    

In [None]:
import ternary
import matplotlib

matplotlib.rcParams["figure.dpi"] = 200

def draw_ternary_heatmap(accuracies, num_segment, fig_path, train_ds, title=None, vmin=0, vmax=100, figsize=(4,4), coloarbar=True, cmap='viridis'):
    fontsize = 10
    matplotlib.rcParams["figure.figsize"] = figsize

    _, tax = ternary.figure(scale=num_segment)
    # print(accuracies)
    cb_kwargs = {"shrink": 0.8, "pad": 0.01, "aspect": 30, "orientation": "horizontal"}
    
    tax.heatmap(accuracies, style="t", colorbar=coloarbar, cmap=cmap, cb_kwargs=cb_kwargs, vmin=vmin, vmax=vmax)
    tax.boundary()

    tax.right_corner_label(train_ds[0], fontsize=fontsize, position=(1.1,0.1,0))
    tax.top_corner_label(train_ds[1], fontsize=fontsize)
    tax.left_corner_label(train_ds[2], fontsize=fontsize, position=(-0.2,0.1,0))
    if title is not None:
        tax.set_title(
            title,
            y=1.2,
            pad=-14,
        )
    tax.clear_matplotlib_ticks()
    tax.get_axes().axis("off")
    # tax.set_axis_limits({'b': [67, 76], 'l': [24, 33], 'r': [0, 9]})
    
    tax.savefig(
        fig_path,
        facecolor="w",
        bbox_inches="tight",
    )

In [None]:
from src.transfer_learning.gen_geodesic import ternary_otdd_interpolation

for idx, (test_dataset, train_datasets) in enumerate(zip(test_dataset_list, train_datasets_list)):
    target_alias = "".join(ds[0] for ds in train_datasets)    
    accuracy_save_path = f"../logs/generalized_geodesic/fine_tune_{test_dataset}/{method}/run/train_on_{target_alias}_epoch100_repeat5.pt"
    accuracy = torch.load(accuracy_save_path)
    
    save_path = f"ternary/{test_dataset}_acc.png"
    local_min_acc=min(accuracy.values())
    local_max_acc=max(accuracy.values())
    local_range=local_max_acc-local_min_acc
    vmin=local_min_acc-(range_accuracy-local_range)*0.5
    vmax=local_max_acc+(range_accuracy-local_range)*0.5
    figsize=(2,2.2)
    if idx< len(test_dataset_list)-1:
        draw_ternary_heatmap(
            accuracy,
            num_segment,
            save_path,
            train_datasets,
            vmin=vmin,
            vmax=vmax,
            figsize=figsize,
        )
    else:
        draw_ternary_heatmap(
            accuracy,
            num_segment,
            save_path,
            train_datasets,
            vmin=vmin,
            vmax=vmax,
            figsize=figsize,
        )        

In [None]:
from src.transfer_learning.gen_geodesic import ternary_otdd_interpolation

for idx, (test_dataset, train_datasets) in enumerate(zip(test_dataset_list, train_datasets_list)):
    target_alias = "".join(ds[0] for ds in train_datasets)    
    otdd_path= f"../logs/otdd_ternary_transport_metric/external_{test_dataset}/from_{test_dataset}2{target_alias}_repeat5.pth"
    otdd_stat= torch.load(otdd_path)
    w2_vector_dict = otdd_stat["W(nu,mu_i)"]
    w2_matrix_dict = otdd_stat["W(mu_i,mu_j)"]
    
    otdd_dist_dict=[]
    for key in w2_matrix_dict.keys():
        otdd_dist_dict.append(ternary_otdd_interpolation(
        w2_vector_dict[key], w2_matrix_dict[key], num_segment
    ))
    otdd_distance={}
    for key in otdd_dist_dict[0].keys():
        otdd_list=[dictionary[key] for dictionary in otdd_dist_dict]
        otdd_distance[key]=sum(otdd_list)/len(otdd_list)

    # avg_w2_matrix = sum(w2_matrix_dict.values()) / len(w2_matrix_dict)
    # avg_w2_vector = sum(w2_vector_dict.values()) / len(w2_vector_dict)

    # otdd_distance = ternary_otdd_interpolation(
    #     avg_w2_vector, avg_w2_matrix, num_segment
    # )
    
    local_min_dist=min(otdd_distance.values())
    local_max_dist=max(otdd_distance.values())
    local_range=local_max_dist-local_min_dist    
    vmin=local_min_dist-(range_distance-local_range)*0.5
    vmax=local_max_dist+(range_distance-local_range)*0.5
    
    save_path = f"ternary/{test_dataset}_otdd.png"
        
    if idx< len(test_dataset_list)-1:
        draw_ternary_heatmap(
            otdd_distance,
            num_segment,
            save_path,
            train_datasets,
            vmin=min_distance,
            vmax=max_distance,
            figsize=figsize,
            cmap='viridis_r',
        )
    else:
        draw_ternary_heatmap(
            otdd_distance,
            num_segment,
            save_path,
            train_datasets,
            vmin=min_distance,
            vmax=max_distance,
            figsize=figsize,        
            cmap='viridis_r',   
        )        

In [None]:
from PIL import Image
prefix="ternary/"
postfix=".png"
dataset_list = ["MNIST", "EMNIST", "USPS", "FMNIST", "KMNIST"]
# MNIST, EMNIST, USPS, FMNIST, KMNIST
fig, axes = plt.subplots(2, 5, figsize=(14, 4.5),dpi=200, facecolor="w")
fontsize=15

for i, (ds_name) in enumerate(dataset_list):
    im=Image.open(prefix+ds_name+"_otdd"+postfix)
    axes[0,i].imshow(im)    
    im=Image.open(prefix+ds_name+"_acc"+postfix)
    axes[1,i].imshow(im)      
    axes[1,i].set_xlabel(ds_name, fontsize=fontsize, rotation='horizontal', va="center", labelpad=15)
    
for ax in axes.flatten():
    # ax.get_xaxis().set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])
    # ax.axis('off')     
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)    
     
label_pad=10

# axes[0,0].set_ylabel(r'$\mathcal{W}_{2,Q}^2(P_a,Q)$', fontsize=20, rotation='vertical', va="center", labelpad=label_pad,)
axes[0,0].set_ylabel(r'$\mathcal{W}^2(P_a,Q)$', fontsize=fontsize, rotation='vertical', va="center", labelpad=label_pad,)
axes[1,0].set_ylabel("Test on      Accuracy", fontsize=fontsize, rotation='vertical', va="center", labelpad=label_pad,)
axes[1,0].yaxis.set_label_coords(-0.05,0.25)

plt.subplots_adjust(wspace=0.1, hspace=0)          
fig.tight_layout(pad=0.01)
fig.savefig("ternary",bbox_inches="tight")