In [None]:
"""
Normalized mutual information
"""
import tensorflow as tf
import numpy as np
from GeneralTools.graph_func import MySession

num_x = 3000

p_c_on_x = tf.tile(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32), [1000, 1])
p_c_on_x = p_c_on_x + 1e-6
# p_c_on_x = tf.random_uniform((num_x, 3), minval=1e-4, maxval=1.0)
p_c_on_x = p_c_on_x / tf.reduce_sum(p_c_on_x, axis=1, keepdims=True)
# p_y_on_x = tf.random_uniform((num_x, 2), minval=1e-4, maxval=1.0)
# p_y_on_x = p_y_on_x / tf.reduce_sum(p_y_on_x, axis=1, keepdims=True)
p_y_on_x = p_c_on_x

p_y = tf.reduce_sum(p_y_on_x, axis=0, keepdims=True) / num_x  # 1-by-num_y
h_y = -tf.reduce_sum(p_y * tf.math.log(p_y))
p_c = tf.reduce_sum(p_c_on_x, axis=0) / num_x  # 1-by-num_c
h_c = -tf.reduce_sum(p_c * tf.math.log(p_c))
p_x_on_y = p_y_on_x / num_x / p_y  # num_x-by-num_y
p_c_on_y = tf.matmul(p_c_on_x, p_x_on_y, transpose_a=True)  # num_c-by-num_y
h_c_on_y = -tf.reduce_sum(tf.reduce_sum(p_c_on_y * tf.math.log(p_c_on_y), axis=0) * p_y)
i_y_c = h_c - h_c_on_y
nmi = 2 * i_y_c / (h_y + h_c)

with MySession() as sess:
    nmi, h_y, h_c, h_c_on_y = sess.run_once([nmi, h_y, h_c, h_c_on_y])
    
print(nmi)
print([h_y, h_c, h_c_on_y])

In [None]:
"""
Illustration of MMD loss
"""
import numpy as np
import matplotlib.pyplot as plt

do_plot = True

# data
xr = 0.15*np.random.randn(2, 6)  # N-by-1
xg = 0.15*np.random.randn(2, 6)


In [85]:
# plot
if do_plot:
    fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8))
    plt.scatter(
        xr[0], xr[1], marker='o', c='tab:gray', 
        s=40, linewidths=20, alpha=0.5)
    plt.scatter(
        xg[0], xg[1], marker='o', c='tab:red', 
        s=40, linewidths=20, alpha=0.5)
    ax.legend(
        ['R', 'G'], frameon=True, fontsize=15, labelspacing=1, borderpad=0.5)
    plt.show()

In [141]:
import tensorflow as tf
from GeneralTools.math_func import matrix_mean_wo_diagonal
from GeneralTools.graph_func import MySession

max_step = 450
lr = 1e-2
query_step = 30
do_plot = True
mmd_plan = False
folder = '/home/richard/Pictures/Tensorflow Screenshots/animation_mmd/'
do_save = False


def pd2(m1, m2):
    """ squared pair-wise distance
    
    :param m1: 2-by-N1
    :param m2: 2-by-N2
    :return: 
    """
    aa = tf.reduce_sum(tf.multiply(m1, m1), axis=0, keepdims=True)  # 1-by-N1
    bb = tf.reduce_sum(tf.multiply(m2, m2), axis=0, keepdims=True)  # 1-by-N2
    ab = tf.matmul(m1, m2, transpose_a=True)  # N1-by-N2
    
    return tf.clip_by_value(
        tf.transpose(aa, perm=(1, 0)) + bb - 2.0*ab, 
        clip_value_min=0.0, clip_value_max=10000.0)


def kernel(m, sigma=1.0):
    return tf.exp(-m/sigma)  # N1-by-N2


def e_kernel(m):
    ms = tf.cast(m.get_shape().as_list(), tf.float32)
    return matrix_mean_wo_diagonal(m, ms[0], ms[1])


