In [None]:
%matplotlib inline
import numpy as np
import matplotlib
import sys
import time

In [None]:
# import pytorch modules
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data

In [None]:
# plotting
sys.path.insert(0,'..')
from utils import plot_stroke

In [None]:
# find gpu
cuda = torch.cuda.is_available()

In [None]:
# hyperparamters
timesteps = 800
num_clusters = 20
cell_size = 400
nlayers = 2
bsize = 150
init_lr = 1E-3

In [None]:
# prepare training data
train_data = [np.load('train_strokes_800.npy'), np.load('train_masks_800.npy'), np.load('train_onehot_800.npy')]
for _ in range(len(train_data)):
    train_data[_] =torch.from_numpy(train_data[_]).type(torch.FloatTensor)
    if cuda:
        train_data[_] = train_data[_].cuda()
train_data = [(train_data[0][i], train_data[1][i], train_data[2][i]) for i in range(len(train_data[0]))] 
train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=bsize, shuffle=True)
    
# prepare validation data
validation_data = [np.load('validation_strokes_800.npy'), np.load('validation_masks_800.npy'), 
                   np.load('validation_onehot_800.npy')]
for _ in range(len(validation_data)):
    validation_data[_] = torch.from_numpy(validation_data[_])
    if cuda:
        validation_data[_] = validation_data[_].cuda()
validation_data = [(validation_data[0][i], validation_data[1][i], validation_data[2][i]) 
                   for i in range(len(train_data[0]))] 
validation_loader = torch.utils.data.DataLoader(
    validation_data, batch_size=bsize, shuffle=True)

In [None]:
# 2-layer lstm with mixture of gaussian parameters as outputs
# skip connections added
class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size = 3, hidden_size = cell_size, num_layers = 1, batch_first=True)
        self.lstm2 = nn.LSTM(input_size = cell_size+3, hidden_size = cell_size, num_layers = 1, batch_first=True)
        self.linear1 = nn.Linear(cell_size*2, 1+ num_clusters*6)
        self.tanh = nn.Tanh()
        
    def forward(self, x, prev, prev2):
        timesteps = x.shape[1]
        h1, (h1_n, c1_n) = self.lstm(x, prev)
        x2 = torch.cat([h1, x], dim=-1)
        h2, (h2_n, c2_n) = self.lstm2(x2, prev2)
        h = torch.cat([h1, h2], dim=-1)
        params = self.linear1(h)
        weights = F.softmax(params.narrow(-1, 0, num_clusters), dim=-1)
        mu_1 = params.narrow(-1, num_clusters, num_clusters)
        mu_2 = params.narrow(-1, 2*num_clusters, num_clusters)
        log_sigma_1 = params.narrow(-1, 3*num_clusters, num_clusters)
        log_sigma_2 = params.narrow(-1, 4*num_clusters, num_clusters)
        p = self.tanh(params.narrow(-1, 5*num_clusters, num_clusters))
        end = F.sigmoid(params.narrow(-1, 6*num_clusters, 1))
        
        return end, weights, mu_1, mu_2, log_sigma_1, log_sigma_2, p, (h1_n, c1_n), (h2_n, c2_n)

In [None]:
# training objective
def log_likelihood(end, weights, mu_1, mu_2, log_sigma_1, log_sigma_2, p, x, masks):
    timesteps = x.shape[1]
    x_0 = x.narrow(-1,0,1)
    x_1 = x.narrow(-1,1,1)
    x_2 = x.narrow(-1,2,1)
    end_loglik = (x_0*end + (1-x_0)*(1-end)).log()
    const = 1E-20
    pi_term = torch.Tensor([2*np.pi])
    if cuda:
        pi_term = pi_term.cuda()
    pi_term = -Variable(pi_term, requires_grad = False).log()
    z = (x_1 - mu_1)**2/(log_sigma_1.exp()**2)\
        + ((x_2 - mu_2)**2/(log_sigma_2.exp()**2)) \
        - 2*p*(x_1-mu_1)*(x_2-mu_2)/((log_sigma_1 + log_sigma_2).exp())
    mog_lik1 =  pi_term -log_sigma_1 - log_sigma_2 - 0.5*((1-p**2).log())
    mog_lik2 = z/(2*(1-p**2))
    mog_loglik = ((weights.log() + (mog_lik1 - mog_lik2)).exp().sum(dim=-1)).log()
    
    return (end_loglik*masks).sum() + ((mog_loglik+const)*masks).sum()

