In [1]:
%matplotlib inline

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from itertools import product
from math import ceil
from utils import *

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [2]:
class MNISTEncoder(object):
    def __init__(self, emb_dim, sess):
        self.emb_dim = emb_dim
        self.sess = sess
        
        self.inputs = tf.placeholder("float32", [None, 28*28])
        self.labels = tf.placeholder("bool", [None])
        
        self._build_model()

    def _build_model(self):
        # Convolutional encoder
        x_image = tf.reshape(self.inputs, [-1,28,28,1])

        self.W_conv1 = weight_variable([5, 5, 1, 32])
        self.b_conv1 = bias_variable([32])
        self.h_conv1 = tf.nn.relu(conv2d(x_image, self.W_conv1) + self.b_conv1)
        self.h_pool1 = max_pool_2x2(self.h_conv1)

        self.W_conv2 = weight_variable([5, 5, 32, 64])
        self.b_conv2 = bias_variable([64])

        self.h_conv2 = tf.nn.relu(conv2d(self.h_pool1, self.W_conv2) + self.b_conv2)
        self.h_pool2 = max_pool_2x2(self.h_conv2)

        self.W_fc1 = weight_variable([7 * 7 * 64, self.emb_dim])
        self.b_fc1 = bias_variable([self.emb_dim])

        h_pool2_flat = tf.reshape(self.h_pool2, [-1, 7*7*64])
        self.emb = tf.matmul(h_pool2_flat, self.W_fc1) + self.b_fc1

        # L2 normalize
        self.norm_emb = tf.nn.l2_normalize(self.emb, 1)

    def get_norm_embedding(self, batch):
        return self.sess.run(self.norm_emb, feed_dict={self.inputs: batch})        
        
    def get_embedding(self, batch):
        return self.sess.run(self.emb, feed_dict={self.inputs: batch})

In [None]:
def magnet_loss_old(r, m, d, alpha=1.0):
    """Compute magnet loss for batch.
    
    Given a batch of features r consisting of m batches
    each with d assigned examples and a cluster separation
    gap of alpha, compute the total magnet loss and the per
    example losses.
    
    Args:
        r: A batch of features.
        m: The number of clusters in the batch.
        d: The number of examples in each cluster.
        alpha: The cluster separation gap hyperparameter.
        
    Returns:
        total_loss: The total magnet loss for the batch.
        losses: The loss for each example in the batch.
    
    """

    # Take cluster means within the batch
    cluster_means = tf.reduce_mean(tf.reshape(r, [m, d, -1]), 1)

    # Compute squared differences of each example to each cluster centroid
    sample_cluster_pair_inds = np.array(list(product(range(m*d), range(m))))
    sample_costs = tf.squared_difference(
        tf.gather(r, sample_cluster_pair_inds[:,0]),
        tf.gather(cluster_means, sample_cluster_pair_inds[:,1]))

    # Sum to compute squared distances of each example to each cluster centroid
    # and reshape such that tensor is indexed by
    # [true cluster, comparison cluster, example in true cluster]
    sample_costs = tf.reshape(tf.reduce_sum(sample_costs, 1), [m, d, m])
    sample_costs = tf.transpose(sample_costs, [0, 2, 1])

    # Select distances of examples to their own centroid
    same_cluster_inds = np.vstack(np.diag_indices(m)).T
    intra_cluster_costs = tf.gather_nd(sample_costs, same_cluster_inds)

    # Select distances of examples to other centroids and reshape such that
    # tensor is indexed by [true cluster, comparison cluster, example]
    cluster_inds = np.arange(m)
    diff_cluster_inds = np.vstack(
        [np.repeat(cluster_inds, m-1), 
         np.hstack([cluster_inds[cluster_inds != i] for i in range(m)])]).T
    inter_cluster_costs = tf.reshape(tf.gather_nd(sample_costs, diff_cluster_inds), [m, m-1, d])

    # Compute variance of intra-cluster squared distances
    variance = tf.reduce_sum(intra_cluster_costs) / (m * d - 1)
    var_normalizer = -1 / 2*variance**2

    # Compute numerator and denominator of inner term
    numerator = tf.exp(var_normalizer * intra_cluster_costs - alpha)
    denominator = tf.reduce_sum(tf.exp(var_normalizer * inter_cluster_costs), 1)

    # Compute example losses and total loss
    losses = tf.nn.relu(-tf.log(numerator / denominator))
    total_loss = tf.reduce_mean(losses)
    
    return total_loss, losses