with tf.Graph().as_default():
    xr_tf = tf.Variable(xr, name='r', dtype=tf.float32)
    xg_tf = tf.Variable(xg, name='g', dtype=tf.float32)
    
    d2r = pd2(xr_tf, xr_tf)
    d2g = pd2(xg_tf, xg_tf)
    d2rg = pd2(xr_tf, xg_tf)
    kr = kernel(d2r, sigma=1.0)
    kg = kernel(d2g, sigma=1.0)
    krg = kernel(d2rg, sigma=1.0)
    
    if mmd_plan:
        mmd_att = e_kernel(kr) + e_kernel(kg)
        mmd_rep = - 2*e_kernel(krg)
        
        gr_att = tf.gradients(mmd_att, xr_tf)[0]
        gg_att = tf.gradients(mmd_att, xg_tf)[0]
        gr_rep = tf.gradients(mmd_rep, xr_tf)[0]
        gg_rep = tf.gradients(mmd_rep, xg_tf)[0]
    else:
        mmd_att = e_kernel(kg)
        mmd_rep = - e_kernel(kr)
        
        gr_att = tf.zeros([2, 6])
        gg_att = tf.gradients(mmd_att, xg_tf)[0]
        gr_rep = tf.gradients(mmd_rep, xr_tf)[0]
        gg_rep = tf.zeros([2, 6])
    
    mmd = mmd_att + mmd_rep
    gr = tf.gradients(mmd, xr_tf)[0]
    gg = tf.gradients(mmd, xg_tf)[0]
    lr = tf.constant(lr, dtype=tf.float32)
    opr = tf.assign(xr_tf, xr_tf + gr*lr)
    opg = tf.assign(xg_tf, xg_tf + gg*lr)
    
    sess = tf.Session()
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init_op)
    
    fig_index = 0
    for step in range(max_step):
        _, _, mmd_np, xr_np, xg_np = sess.run(
            [opr, opg, mmd, xr_tf, xg_tf])
        
        if step % query_step == 0:
            print('step {}, mmd {}'.format(step, mmd_np))
            if do_plot:
                fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8))
                plt.scatter(
                    xr_np[0], xr_np[1], marker='o', c='tab:gray', 
                    s=40, linewidths=20, alpha=0.5)
                plt.scatter(
                    xg_np[0], xg_np[1], marker='o', c='tab:red', 
                    s=40, linewidths=20, alpha=0.5)
                ax.legend(
                    ['R', 'G'], frameon=True, fontsize=15, labelspacing=1, borderpad=0.5)
                
                gr_att_np, gg_att_np, gr_rep_np, gg_rep_np = sess.run(
                    [gr_att/2.0, gg_att/2.0, gr_rep/2.0, gg_rep/2.0])
                for i in range(xr_np.shape[1]):
                    if mmd_plan:
                        plt.arrow(
                            xr_np[0, i], xr_np[1, i], gr_att_np[0, i], gr_att_np[1, i], 
                            color='tab:blue', width=0.01)
                    plt.arrow(
                        xr_np[0, i], xr_np[1, i], gr_rep_np[0, i], gr_rep_np[1, i], 
                        color='tab:orange', width=0.01)
                for i in range(xg_np.shape[1]):
                    plt.arrow(
                        xg_np[0, i], xg_np[1, i], gg_att_np[0, i], gg_att_np[1, i], 
                        color='tab:blue', width=0.01)
                    if mmd_plan:
                        plt.arrow(
                            xg_np[0, i], xg_np[1, i], gg_rep_np[0, i], gg_rep_np[1, i], 
                            color='tab:orange', width=0.01)
                
                _, _, _, _ = plt.axis([-1.0, 1.0, -1.0, 1.0])
                
                if do_save:
                    fig_index = fig_index+1
                    if mmd_plan:
                        figurename = 'mmd_att_{:03d}.png'.format(fig_index)
                    else:
                        figurename = 'mmd_rep_{:03d}.png'.format(fig_index)
                    plt.savefig(
                        folder + figurename, format='png', bbox_inches='tight')
                else:
                    plt.show()
        
    sess.close()
    plt.close('all')

step 0, mmd -0.03920328617095947
step 30, mmd 0.025973975658416748


step 60, mmd 0.09224051237106323
step 90, mmd 0.16177546977996826


step 120, mmd 0.23458749055862427


step 150, mmd 0.3084554672241211
step 180, mmd 0.38008588552474976