In [None]:
def decay_learning_rate(optimizer, decay = 1.01):
    state_dict = optimizer.state_dict()
    lr = state_dict['param_groups'][0]['lr']
    lr /= decay
    for param_group in state_dict['param_groups']:
        param_group['lr'] = lr
    optimizer.load_state_dict(state_dict)
    return optimizer

In [None]:
def save_checkpoint(epoch, model, validation_loss, optimizer, filename='best.pt'):
    checkpoint=({'epoch': epoch+1,
    'model': model.state_dict(),
    'validation_loss': validation_loss,
    'optimizer' : optimizer.state_dict()
    })
    torch.save(checkpoint, filename)

In [None]:
model = LSTM()
if cuda:
    model = model.cuda()

In [None]:
epochs = 10
optimizer = optim.Adam([
                {'params':model.parameters()},
            ], lr=init_lr)

In [None]:
# training
h1_init, c1_init = torch.zeros((1,bsize,cell_size)), torch.zeros((1,bsize,cell_size))
h2_init, c2_init = torch.zeros((1,bsize,cell_size)), torch.zeros((1,bsize,cell_size))

h1_init2, c1_init2 = torch.zeros((1, bsize,cell_size)),\
                    torch.zeros((1,bsize,cell_size))
h2_init2, c2_init2 = torch.zeros((1, bsize,cell_size)),\
                    torch.zeros((1,bsize,cell_size))
if cuda:
    h1_init = h1_init.cuda()
    c1_init = c1_init.cuda()
    h2_init = h2_init.cuda()
    c2_init = c2_init.cuda()
    h1_init2 = h1_init2.cuda()
    c1_init2 = c1_init2.cuda()
    h2_init2 = h2_init2.cuda()
    c2_init2 = c2_init2.cuda()
h1_init, c1_init = Variable(h1_init), Variable(c1_init)
h2_init, c2_init = Variable(h2_init), Variable(c2_init)
h1_init2, c1_init2 = Variable(h1_init2), Variable(c1_init2)
h2_init2, c2_init2 = Variable(h2_init2), Variable(c2_init2)


# t_loss = []
# v_loss = []
# best_validation_loss = 1E10

start_time = time.time()
# for epoch in range(epochs):
for epoch in range(40,50):
    train_loss =0
    for batch_idx, (data, masks, onehots) in enumerate(train_loader):
        #step_back = torch.cat([zero_tensor, data.narrow(1,0,timesteps-1)], 1)
        step_back = data.narrow(1,0,timesteps)
        x = Variable(step_back, requires_grad=False)
        masks = Variable(masks, requires_grad=False)
        optimizer.zero_grad()
        
        end, weights, mu_1, mu_2, log_sigma_1, log_sigma_2, p , prev, prev2 = model(x, (h1_init, c1_init), \
                                                                                    (h2_init, c2_init))
        data = data.narrow(1,1,timesteps)
        y = Variable(data, requires_grad=False)
        loss = -log_likelihood(end, weights, mu_1, mu_2, log_sigma_1, log_sigma_2, p, y, masks)/torch.sum(masks)
        loss.backward()
        train_loss += loss.data[0]
        optimizer.step()
        if batch_idx % 6 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch+1, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.data[0]))
        
        

    print('====> Epoch: {} Average train loss: {:.4f}'.format(
          epoch+1, train_loss / len(train_loader.dataset)))
    t_loss.append(train_loss / len(train_loader.dataset))
    
    # validation
    # prepare validation data
    (validation_samples, masks, onehots) = list(enumerate(validation_loader))[0][1]
    step_back2 = validation_samples.narrow(1,0,timesteps)
    masks = Variable(masks, requires_grad=False)
    
    x2 = Variable(step_back2, requires_grad=False)
    
    validation_samples = validation_samples.narrow(1,1,timesteps)
    y2 = Variable(validation_samples, requires_grad = False)

    end, weights, mu_1, mu_2, log_sigma_1, log_sigma_2, p , prev, prev2= model(x2, (h1_init2, c1_init2), \
                                                                              (h2_init2, c2_init2))
    loss = -log_likelihood(end, weights, mu_1, mu_2, log_sigma_1, log_sigma_2, p, y2, masks)/torch.sum(masks)
    validation_loss = loss.data[0]
    print('====> Epoch: {} Average validation loss: {:.4f}'.format(
          epoch+1, validation_loss))
    v_loss.append(validation_loss)

    if validation_loss < best_validation_loss:
        best_validation_loss = validation_loss
        save_checkpoint(epoch, model, validation_loss, optimizer)
    
    # learning rate annealing
    # optimizer = decay_learning_rate(optimizer, 1.03)
    
    # checkpoint model and training
    #if (epochs+1)%5 == 0:
    filename = 'epoch_{}_800.pt'.format(epoch+1)
    save_checkpoint(epoch, model, validation_loss, optimizer, filename)

    # testing checkpoints
    state = torch.load(filename)
    model.load_state_dict(state['model'])
    optimizer.load_state_dict(state['optimizer'])
    
    print('wall time: {}s'.format(time.time()-start_time))

