In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib import gridspec
import matplotlib.pyplot as plt
import matplotlib.tri as mtri
from tqdm import tqdm_notebook
import odr
import odr.ed_em_golf as ed
from odr.model_operational import OperationalNetwork
from odr.data_handler import Dataset

## Generate or load data

In [None]:
data = ed.create_dataset(int(1e5), multi_enc=True, correlated=True, noise_level=0.1)
data.save('multi_enc_noisy_corr_100k')

In [None]:
data = Dataset.load('multi_enc_noisy_corr_100k')

In [None]:
td, vd = data.train_val_separation(0.05)

## Train model

In [None]:
gamma = 1e-3
reg_loss = 1e-7
for i in range(100):
    model = OperationalNetwork(encoder_num=2, decoder_num=4, input_sizes=[20, 20], latent_sizes=[2, 2], 
                question_sizes=[1, 1, 1, 1], answer_sizes=[1, 1, 1, 1],
                encoder_num_units=[200, 200], name='multi_enc_corr_{}'.format(i))
    # Pretraining
    model.train(100, 256, 1e-3, td, vd, test_step=10, reg_loss_factor=reg_loss, gamma=gamma, nloc_factor=5., pretrain=True, progress_bar=tqdm_notebook)
    if model.run(vd, model.cost_nloc) > 0.1:
        continue
    model.train(400, 256, 1e-3, td, vd, test_step=10, reg_loss_factor=reg_loss, gamma=gamma, nloc_factor=5., pretrain=True, progress_bar=tqdm_notebook)
    model.save(model.name)
    
    # Train with unconstrained selection neurons
    model.train(4500, 256, 1e-3, td, vd, test_step=10, reg_loss_factor=reg_loss, gamma=gamma, progress_bar=tqdm_notebook)
    model.save(model.name)
    break

## Load pre-trained model and plot latent layer

In [None]:
model = OperationalNetwork.from_saved('multi_enc_corr')

In [None]:
def get_triangles(x_data, y_data, threshold=5.):
    tri = mtri.Triangulation(x_data, y_data)
    points = np.dstack([x_data[tri.triangles], y_data[tri.triangles]])
    edge_length = np.empty(points.shape[0])
    for i in range(len(edge_length)):
        p = points[i]
        edge_length[i] = np.linalg.norm(p[0] - p[1]) + np.linalg.norm(p[0] - p[2]) + np.linalg.norm(p[1] - p[2])
    mask = np.where(edge_length > threshold, True, False)
    tri.set_mask(mask)
    return tri


def plot_multi_enc(model, mass_data, charge_data):
    latent_mass = model.run(mass_data, model.full_latent)
    latent_charge = model.run(charge_data, model.full_latent)
    z_lim_mass = [np.min(latent_mass), np.max(latent_mass)]
    z_lim_charge = [np.min(latent_charge[:, [0, 3]]), np.max(latent_charge[:, [0, 3]])]
    
    fig = plt.figure(figsize=(12, 7))
    gs = gridspec.GridSpec(2, 4)
    ax_mass = []
    ax_charge = []
    ax_sel = []
    for latent_index in range(4):
        ax_mass.append(fig.add_subplot(gs[latent_index], projection='3d'))
        ax_charge.append(fig.add_subplot(gs[latent_index + 4], projection='3d'))
            
        # Plot dependent on mass, fixed charge   
        tri = get_triangles(mass_data.hidden_states[:, 0], mass_data.hidden_states[:, 1])
        ax_mass[latent_index].plot_trisurf(tri, latent_mass[:, latent_index], cmap=cm.inferno, vmin=z_lim_mass[0], vmax=z_lim_mass[1])
        ax_mass[latent_index].set_xlabel(r'$m_1$')
        ax_mass[latent_index].set_ylabel(r'$m_2$')
        ax_mass[latent_index].set_zlim(z_lim_mass)
        ax_mass[latent_index].azim = -45
        ax_mass[latent_index].set_xticks([0, 5, 10])
        ax_mass[latent_index].set_yticks([0, 5, 10])
        
        # Plot dependent on charge, fixed mass
        tri = get_triangles(charge_data.hidden_states[:, 2], charge_data.hidden_states[:, 3], threshold=1.)
        if latent_index==1:
            zmin = 0.5
            zmax = zmin + z_lim_charge[1] - z_lim_charge[0]
        elif latent_index==2:
            zmin = -0.7
            zmax = zmin + z_lim_charge[1] - z_lim_charge[0]
        else:
            zmin = z_lim_charge[0]
            zmax = z_lim_charge[1]
        ax_charge[latent_index].plot_trisurf(tri, latent_charge[:, latent_index], 
                                             cmap=cm.inferno, vmin=z_lim_charge[0], vmax=z_lim_charge[1])
        ax_charge[latent_index].set_xlabel(r'$q_1$')
        ax_charge[latent_index].set_ylabel(r'$q_2$')
        ax_charge[latent_index].set_zlim([zmin, zmax])
        ax_charge[latent_index].set_xticks([-1, 0, 1])
        ax_charge[latent_index].set_yticks([-1, 0, 1])
        
    fig.tight_layout()
    return fig

In [None]:
%matplotlib tk
mass_data = ed.create_dataset(1000, multi_enc=True, charge_range=[.5, .5], mass_range=[1, 10])
charge_data = ed.create_dataset(1000, multi_enc=True, mass_range=[1.5, 1.5])
fig = plot_multi_enc(model, mass_data, charge_data)