step 210, mmd 0.446527898311615
step 240, mmd 0.5059857368469238


step 270, mmd 0.557883620262146
step 300, mmd 0.6025186777114868


step 330, mmd 0.6406514644622803
step 360, mmd 0.6732016801834106


step 390, mmd 0.7010701894760132
step 420, mmd 0.725059986114502


In [142]:
max_step = 1200
lr = 1e-2
query_step = 10
folder = '/home/richard/Pictures/Tensorflow Screenshots/animation_mmd/'
do_save = True

with tf.Graph().as_default():
    xr_tf = tf.Variable(xr_np, name='r', dtype=tf.float32)
    xg_tf = tf.Variable(xg_np, name='g', dtype=tf.float32)
    
    d2r = pd2(xr_tf, xr_tf)
    d2g = pd2(xg_tf, xg_tf)
    d2rg = pd2(xr_tf, xg_tf)
    kr = kernel(d2r, sigma=1.0)
    kg = kernel(d2g, sigma=1.0)
    krg = kernel(d2rg, sigma=1.0)
    
    mmd_att = e_kernel(kr) + e_kernel(kg)
    mmd_rep = - 2*e_kernel(krg)
    
    gg_att = tf.gradients(mmd_att, xg_tf)[0]
    gg_rep = tf.gradients(mmd_rep, xg_tf)[0]
    
    mmd = mmd_att + mmd_rep
    gg = tf.gradients(mmd, xg_tf)[0]
    lr = tf.constant(lr, dtype=tf.float32)
    opg = tf.assign(xg_tf, xg_tf - gg*lr)
    
    sess = tf.Session()
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init_op)
    
    fig_index = 0
    for step in range(max_step):
        _, mmd_np2, xr_np2, xg_np2 = sess.run(
            [opg, mmd, xr_tf, xg_tf])
        
        if step % query_step == 0:
            print('step {}, mmd {}'.format(step, mmd_np2))
            if do_plot:
                fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8))
                plt.scatter(
                    xr_np2[0], xr_np2[1], marker='o', c='tab:gray', 
                    s=40, linewidths=20, alpha=0.5)
                plt.scatter(
                    xg_np2[0], xg_np2[1], marker='o', c='tab:red', 
                    s=40, linewidths=20, alpha=0.5)
                ax.legend(
                    ['R', 'G'], loc='upper left', frameon=True, fontsize=15, labelspacing=1, borderpad=0.5)
                
                gg_att_np, gg_rep_np = sess.run(
                    [gg_att/2.0, gg_rep/2.0])
                for i in range(xg_np2.shape[1]):
                    plt.arrow(
                        xg_np2[0, i], xg_np2[1, i], -gg_att_np[0, i], -gg_att_np[1, i], 
                        color='tab:orange', width=0.01)
                    plt.arrow(
                        xg_np2[0, i], xg_np2[1, i], -gg_rep_np[0, i], -gg_rep_np[1, i], 
                        color='tab:blue', width=0.01)
                
                _, _, _, _ = plt.axis([-1.0, 1.0, -1.0, 1.0])
                
                if do_save:
                    fig_index = fig_index+1
                    if mmd_plan:
                        figurename = 'g_mmd_att_{:03d}.png'.format(fig_index)
                    else:
                        figurename = 'g_mmd_rep_{:03d}.png'.format(fig_index)
                    plt.savefig(
                        folder + figurename, format='png', bbox_inches='tight')
                else:
                    plt.show()
                    
    sess.close()
    plt.close('all')
    

step 0, mmd 0.2750566601753235


step 10, mmd 0.273040235042572
step 20, mmd 0.2708194851875305


step 30, mmd 0.26836615800857544
step 40, mmd 0.2656477689743042


step 50, mmd 0.2626284956932068


step 60, mmd 0.2592681646347046


step 70, mmd 0.25552308559417725
step 80, mmd 0.25134408473968506


step 90, mmd 0.24667847156524658
step 100, mmd 0.24146902561187744


step 110, mmd 0.23565423488616943


step 120, mmd 0.2291703224182129
step 130, mmd 0.22195005416870117


step 140, mmd 0.21392643451690674
step 150, mmd 0.2050337791442871