def magnet_loss(r, m, d, alpha=1.0):
    """Compute magnet loss for batch.
    
    Given a batch of features r consisting of m batches
    each with d assigned examples and a cluster separation
    gap of alpha, compute the total magnet loss and the per
    example losses.
    
    Args:
        r: A batch of features.
        m: The number of clusters in the batch.
        d: The number of examples in each cluster.
        alpha: The cluster separation gap hyperparameter.
        
    Returns:
        total_loss: The total magnet loss for the batch.
        losses: The loss for each example in the batch.
    
    """
    
    # Helper to compute indexes to select intra- and inter-cluster
    # distances
    def compute_comparison_inds():
        same_cluster_inds = []
        for i in range(m*d):
            c = i / d
            same_cluster_inds.append(c*d*m + c*d + (i % d))
        diff_cluster_inds = sorted(set(range(m*m*d)) - set(same_cluster_inds))
        
        return same_cluster_inds, diff_cluster_inds

        
    # Take cluster means within the batch
    cluster_means = tf.reduce_mean(tf.reshape(r, [m, d, -1]), 1)

    # Compute squared distance of each example to each cluster centroid
    sample_cluster_pair_inds = np.array(list(product(range(m), range(m*d))))
    sample_costs = tf.squared_difference(
        tf.gather(cluster_means, sample_cluster_pair_inds[:,0]),
        tf.gather(r, sample_cluster_pair_inds[:,1]))
    sample_costs = tf.reduce_sum(sample_costs, 1)
    
    # Compute intra- and inter-cluster comparison indexes
    same_cluster_inds, diff_cluster_inds = compute_comparison_inds()
    
    # Select distances of examples to their own centroid
    intra_cluster_costs = tf.gather(sample_costs, same_cluster_inds)
    intra_cluster_costs = tf.reshape(intra_cluster_costs, [m, d])
    
    # Select distances of examples to other centroids
    inter_cluster_costs = tf.reshape(tf.gather(sample_costs, diff_cluster_inds), [m, m-1, d])

    # Compute variance of intra-cluster squared distances
    variance = tf.reduce_sum(intra_cluster_costs) / (m * d - 1)
    var_normalizer = -1 / 2*variance**2

    # Compute numerator and denominator of inner term
    numerator = tf.exp(var_normalizer * intra_cluster_costs - alpha)
    denominator = tf.reduce_sum(tf.exp(var_normalizer * inter_cluster_costs), 1)

    # Compute example losses and total loss
    losses = tf.nn.relu(-tf.log(numerator / denominator))
    total_loss = tf.reduce_mean(losses)
    
    return total_loss, losses


    
# sess = tf.InteractiveSession()

# m = 3
# d = 2

# m = 6
# d = 4

# K = 5
# alpha = 15.0

# r = tf.placeholder(tf.float32, [None, 8])
# magnet_loss1, losses1 = magnet_loss_old(r, m, d, alpha)
# magnet_loss2, losses2 = magnet_loss(r, m, d, alpha)


# # Helper to generate debug data
# def gen_data(m, d):
#     data = []
#     for c in range(m):
#         a = (c + 1) * 3
#         centroid = np.random.random([1, 8]) * a
#         data.append(centroid + np.random.random([d, 8]))
#     return np.vstack(data)


# # feed_dict = {r: np.random.random([m*d, 8])}
# feed_dict = {r: gen_data(m, d)}

# print sess.run([magnet_loss1, magnet_loss2], feed_dict=feed_dict)


# sess.close()
# tf.reset_default_graph()

In [15]:

