In [1]:
%matplotlib inline

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from itertools import product
from math import ceil
from utils import *

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [2]:
class MNISTEncoder(object):
    def __init__(self, emb_dim, sess):
        self.emb_dim = emb_dim
        self.sess = sess
        
        self.inputs = tf.placeholder("float32", [None, 28*28])
        self.labels = tf.placeholder("bool", [None])
        
        self._build_model()

    def _build_model(self):
        # Convolutional encoder
        x_image = tf.reshape(self.inputs, [-1,28,28,1])

        self.W_conv1 = weight_variable([5, 5, 1, 32])
        self.b_conv1 = bias_variable([32])
        self.h_conv1 = tf.nn.relu(conv2d(x_image, self.W_conv1) + self.b_conv1)
        self.h_pool1 = max_pool_2x2(self.h_conv1)

        self.W_conv2 = weight_variable([5, 5, 32, 64])
        self.b_conv2 = bias_variable([64])

        self.h_conv2 = tf.nn.relu(conv2d(self.h_pool1, self.W_conv2) + self.b_conv2)
        self.h_pool2 = max_pool_2x2(self.h_conv2)

        self.W_fc1 = weight_variable([7 * 7 * 64, self.emb_dim])
        self.b_fc1 = bias_variable([self.emb_dim])

        h_pool2_flat = tf.reshape(self.h_pool2, [-1, 7*7*64])
        self.emb = tf.matmul(h_pool2_flat, self.W_fc1) + self.b_fc1

        # L2 normalize
        self.norm_emb = tf.nn.l2_normalize(self.emb, 1)

    def get_norm_embedding(self, batch):
        return self.sess.run(self.norm_emb, feed_dict={self.inputs: batch})        
        
    def get_embedding(self, batch):
        return self.sess.run(self.emb, feed_dict={self.inputs: batch})

In [10]:
def magnet_loss(r, m, d, alpha=1.0):
    """Compute magnet loss for batch.
    
    Given a batch of features r consisting of m batches
    each with d assigned examples and a cluster separation
    gap of alpha, compute the total magnet loss and the per
    example losses.
    
    Args:
        r: A batch of features.
        m: The number of clusters in the batch.
        d: The number of examples in each cluster.
        alpha: The cluster separation gap hyperparameter.
        
    Returns:
        total_loss: The total magnet loss for the batch.
        losses: The loss for each example in the batch.
    
    """

    # Take cluster means within the batch
    cluster_means = tf.reduce_mean(tf.reshape(r, [m, d, -1]), 1)

    # Compute squared differences of each example to each cluster centroid
    sample_cluster_pair_inds = np.array(list(product(range(m*d), range(m))))
    sample_costs = tf.squared_difference(
        tf.gather(r, sample_cluster_pair_inds[:,0]),
        tf.gather(cluster_means, sample_cluster_pair_inds[:,1]))

    # Sum to compute squared distances of each example to each cluster centroid
    # and reshape such that tensor is indexed by
    # [true cluster, comparison cluster, example in true cluster]
    sample_costs = tf.reshape(tf.reduce_sum(sample_costs, 1), [m, d, m])
    sample_costs = tf.transpose(sample_costs, [0, 2, 1])

    # Select distances of examples to their own centroid
    same_cluster_inds = np.vstack(np.diag_indices(m)).T
    intra_cluster_costs = tf.gather_nd(sample_costs, same_cluster_inds)

    # Select distances of examples to other centroids and reshape such that
    # tensor is indexed by [true cluster, other cluster, example in true cluster]
    cluster_inds = np.arange(m)
    diff_cluster_inds = np.vstack(
        [np.repeat(cluster_inds, m-1), 
         np.hstack([cluster_inds[cluster_inds != i] for i in range(m)])]).T
    inter_cluster_costs = tf.reshape(tf.gather_nd(sample_costs, diff_cluster_inds), [m, m-1, d])

    # Compute variance of intra-cluster squared distances
    variance = tf.reduce_sum(intra_cluster_costs) / (m * d - 1)
    var_normalizer = -1 / 2*variance**2

    # Compute numerator and denominator of inner term
    numerator = tf.exp(var_normalizer * intra_cluster_costs - alpha)
    denominator = tf.reduce_sum(tf.exp(var_normalizer * inter_cluster_costs), 1)

    # Compute example losses and total loss
    losses = tf.nn.relu(-tf.log(numerator / denominator))
    total_loss = tf.reduce_mean(losses)
    
    return total_loss, losses




# sess = tf.InteractiveSession()

# m = 12
# d = 4
# K = 5
# alpha = 1.0

# r = tf.placeholder(tf.float32, [None, 8])
# magnet_loss, losses = magnet_loss(r, m, d, alpha)

# feed_dict = {r: np.random.random([m*d, 8])}
# print sess.run(magnet_loss, feed_dict=feed_dict)

# sess.close()
# tf.reset_default_graph()

In [11]:
from sklearn.cluster import KMeans