In [None]:
def generate_unconditionally(model, steps=700, random_seed=1):
    torch.manual_seed(random_seed)
    zero_tensor = torch.zeros((1,1,3))
    h1_init, c1_init = torch.zeros((1,1,cell_size)), torch.zeros((1,1,cell_size))
    h2_init, c2_init = torch.zeros((1,1,cell_size)), torch.zeros((1,1,cell_size))
    if cuda:
        zero_tensor = zero_tensor.cuda()
        h1_init = h1_init.cuda()
        c1_init = c1_init.cuda()
        h2_init = h2_init.cuda()
        c2_init = c2_init.cuda()
    x = Variable(zero_tensor)
    h1_init, c1_init = Variable(h1_init), Variable(c1_init)
    h2_init, c2_init = Variable(h2_init), Variable(c2_init)
    prev = (h1_init, c1_init)
    prev2 = (h2_init, c2_init)
    
    record = []
    # greedy but not the right generation
    for i in range(steps):        
        end, weights, mu_1, mu_2, log_sigma_1, log_sigma_2, p, prev, prev2 = model(x, prev, prev2)
        prob_end = end.data[0][0][0]
        sample_end = np.random.binomial(1,prob_end)
        
        sample_index = np.random.choice(range(20),p = weights.data[0][0].cpu().numpy())
        mu = np.array([mu_1.data[0][0][sample_index], mu_2.data[0][0][sample_index]])
        v1 = log_sigma_1.exp().data[0][0][sample_index]**2
        v2 = log_sigma_2.exp().data[0][0][sample_index]**2
        c = p.data[0][0][sample_index]*log_sigma_1.exp().data[0][0][sample_index]\
            *log_sigma_2.exp().data[0][0][sample_index]
        cov = np.array([[v1,c],[c,v2]])
        sample_point = np.random.multivariate_normal(mu, cov)
        out = np.insert(sample_point,0,sample_end)
        record.append(out)
        x = torch.from_numpy(out).type(torch.FloatTensor)
        if cuda:
            x = x.cuda()
        x = Variable(x, requires_grad=False)
        x = x.view((1,1,3))
    return np.array(record)