def magnet_loss(r, c, m, d, alpha=1.0):
    """Compute magnet loss for batch.
    
    Given a batch of features r consisting of m batches
    each with d assigned examples and a cluster separation
    gap of alpha, compute the total magnet loss and the per
    example losses.
    
    Args:
        r: A batch of features.
        c: Class labels for each example.
        m: The number of clusters in the batch.
        d: The number of examples in each cluster.
        alpha: The cluster separation gap hyperparameter.
        
    Returns:
        total_loss: The total magnet loss for the batch.
        losses: The loss for each example in the batch.
    
    """
    
    # Helper to compute indexes to select intra- and inter-cluster
    # distances
    def compute_comparison_inds():
        same_cluster_inds = []
        for i in range(m*d):
            c = i / d
            same_cluster_inds.append(c*d*m + c*d + (i % d))
        diff_cluster_inds = sorted(set(range(m*m*d)) - set(same_cluster_inds))
        
        return same_cluster_inds, diff_cluster_inds

        
    # Take cluster means within the batch
    cluster_means = tf.reduce_mean(tf.reshape(r, [m, d, -1]), 1)

    # Compute squared distance of each example to each cluster centroid
    sample_cluster_pair_inds = np.array(list(product(range(m), range(m*d))))
    sample_costs = tf.squared_difference(
        tf.gather(cluster_means, sample_cluster_pair_inds[:,0]),
        tf.gather(r, sample_cluster_pair_inds[:,1]))
    sample_costs = tf.reduce_sum(sample_costs, 1)
    
    # Compute intra- and inter-cluster comparison indexes
    same_cluster_inds, diff_cluster_inds = compute_comparison_inds()
    
    # Select distances of examples to their own centroid
    intra_cluster_costs = tf.gather(sample_costs, same_cluster_inds)
    intra_cluster_costs = tf.reshape(intra_cluster_costs, [m, d])
    
    
    
    # Select distances of examples to other centroids
    inter_cluster_costs = tf.reshape(tf.gather(sample_costs, diff_cluster_inds), [m, m-1, d])
    
    
    # Select distances of examples to other class centroids
    cluster_classes = tf.strided_slice(c, [0], [m*d], [d])
    print cluster_classes
    

    # Compute variance of intra-cluster squared distances
    variance = tf.reduce_sum(intra_cluster_costs) / (m * d - 1)
    var_normalizer = -1 / 2*variance**2

    # Compute numerator and denominator of inner term
    numerator = tf.exp(var_normalizer * intra_cluster_costs - alpha)
    denominator = tf.reduce_sum(tf.exp(var_normalizer * inter_cluster_costs), 1)

    # Compute example losses and total loss
    losses = tf.nn.relu(-tf.log(numerator / denominator))
    total_loss = tf.reduce_mean(losses)
    
    return total_loss, losses



sess = tf.InteractiveSession()

m = 6
d = 4

K = 5
alpha = 15.0

r = tf.placeholder(tf.float32, [m*d, 8])
c = tf.placeholder(tf.int32, [m*d])
magnet_loss2, losses2, cc = magnet_loss(r, c, m, d, alpha)


# Helper to generate debug data
def gen_data(m, d):
    data = []
    for c in range(m):
        a = (c + 1) * 3
        centroid = np.random.random([1, 8]) * a
        data.append(centroid + np.random.random([d, 8]))
    return np.vstack(data)


# feed_dict = {r: np.random.random([m*d, 8])}
feed_dict = {r: gen_data(m, d), c: np.repeat(range(m), d)}

print sess.run([magnet_loss2, cc], feed_dict=feed_dict)


sess.close()
tf.reset_default_graph()

Tensor("StridedSlice:0", shape=(6,), dtype=int32)
[1.8370082, array([0, 1, 2, 3, 4, 5], dtype=int32)]


In [9]:
np.repeat(range(m), d)

array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5,
       5])

In [None]:
from sklearn.cluster import KMeans

class ClusterBatchBuilder(object):
    def __init__(self, labels, k, m, d):
        
        self.num_classes = np.unique(labels).shape[0]
        self.labels = labels

        self.k = k
        self.m = m
        self.d = d

        self.centroids = None
        self.assignments = np.zeros_like(labels, int)
        self.cluster_assignments = {}
        self.example_losses = None
        self.cluster_losses = np.zeros([self.k * self.num_classes], float)

    
    def update_clusters(self, rep_data, max_iter=20):
        """
        Given an array of representations for the entire training set,
        recompute clusters and store example cluster assignments in a
        quickly sampleable form.
        """
        # Lazily allocate array for centroids
        if self.centroids is None:
            self.centroids = np.zeros([self.num_classes * self.k, rep_data.shape[1]])
        
        for c in range(self.num_classes):

            class_mask = self.labels == c
            class_examples = rep_data[class_mask]
            kmeans = KMeans(n_clusters=self.k, init='k-means++', n_init=1, max_iter=max_iter)
            kmeans.fit(class_examples)

            # Save cluster centroids for finding impostor clusters
            start = self.get_cluster_ind(c, 0)
            stop = self.get_cluster_ind(c, self.k)
            self.centroids[start:stop] = kmeans.cluster_centers_

            # Update assignments with new global cluster indexes
            self.assignments[class_mask] = self.get_cluster_ind(c, kmeans.predict(class_examples))
            
        # Construct a map from cluster to example indexes for fast batch creation
        for cluster in range(self.k * self.num_classes):
            cluster_mask = self.assignments == cluster
            self.cluster_assignments[cluster] = np.flatnonzero(cluster_mask)

        
    def update_losses(self, indexes, losses):
        """
        Given a list of examples indexes and corresponding losses
        store the new losses and update corresponding cluster losses.
        """
        
        # Update example losses
        self.example_losses[indexes] = losses

        # Find affected clusters and update the corresponding cluster losses
        clusters = np.unique(self.assignments[indexes])
        for cluster in clusters:
            cluster_example_losses = self.example_losses[self.assignments == cluster]
            self.cluster_losses[cluster] = np.mean(cluster_example_losses)

        
    def gen_batch(self):
        """
        Sample a batch by first sampling a seed cluster proportionally to
        the mean loss of the clusters, then finding nearest neighbor
        "impostor" clusters, then sampling d examples uniformly from each cluster.
        
        The generated batch will consist of m clusters each with d consecutive
        examples.
        """
        
        # Sample seed cluster proportionally to cluster losses if available
        if self.cluster_losses is not None:
            p = self.cluster_losses / np.sum(self.cluster_losses)
            seed_cluster = np.random.choice(self.num_classes * self.k, p=p)
        else:
            seed_cluster = np.random.choice(self.num_classes * self.k)

        # Get imposter clusters by ranking centroids by distance
        # The seed cluster itself is guaranteed to be included
        sq_dists = ((self.centroids[seed_cluster] - self.centroids) ** 2).sum(axis=1)
        clusters = np.argpartition(sq_dists, self.m)[:self.m]
        
        # Sample examples uniformly from cluster
        batch_indexes = np.empty([self.m * self.d], int)
        for i, c in enumerate(clusters):
            x = np.random.choice(self.cluster_assignments[c], self.d, replace=False)

            start = i * self.d
            stop = start + self.d
            batch_indexes[start:stop] = x

        return batch_indexes

    
    def get_cluster_ind(self, c, i):
        """
        Given a class index and a cluster index within the class
        return the global cluster index
        """
        return c * self.k + i
    

