# MAML Regression

In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf

In [None]:
sess = tf.Session()

## Generate Sine Waves

In [None]:
amp = np.random.uniform(0.1, 5.0, size=10)
phase = np.random.uniform(0, np.pi, size=10)

for i in range(10):
    # generate the data
    inputs = np.linspace(-5, 5, 10000)[:, None]
    outputs = amp[i] * np.sin(inputs + phase[i])
    plt.scatter(inputs[:, 0], outputs[:, 0], s=0.1, color='k', marker='o')
plt.show()

## MAML Regression Model

In [None]:
def create_model(in_ph, learning_rate=0.01, grads=None):
    with tf.variable_scope('fc1', reuse=tf.AUTO_REUSE) as scope:
        tf.contrib.layers.fully_connected(in_ph, 40)  # Create weights and biases
        w = tf.get_variable('fully_connected/weights')
        b = tf.get_variable('fully_connected/biases')
        w = w - (tf.zeros_like(w)) if grads is None else (learning_rate * grads[w.name])
        b = b - (tf.zeros_like(b)) if grads is None else (learning_rate * grads[b.name])
        network = tf.nn.relu(tf.nn.xw_plus_b(in_ph, w, b))

    with tf.variable_scope('fc2', reuse=tf.AUTO_REUSE):
        tf.contrib.layers.fully_connected(network, 40)
        w = tf.get_variable('fully_connected/weights')
        b = tf.get_variable('fully_connected/biases')
        w = w - (tf.zeros_like(w)) if grads is None else (learning_rate * grads[w.name])
        b = b - (tf.zeros_like(b)) if grads is None else (learning_rate * grads[b.name])
        network_2 = tf.nn.relu(tf.nn.xw_plus_b(network, w, b))

    with tf.variable_scope('out', reuse=tf.AUTO_REUSE):
        tf.contrib.layers.fully_connected(network_2, 1)
        w = tf.get_variable('fully_connected/weights')
        b = tf.get_variable('fully_connected/biases')
        w = w - (tf.zeros_like(w)) if grads is None else (learning_rate * grads[w.name])
        b = b - (tf.zeros_like(b)) if grads is None else (learning_rate * grads[b.name])
        out_pred = tf.nn.xw_plus_b(network_2, w, b)
    return out_pred

In [None]:
class MAMLTask:
    def __init__(self):
        self.train_in_ph = tf.placeholder(dtype=tf.float32, name='task_train_input', shape=[None, 1])
        self.train_targets = tf.placeholder(dtype=tf.float32, name='task_train_targets', shape=[None, 1])
        self.test_in_ph = tf.placeholder(dtype=tf.float32, name='task_test_input', shape=[None, 1])
        self.test_targets = tf.placeholder(dtype=tf.float32, name='task_test_targets', shape=[None, 1])

        # Take gradient w.r.t overall parameters. Corresponds with lines 6 & 7 of Algorithm 2
        self.out_pred = create_model(self.train_in_ph)
        self.loss = tf.losses.mean_squared_error(self.train_targets, self.out_pred)
        weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        grads = {var.name: grad for grad, var in zip(tf.gradients(self.loss, weights), weights)}

        # Use gradients for update in line 10
        self.task_out_pred = create_model(self.test_in_ph, grads=grads)
        self.task_loss = tf.losses.mean_squared_error(self.test_targets, self.task_out_pred)


In [None]:
class MAMLNetworkBuilder():
    def __init__(self, num_tasks):
        # Create forward pass for main parameters
        self.in_ph = tf.placeholder(dtype=tf.float32, shape=[None, 1])
        self.out_pred = create_model(self.in_ph)

        # Create regression loss for fine tuning
        self.targets = tf.placeholder(dtype=tf.float32, name='tune_targets', shape=[None, 1])
        self.loss = tf.losses.mean_squared_error(self.targets, self.out_pred)
        self.adam = tf.train.AdamOptimizer()
        self.opt = self.adam.minimize(self.loss)

        # Generate individual tasks and combined loss (line 10 of Algorithm 2)
        self.tasks = [MAMLTask() for i in range(num_tasks)]
        self.task_loss = tf.reduce_sum([t.task_loss for t in self.tasks])
        self.task_opt = self.adam.minimize(self.task_loss)