In [None]:
def greedy_generate(model, steps=700):
    #torch.manual_seed(random_seed)
    zero_tensor = torch.zeros((1,1,3))
    h1_init, c1_init = torch.zeros((1,1,cell_size)), torch.zeros((1,1,cell_size))
    h2_init, c2_init = torch.zeros((1,1,cell_size)), torch.zeros((1,1,cell_size))
    if cuda:
        zero_tensor = zero_tensor.cuda()
        h1_init = h1_init.cuda()
        c1_init = c1_init.cuda()
        h2_init = h2_init.cuda()
        c2_init = c2_init.cuda()
    x = Variable(zero_tensor)
    h1_init, c1_init = Variable(h1_init), Variable(c1_init)
    h2_init, c2_init = Variable(h2_init), Variable(c2_init)
    prev = (h1_init, c1_init)
    prev2 = (h2_init, c2_init)
    
    record=[]
    for i in range(steps):        
        end, weights, mu_1, mu_2, log_sigma_1, log_sigma_2, p , prev, prev2 = model(x, prev, prev2)
        prob_end = end.data[0][0][0]
        sample_end = np.round(prob_end)
        
        sample_index = np.argmax(weights.data[0][0].cpu().numpy())
        mu = np.array([mu_1.data[0][0][sample_index], mu_2.data[0][0][sample_index]])

        out = np.array([sample_end, mu[0], mu[1]])
        record.append(out)
        x = torch.from_numpy(out).type(torch.FloatTensor)
        if cuda:
            x = x.cuda()
        x = Variable(x, requires_grad=False)
        x = x.view((1,1,3))
    return np.array(record)

In [None]:
state2 = 

In [None]:
# test generation 60
s = generate_unconditionally(model,700, 20)
plot_stroke(s)

In [None]:
# test generation 60
g = greedy_generate(model)
g[650][0]=1
plot_stroke(g)

In [None]:
def control_generate(model, steps=700, random_seed=1, temp=.9):
    torch.manual_seed(random_seed)
    zero_tensor = torch.zeros((1,1,3))
    h1_init, c1_init = torch.zeros((1,1,cell_size)), torch.zeros((1,1,cell_size))
    h2_init, c2_init = torch.zeros((1,1,cell_size)), torch.zeros((1,1,cell_size))
    if cuda:
        zero_tensor = zero_tensor.cuda()
        h1_init = h1_init.cuda()
        c1_init = c1_init.cuda()
        h2_init = h2_init.cuda()
        c2_init = c2_init.cuda()
    x = Variable(zero_tensor)
    h1_init, c1_init = Variable(h1_init), Variable(c1_init)
    h2_init, c2_init = Variable(h2_init), Variable(c2_init)
    prev = (h1_init, c1_init)
    prev2 = (h2_init, c2_init)
    
    record = []
    # greedy but not the right generation
    for i in range(steps):        
        if np.random.random() < temp:
            end, weights, mu_1, mu_2, log_sigma_1, log_sigma_2, p , prev, prev2 = model(x, prev, prev2)
            prob_end = end.data[0][0][0]
            sample_end = np.round(prob_end)

            sample_index = np.argmax(weights.data[0][0].cpu().numpy())
            mu = np.array([mu_1.data[0][0][sample_index], mu_2.data[0][0][sample_index]])

            out = np.array([sample_end, mu[0], mu[1]])
            record.append(out)
            x = torch.from_numpy(out).type(torch.FloatTensor)
            if cuda:
                x = x.cuda()
            x = Variable(x, requires_grad=False)
            x = x.view((1,1,3))
        
        else:
            end, weights, mu_1, mu_2, log_sigma_1, log_sigma_2, p, prev, prev2 = model(x, prev, prev2)
            prob_end = end.data[0][0][0]
            sample_end = np.random.binomial(1,prob_end)

            sample_index = np.random.choice(range(20),p = weights.data[0][0].cpu().numpy())
            mu = np.array([mu_1.data[0][0][sample_index], mu_2.data[0][0][sample_index]])
            v1 = log_sigma_1.data[0][0][sample_index]**2
            v2 = log_sigma_2.data[0][0][sample_index]**2
            c = p.data[0][0][sample_index]*log_sigma_1.data[0][0][sample_index]*log_sigma_2.data[0][0][sample_index]
            cov = np.array([[v1,c],[c,v2]])
            sample_point = np.random.multivariate_normal(mu, cov)
            out = np.insert(sample_point,0,sample_end)
            record.append(out)
            x = torch.from_numpy(out).type(torch.FloatTensor)
            if cuda:
                x = x.cuda()
            x = Variable(x, requires_grad=False)
            x = x.view((1,1,3))
    return np.array(record)

In [None]:
# test generation 30
s = generate_unconditionally(200,42)
plot_stroke(s)

In [None]:
# test generation 30
g = greedy_generate()
g[650][0]=1
plot_stroke(g)


In [None]:
c = control_generate(model)
plot_stroke(c)