This is an experiment in using SGD + a simple feedforward network as an autoregressive generative model.
MAML and Reptile effectively turn feedforward networks (e.g. CNNs and MLPs) into sequence models. Just like in an RNN, MAML and Reptile have a state (the network parameters) which is updated over time (every inner-loop training step). Thus, we can pose a sequence of
(x_i, y_i) as a dataset of inputs (
x_i) and outputs (
y_i), and use MAML/Reptile to train the network to predict the output
y_i given input
x_i after being tuned (with SGD) on all of the previous input/output pairs in the dataset. In text generation, I set
x_i to the index of the current character. For image generation, I set
x_i to the 2-D pixel coordinate.
Generation using such a model is simple:
theta= parameters learned with Reptile/MAML.
x_i= some fixed indices/coordinates.
ifrom 1 to
P = f_theta(x_i), i.e. the network output for input
y_iby sampling from
P(which is typically logits).
- Take an SGD step to increase
- The result of all
y_iis the sampled sequence.
A more concrete example can be found in text_sample.py, and is simplified here for illustration:
model = load_text_model() opt = optim.SGD(model.parameters()) sequence =  for i in range(128): # Create input Tensor containing x_i (the index). inputs = torch.from_numpy(np.array([[i]])).to(device).long() # Compute output distribution, P. logits = model(inputs) probs = F.softmax(logits, dim=0).detach().cpu().numpy() # Sample from the output distribution. sample = np.random.choice(np.arange(256), p=probs) sequence.append(int(sample)) # NUL character ends the sequence early. if sample == 0: break # Generate a target vector to increase P(sample). targets = torch.from_numpy(np.array([sample])).to(device).long() # Compute -P(sample) and take a step to minimize it. loss = F.cross_entropy(logits, targets) opt.zero_grad() loss.backward() opt.step() # Print the final sample. print(str(bytes([min(0x79, x) for x in sequence]), 'ascii'))
I tried reptile-gen with both MAML and Reptile on two tasks: text generation and binarized sequential MNIST generation.
I also experimented with different model architectures. First I tried vanilla MLPs, but quickly saw some improvement by switching to gated activation functions. Presumably, this allowed the model to prevent certain parameters from being heavily affected by every SGD step. I found the most improvement from using an LSTM block directly, but with no recurrent connections. Essentially, it was just an LSTM block where the state vector was just another trainable parameter.
Here are the best MNIST samples my models were able to generate. Not very nice looking!
All of my experiments resulted in significant underfitting. However, MAML tended to produce better models than Reptile, indicating that Reptile's gradient-free nature prevents it from modeling long-term dependencies.
The best architecture for reptile-gen was, by a small margin, the LSTM block with no recurrent connections. Since LSTMs were not taylored for reptile-gen, it seems likely that a gating mechanism more taylored for reptile-gen could produce better results. However, my interested in the project started to run thin before I got to that.