# rep = np.random.random([mnist.train.images.shape[0], 8])
# batch_builder = ClusterBatchBuilder(mnist.train.labels, 5, 12, 4)
# batch_builder.update_clusters(rep)
# batch_builder.update_losses(range(mnist.train.labels.shape[0]), np.random.random([mnist.train.labels.shape[0]]))
# batch_builder.gen_batch()

In [None]:
# Define magnet loss parameters
m = 6
d = 4
k = 3
alpha = 1.0
batch_size = m * d

# Define training data
X = mnist.train.images
y = mnist.train.labels

# Define model and training parameters
epoch_steps = int(ceil(float(X.shape[0]) / batch_size)) 
emb_dim = 2
# num_steps = 5 * epoch_steps
# cluster_refresh_interval = epoch_steps
num_steps = epoch_steps
cluster_refresh_interval = 20



sess = tf.InteractiveSession()

# Model
with tf.variable_scope('model'):
    model = MNISTEncoder(emb_dim, sess)

# Loss
with tf.variable_scope('magnet_loss'):
    train_loss, losses = magnet_loss(model.emb, m, d, alpha)

train_op = tf.train.AdamOptimizer(1e-4).minimize(train_loss)

sess.run(tf.initialize_all_variables())



def compute_reps(extract_fn, X, chunk_size):
    chunks = int(ceil(float(X.shape[0]) / chunk_size))
    reps = []
    for i in range(chunks):
        start = i * chunk_size
        stop = start + chunk_size
        chunk_reps = extract_fn(X[start:stop])
        reps.append(chunk_reps)
    return np.vstack(reps)

# Get initial embedding
extract = lambda x: sess.run(model.emb, feed_dict={model.inputs: x})
initial_reps = compute_reps(extract, X, 400)

print initial_reps.shape


# # Create batcher
batch_builder = ClusterBatchBuilder(mnist.train.labels, k, m, d)
batch_builder.update_clusters(initial_reps)


for i in range(num_steps):
    
    batch_inds = batch_builder.gen_batch()
    _, batch_loss = sess.run([train_op, train_loss], feed_dict={model.inputs: X[batch_inds]})
    
    if not i % 100:
        print i, batch_loss
    
    if not i % cluster_refresh_interval:
        print 'Refreshing clusters'
        reps = compute_reps(extract, X, 400)
        print reps[0]
        print np.any(np.isnan(reps))
        batch_builder.update_clusters(reps)
        
final_reps = compute_reps(extract, X, 400)
    
sess.close()
tf.reset_default_graph()

In [None]:
reps

In [None]:
num_plot = 500
imgs = mnist.train.images[:num_plot]
imgs = np.reshape(imgs, [num_plot, 28, 28])
plot_embedding(initial_reps[:num_plot], mnist.train.labels[:num_plot], imgs)

In [None]:
num_plot = 500
imgs = mnist.train.images[:num_plot]
imgs = np.reshape(imgs, [num_plot, 28, 28])
plot_embedding(final_reps[:num_plot], mnist.train.labels[:num_plot], imgs)

In [None]:
sess.close()
tf.reset_default_graph()