step 160, mmd 0.19520938396453857
step 170, mmd 0.18439805507659912


step 180, mmd 0.172554612159729
step 190, mmd 0.15964794158935547


step 200, mmd 0.14566540718078613
step 210, mmd 0.13061630725860596


step 220, mmd 0.11453545093536377
step 230, mmd 0.09748494625091553


step 240, mmd 0.07955563068389893
step 250, mmd 0.06086623668670654


step 260, mmd 0.04156064987182617
step 270, mmd 0.021802783012390137


step 280, mmd 0.001771688461303711
step 290, mmd -0.01834738254547119


step 300, mmd -0.038369596004486084
step 310, mmd -0.058117568492889404


step 320, mmd -0.07742863893508911
step 330, mmd -0.09616035223007202


step 340, mmd -0.11419302225112915
step 350, mmd -0.13143134117126465


step 360, mmd -0.14780592918395996
step 370, mmd -0.16327130794525146


step 380, mmd -0.1778031587600708
step 390, mmd -0.19139736890792847


step 400, mmd -0.20406442880630493
step 410, mmd -0.2158282995223999


step 420, mmd -0.22672319412231445
step 430, mmd -0.23678970336914062


step 440, mmd -0.24607348442077637
step 450, mmd -0.2546231746673584


step 460, mmd -0.26248806715011597
step 470, mmd -0.2697175145149231


step 480, mmd -0.27636003494262695
step 490, mmd -0.28246253728866577


step 500, mmd -0.288068950176239
step 510, mmd -0.293221116065979


step 520, mmd -0.29795849323272705
step 530, mmd -0.30231690406799316


step 540, mmd -0.3063305616378784
step 550, mmd -0.31002992391586304


step 560, mmd -0.31344348192214966
step 570, mmd -0.31659698486328125


step 580, mmd -0.31951409578323364
step 590, mmd -0.32221633195877075


step 600, mmd -0.3247230648994446
step 610, mmd -0.32705211639404297


step 620, mmd -0.3292192816734314
step 630, mmd -0.3312390446662903


step 640, mmd -0.33312493562698364
step 650, mmd -0.33488816022872925


step 660, mmd -0.33653998374938965
step 670, mmd -0.3380897641181946


step 680, mmd -0.3395463228225708
step 690, mmd -0.34091752767562866


step 700, mmd -0.34221047163009644
step 710, mmd -0.34343141317367554


step 720, mmd -0.34458643198013306


step 730, mmd -0.34568071365356445
step 740, mmd -0.3467187285423279


step 750, mmd -0.3477051258087158
step 760, mmd -0.34864336252212524


step 770, mmd -0.34953737258911133
step 780, mmd -0.3503899574279785


step 790, mmd -0.3512040972709656
step 800, mmd -0.3519825339317322


step 810, mmd -0.35272741317749023
step 820, mmd -0.3534409999847412


step 830, mmd -0.3541252017021179
step 840, mmd -0.3547818660736084


step 850, mmd -0.3554123044013977
step 860, mmd -0.3560183048248291


step 870, mmd -0.35660111904144287
step 880, mmd -0.3571620583534241


step 890, mmd -0.3577020764350891
step 900, mmd -0.35822224617004395


step 910, mmd -0.3587237596511841
step 920, mmd -0.3592072129249573


step 930, mmd -0.3596736192703247
step 940, mmd -0.3601234555244446


step 950, mmd -0.3605577349662781
step 960, mmd -0.3609771132469177


step 970, mmd -0.36138200759887695
step 980, mmd -0.3617730736732483


step 990, mmd -0.3621508479118347
step 1000, mmd -0.36251574754714966


step 1010, mmd -0.36286842823028564
step 1020, mmd -0.36320942640304565


step 1030, mmd -0.36353886127471924
step 1040, mmd -0.36385732889175415


step 1050, mmd -0.36416512727737427
step 1060, mmd -0.3644627332687378


step 1070, mmd -0.36475056409835815
step 1080, mmd -0.3650286793708801


step 1090, mmd -0.36529773473739624
step 1100, mmd -0.36555778980255127


step 1110, mmd -0.3658093810081482
step 1120, mmd -0.3660525679588318