class ClusterBatchBuilder(object):
    def __init__(self, labels, k, m, d):
        
        self.num_classes = np.unique(labels).shape[0]
        self.labels = labels

        self.k = k
        self.m = m
        self.d = d

        self.centroids = None
        self.assignments = np.zeros_like(labels, int)
        self.cluster_assignments = {}
        self.example_losses = np.zeros_like(labels, float)
        self.cluster_losses = np.zeros([self.k * self.num_classes], float)

    
    def update_clusters(self, rep_data, max_iter=20):
        """
        Given an array of representations for the entire training set,
        recompute clusters and store example cluster assignments in a
        quickly sampleable form.
        """
        # Lazily allocate array for centroids
        if self.centroids is None:
            self.centroids = np.zeros([self.num_classes * self.m, rep_data.shape[1]])
        
        for c in range(self.num_classes):

            class_mask = self.labels == c
            class_examples = rep_data[class_mask]
            kmeans = KMeans(n_clusters=self.k, init='k-means++', n_init=1, max_iter=max_iter)
            kmeans.fit(class_examples)

            # Save cluster centroids for finding impostor clusters
            start = self.get_cluster_ind(c, 0)
            stop = self.get_cluster_ind(c, self.k)
            self.centroids[start:stop] = kmeans.cluster_centers_

            # Update assignments with new global cluster indexes
            self.assignments[class_mask] = self.get_cluster_ind(c, kmeans.predict(class_examples))
            
        # Construct a map from cluster to example indexes for fast batch creation
        for c in range(self.k * self.num_classes):
            cluster_mask = self.assignments == c
            self.cluster_assignments[c] = np.flatnonzero(cluster_mask)

        
    def update_losses(self, indexes, losses):
        """
        Given a list of examples indexes and corresponding losses
        store the new losses and update corresponding cluster losses.
        """
        
        # Update example losses
        self.example_losses[indexes] = losses

        # Find affected clusters and update the corresponding cluster losses
        clusters = np.unique(self.assignments[indexes])
        for c in clusters:
            cluster_example_losses = self.example_losses[self.assignments == c]
            self.cluster_losses[c] = np.mean(cluster_example_losses)

        
    def gen_batch(self):
        """
        Sample a batch by first sampling a seed cluster proportionally to
        the mean loss of the clusters, then finding nearest neighbor
        "impostor" clusters, then sampling d examples uniformly from each cluster.
        
        The generated batch will consist of m clusters each with d consecutive
        examples.
        """
        
        # Sample seed cluster proportionally to cluster losses
        p = self.cluster_losses / np.sum(self.cluster_losses)
        seed_cluster = np.random.choice(self.num_classes * self.k, p=p)

        # Get imposter clusters by ranking centroids by distance
        # The seed cluster itself is guaranteed to be included
        sq_dists = ((self.centroids[seed_cluster] - self.centroids) ** 2).sum(axis=1)
        clusters = np.argpartition(sq_dists, self.m)[:self.m]
        
        # Sample examples uniformly from cluster
        batch_indexes = np.empty([self.m * self.d], int)
        for i, c in enumerate(clusters):
            x = np.random.choice(self.cluster_assignments[c], self.d, replace=False)

            start = i * self.d
            stop = start + self.d
            batch_indexes[start:stop] = x

        return batch_indexes

    
    def get_cluster_ind(self, c, i):
        """
        Given a class index and a cluster index within the class
        return the global cluster index
        """
        return c * self.k + i
    

# rep = np.random.random([mnist.train.images.shape[0], 8])
# batch_builder = ClusterBatchBuilder(mnist.train.labels, 5, 12, 4)
# batch_builder.update_clusters(rep)
# batch_builder.update_losses(range(mnist.train.labels.shape[0]), np.random.random([mnist.train.labels.shape[0]]))
# batch_builder.gen_batch()

In [12]:
# Define magnet loss parameters
m = 12
d = 4
k = 5
alpha = 1.0

# Define model and training parameters
emb_dim = 2
num_steps = 100
cluster_interval = 10

# Define training data
X = mnist.train.images
y = mnist.train.labels


sess = tf.InteractiveSession()

# Model
with tf.variable_scope('model'):
    model = MNISTEncoder(emb_dim, sess)

# Loss
with tf.variable_scope('magnet_loss'):
    train_loss, losses = magnet_loss(model.emb, m, d, alpha)

train_op = tf.train.AdamOptimizer(1e-4).minimize(train_loss)


# Get initial embedding and losses
chunks = 100
examples_per_chunk = ceil(float(X.shape[0]), chunks)
reps = []
example_losses = []
for i in range(chunks):
    start = i * examples_per_chunk
    stop = start + examples_per_chunk
    chunk_reps, chunk_losses = sess.run([model.emb, losses], feed_dict={model.inputs:X[start:stop]})
    reps.append(chunk_reps)
    example_losses.append(chunk_losses)
    
initial_reps = np.vstack(reps)
example_losses = np.vstack(example_losses)

print initial_reps.shape
print example_losses.shape

# # Create batcher
# batch_builder = ClusterBatchBuilder(mnist.train.labels, 5, 12, 4)
# batch_builder.update_clusters(rep)


# for i in range(num_steps):
    
sess.close()
tf.reset_default_graph()

NotImplementedError: Gradient for gather_nd is not implemented.

In [None]:
sess.close()
tf.reset_default_graph()