In [None]:
meta_iters = 75000
K = 10
meta_batch_size = 25
mnb = MAMLNetworkBuilder(meta_batch_size)

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
outputs = []
for m in range(meta_iters):
    # Sample batch of tasks
    amp = np.random.uniform(0.1, 5.0, size=meta_batch_size)
    phase = np.random.uniform(0, np.pi, size=meta_batch_size)
    assert len(amp) == meta_batch_size

    # Get train and test data for each task
    test_data = []
    train_data = []
    for t, (a, p) in enumerate(zip(amp, phase)):
        # Sample K datapoints from the task (train/test)
        x = np.random.uniform(-5.0, 5.0, size=K * 2)
        truth = a * np.sin(x + p)
        assert len(truth) == K * 2
        
        train_data.append({'input':np.matrix(x[:K]).transpose(),
                           'truth':np.matrix(truth[:K]).transpose()})
        test_data.append({'input':np.matrix(x[K:]).transpose(),
                           'truth':np.matrix(truth[K:]).transpose()})
    
    assert len(train_data) == meta_batch_size
    assert len(train_data[0]) == 2

        
    # Perform update of parameters
    feed_dict = {}
    task_train_losses = []
    task_test_losses = []
    
    for i,t in enumerate(mnb.tasks):
        feed_dict[t.train_in_ph] = train_data[i]['input']
        feed_dict[t.train_targets] = train_data[i]['truth']

        feed_dict[t.test_in_ph] = test_data[i]['input']
        feed_dict[t.test_targets] = test_data[i]['truth']
        
        task_train_losses.append(t.loss)
        task_test_losses.append(t.task_loss)
    
    output =  sess.run([mnb.task_opt, mnb.task_loss] + task_train_losses + task_test_losses,
                       feed_dict=feed_dict)
    
    outputs.append(output)
    if m % 1000 == 0:
        print(m)
        print(f'Overall loss:{output[1]}\n'
             f'Task Train loss:{output[2]}\n'
             f'Task Test loss:{output[3]}')

## Plot loss

In [None]:
overall_loss_data = [o[1] for o in outputs]

In [None]:
plt.plot(overall_loss_data)

## Test on sine waves

### Copy the network first to avoid overwriting

In [None]:
with tf.variable_scope('test'):
    mnb_test = MAMLNetworkBuilder(25)    

In [None]:
test_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='test')

test_weights

In [None]:
weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
test_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='test')
test_extras = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='test')
copy_op = [old.assign(new) for (new, old) in
                                      zip(weights,
                                          test_weights)]

In [None]:
sess.run(copy_op)
sess.run(tf.variables_initializer(test_extras))

In [None]:
step_list = [0, 1, 10, 100]
color = ['g', 'b', 'orange', 'red']
# figs, axis = plt.subplots(len(step_list),figsize=(10,len(step_list)*5))
figs, axis = plt.subplots(1,figsize=(10,10))

# Sample one task
amp = np.random.uniform(0.1, 5.0, size=1)
phase = np.random.uniform(0, np.pi, size=1)
inputs = np.linspace(-5, 5, 10000)[:, None]
outputs = amp * np.sin(inputs + phase)
    
# choose K random task training points
x = np.random.uniform(-5.0, 5.0, size=5)
truth = amp * np.sin(x + phase)

# take gradient steps and report results
for i, steps in enumerate(step_list):
    
    # plot truth and training points 
    axis.scatter(inputs[:, 0], outputs[:, 0],
                    s=0.1, color='k', marker='o')
    axis.scatter(x, truth, marker='^', color='red')
    
    # gradient steps
    for s in range(steps):
        _, loss = sess.run([mnb_test.opt, mnb_test.loss],
                            feed_dict={mnb_test.in_ph:np.matrix(x).transpose(),
                                       mnb_test.targets:np.matrix(truth).transpose()})
    # plot prediction
    pred = sess.run(mnb_test.out_pred, feed_dict={mnb_test.in_ph:inputs})
    axis.scatter(inputs[:, 0], pred,
                s=0.1, color=color[i], marker='x', label=f'{step_list[i]} Grad Steps')

    sess.run(copy_op)
    sess.run(tf.variables_initializer(test_extras))
axis.legend(fontsize=20, markerscale=20)
    # Reset network weights for next trial
_ = sess.run(copy_op)
sess.run(tf.variables_initializer(test_extras))