step 1130, mmd -0.3662877082824707
step 1140, mmd -0.3665151000022888


step 1150, mmd -0.3667351007461548
step 1160, mmd -0.36694765090942383


step 1170, mmd -0.3671532869338989
step 1180, mmd -0.3673522472381592


step 1190, mmd -0.36754459142684937


In [1]:
"""
Test of MMD loss for classification
"""
import tensorflow as tf
from GeneralTools.math_func import multiclass_mmd_g
from GeneralTools.input_func import SimData


with tf.Graph().as_default():
    # data
    p = SimData('shell2', batch_size=64)
    x = p.next_batch()


[1.0086658, 1.0086659, 1.0086659]


In [5]:
import tensorflow as tf
import numpy as np
from GeneralTools.graph_func import opt_config, global_step_config
from GeneralTools.input_func import Spirals
from GeneralTools.layer_func import Net, Routine
from GeneralTools.math_func import multiclass_mmd_g, get_squared_dist

batch_size = 64
probs = [0.1, 0.2, 0.3, 0.4]
mu = np.array([[3.0, 3.0], [3.0, -3.0], [-3.0, -3.0], [-3.0, 3.0]], dtype=np.float32)
lr = 1e-3
max_step = 1000
query_step = 100

with tf.Graph().as_default():
    # input data
    data_distribution = Spirals(4, 2.5, scale=0.25, sigma=1.0)
    samples, labels = data_distribution.sample(batch_size)
    
    # network
    architecture = [{'name': 'l1', 'out': 10, 'op': 'd', 'act': 'lrelu'}, 
                    {'name': 'l2', 'out': 10, 'op': 'd', 'act': 'lrelu'}, 
                    {'name': 'l3', 'out': 2, 'op': 'd', 'act': 'linear'}]
    net = Net(architecture, net_name='classifier')
    model = Routine(net)
    model.add_input_layers([64, 2])
    model.seq_links()
    model.add_output_layers()
    
    # apply network
    data_batch = {'x': samples, 'y': labels}
    out = model(data_batch, is_training=True)
    codes = out['x']
    
    # calculate loss
    dist_zz = get_squared_dist(codes, mode='xx')
    mmd = multiclass_mmd_g(dist_zz, labels, batch_size, num_class=4, sigma=2.0)
    
    # define optimization process
    global_step = global_step_config()
    _, opt_op = opt_config(lr, optimizer='adam')
    op = opt_op.minimize(-mmd, global_step)
    
    with tf.Session() as sess:
        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        sess.run(init_op)
    
        for step in range(max_step):
            _, z, y, loss = sess.run([op, codes, labels, mmd])
            
            if step % query_step == 0:
                print('mmd: {}'.format(loss))


Adam Optimizer is used.


mmd: 1.588876485824585
mmd: 7.4848127365112305


mmd: 8.526334762573242
mmd: 8.378314971923828
mmd: 9.896876335144043


mmd: 9.749398231506348


mmd: 9.138945579528809


mmd: 10.18325424194336


mmd: 10.5354585647583
mmd: 10.443592071533203


In [7]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8))
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 
          'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 
          'tab:olive', 'tab:cyan']
z_colors = [colors[i[0]] for i in y]
plt.scatter(
    z[:, 0], z[:, 1], c=z_colors)
plt.show()


In [11]:
import tensorflow as tf
import numpy as np
from GeneralTools.input_func import Spirals
from GeneralTools.graph_func import MySession
import matplotlib.pyplot as plt

n_points = 4000
data = Spirals(4, 2.5, scale=0.25, sigma=1.0)
x, y = data.sample(n_points)

with MySession() as sess:
    x, y = sess.run([x, y])

y = np.squeeze(y, axis=1)
fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8))
plt.plot(x[y == 0, 0], x[y == 0, 1], '.', label='class 1')
plt.plot(x[y == 1, 0], x[y == 1, 1], '.', label='class 2')
plt.plot(x[y == 2, 0], x[y == 2, 1], '.', label='class 3')
plt.plot(x[y == 3, 0], x[y == 3, 1], '.', label='class 4')
plt.legend()
plt.show()


Graph initialization finished...
No ckpt model is loaded for current calculation.
Session finished.
