In [None]:
import drawing
import numpy as np

import os

import torch 
from torch import tensor
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

import time

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print(device)


n_output_mixtures = 20
n_attn_mixtures = 1
hidden_size = 400
input_size = 3
batch_size = 128
n_epochs = 1000
timestamps = 1200
lr = 1e-5
eps = 1e-6
model_name = "20210811_Synthesis-K1"
base_dir = "/u/home/lyrebird_code/handwriting_synthesis/model_checkpoints/"
embedding_size = None # To be set to 73 later


In [None]:
a = torch.randn(2,1,  3)
a
nn.Softmax(0)(a)

In [None]:
data_dir = "data"
data = [np.load(os.path.join(data_dir, '{}.npy'.format(i))) for i in ['x', 'x_len', 'c', 'c_len', 'w_id']]
strokes_og = data[0]
stroke_lens = data[1]
strings = data[2]
string_lens = data[3]
w_id = data[4]

In [None]:
from collections import defaultdict

alphabet = [
    '\x00', ' ', '!', '"', '#', "'", '(', ')', ',', '-', '.',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';',
    '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K',
    'L', 'M', 'N', 'O', 'P', 'R', 'S', 'T', 'U', 'V', 'W', 'Y',
    'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
    'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
    'y', 'z'
]
vocab_size = len(alphabet)
eos_char = '\x00'
alpha_to_num = defaultdict(int, list(map(reversed, enumerate(alphabet))))
num_to_alpha = defaultdict(int, list(enumerate(alphabet)))

In [None]:
def standarize(strokes, lengths):
    points_to_consider = np.zeros((lengths.sum() - 2*len(lengths), 2))
    idx = 0
    for i in range(0, len(lengths)):
        #print(strokes[i][1:lengths[i], 0:2].shape)
        points_to_consider[idx:idx+lengths[i]-2] = strokes[i][1:lengths[i]-1, 0:2]
        idx += lengths[i] - 2
    means = points_to_consider.mean(axis=0)
    std = points_to_consider.std(axis=0)
    
    result = strokes.copy()
    for i in range(len(lengths)):
        result[i, 1:lengths[i]-1, 0:2] -= means
        result[i, 1:lengths[i]-1, 0:2] /= std
    return result, means, std

def destandarize(strokes_standarized, lengths, means, std):
    result = strokes_standarized.copy()
    for i in range(len(lengths)):
        result[i, 1:lengths[i]-1, 0:2] *= std
        result[i, 1:lengths[i]-1, 0:2] += means
    return result

def preprocess(strokes_og, stroke_lens, strings, string_lengths):
    # My preprocesisng is simple - all strokes start with (0, 0, 0) nd end with (0, 0, 1)

    # to this end we remove all full length strokes to make space for the end 
    # (0, 0, 1) as the end point

    indices = stroke_lens != 1200
    strokes_og = strokes_og[indices]
    stroke_lens = stroke_lens[indices]
    strings = strings[indices]
    string_lengths = string_lengths[indices]

    # Now we make sure that every stroke has the correct 
    # start point of (0, 0, 0) and end point of (0, 0, 1)
    strokes_og[:,0,2] = 0
    for i in range(strokes_og.shape[0]):
        strokes_og[i, stroke_lens[i], 2] = 1
    stroke_lens += 1
    
    # Standarize the coordinates separaetly to have zero mean and std dev 1
    standarized_strokes, means, std = standarize(strokes_og, stroke_lens)
    
    texts = [drawing.decode_ascii(s) + eos_char for s in strings]
    text_lengths = [len(t) for t in texts]
    
    max_text_length = max(text_lengths)
    num_texts = len(texts)
    one_hots = torch.zeros(num_texts, max_text_length, vocab_size, device=device, dtype=torch.float64)
    for i, text in enumerate(texts):
        for j, c in enumerate(text):
            one_hots[i, j, alpha_to_num[c]] = 1
    
    return standarized_strokes, means, std, stroke_lens, texts, text_lengths, one_hots, vocab_size

def string_to_one_hot(string):
    if (string[-1] != eos_char):
        string += eos_char
    one_hot = torch.zeros(len(string), vocab_size, device=device)
    for i, c in enumerate(string):
        one_hot[i, alpha_to_num[c]] = 1
    return one_hot

def one_hot_to_string(one_hot):
    # assuming that one_hot is numpy
    def get_char(arr):
        arr = list(arr)
        if 1 not in arr:
            return ''
        return num_to_alpha[list(arr).index(1)]
    return ''.join([get_char(a) for a in one_hot])

In [None]:
one_hot_to_string(string_to_one_hot("sample").cpu().numpy())

In [None]:
strokes, means, std, stroke_lengths, texts, text_lengths, one_hots, embedding_size  = preprocess(
    strokes_og, stroke_lens, strings, string_lens)
#print(strokes.shape, stroke_lengths.shape, means.shape, std.shape, texts.shape, text_lengths.shape, one_hots.shape)


In [None]:
strokes[2][stroke_lengths[2]-2:stroke_lengths[2]+2]

In [None]:
idx = torch.randint(low=0, high=len(texts), size=(1,)).item()
print(idx)
texts[idx]
text_lengths[idx]
one_hot_to_string(one_hots[idx])

In [None]:
embedding_size

In [None]:
import matplotlib.pyplot as plt

def offsets_to_coords(offsets):
    """
    convert from offsets to coordinates
    """
    return np.concatenate([np.cumsum(offsets[:, :2], axis=0), offsets[:, 2:3]], axis=1)


def draw(offsets, plot_end_points=True):
    strokes = offsets_to_coords(offsets)
    
    fig, ax = plt.subplots(figsize=(12, 6))
    stroke = []
    for x, y, eos in strokes:
        stroke.append((x, y))
        if eos == 1:
            coords = list(zip(*stroke))
            ax.plot(coords[0], coords[1], 'k')
            if plot_end_points:
                ax.plot(x, y, 'ro')
            stroke = []
    if stroke:
        coords = list(zip(*stroke))
        ax.plot(coords[0], coords[1], 'k')
        stroke = []

    padding = 10
    
    ax.set_xlim(strokes[:,0].min() - padding, strokes[:,0].max() + padding)
    ax.set_ylim(strokes[:,1].min() - padding, strokes[:,1].max() + padding)

    ax.set_aspect('equal')
    plt.tick_params(
        axis='both',
        left='off',
        top='off',
        right='off',
        bottom='off',
        labelleft='off',
        labeltop='off',
        labelright='off',
        labelbottom='off'
    )
    plt.show()
    plt.close('all')
    

def attention_plot(phis):
    phis = phis.cpu().numpy()
    _=plt.plot(phis.T)
    plt.show()
    plt.rcParams["figure.figsize"] = (12,6)
    #phis= phis/(phis.sum(dim = 0) + eps)
    plt.xlabel('handwriting generation')
    plt.ylabel('text scanning')
    plt.imshow(phis, cmap='Greys', interpolation='nearest', aspect='auto')
    plt.show()

In [None]:
draw(strokes[0])
res2 = destandarize(strokes, stroke_lengths, means, std) # for checking
draw(res2[0])
print(drawing.decode_ascii(strings[0]))

a = torch.randint(low = 0, high = 10, size = (5, 20))
a
attention_plot(a)

In [None]:
class HandwritingDataset(Dataset):
    def __init__(self, strokes, stroke_lengths, one_hots, text_lengths):
        assert len(strokes) == len(one_hots)
        self.strokes = strokes
        self.stroke_lengths = stroke_lengths
        self.one_hots = one_hots
        self.one_hot_lengths = text_lengths
        self.stroke_mask = self.getStrokeMask()
        self.one_hot_mask = self.getOneHotMask()
        
    def __len__(self):
        return self.strokes.shape[0]
    
    def __getitem__(self, idx):
        return self.strokes[idx], self.stroke_lengths[idx], self.stroke_mask[idx], self.one_hots[idx], self.one_hot_lengths[idx], self.one_hot_mask[idx]
    
    def getStrokeMask(self):
        mask = np.ones((self.strokes.shape[0], self.strokes.shape[1]))
        for i in range(self.strokes.shape[0]):
            # The mask is true to the string lengths.
            # Any offsets for Special training of RNN needs to be handled separately
            mask[i][self.stroke_lengths[i]:] = 0
        return mask
    
    def getOneHotMask(self):
        mask = np.ones((self.one_hots.shape[0], self.one_hots.shape[1]))
        for i in range(self.one_hots.shape[0]):
            # The mask is true to the string lengths.
            # Any offsets for Special training of RNN needs to be handled separately
            mask[i][self.one_hot_lengths[i]:] = 0
        return mask
    
train_dataset = HandwritingDataset(strokes[:-150], stroke_lengths[:-150], one_hots[:-150], text_lengths[:-150])
test_dataset = HandwritingDataset(strokes[-150:], stroke_lengths[-150:], one_hots[-150:], text_lengths[-150:])


In [None]:
# dl = DataLoader(
#     train_dataset,
#     shuffle=True,
#     batch_size=2,
#     drop_last=True)

# k = iter(dl)
# for j in range(1):
#     print("JJJJJJJJJJJJJJJJJ:", j)
#     s, l, m, oh, ohl, ohm = next(k)
#     l0 = l[0]
#     l1 = l[1]
#     l
#     s.shape
#     m.shape

#     print("Testing the 0th position of every stroke")
#     s[:,0,:]

#     print("testing the length location of stroke")
#     s[0,l0-2:l0+2, :]
#     s[1,l1-2:l1+2, :]

#     print("testing the length location of mask")
#     m[0,l0-2:l0+2]
#     m[1,l1-2:l1+2]
    
    
#     l0 = ohl[0]
#     l1 = ohl[1]
#     ohl
#     oh.shape
#     ohm.shape

#     print("Testing the 0th position of every one hot")
#     oh[:,0,:]

#     print("testing the length location of one hot")
#     oh[0,l0-2:l0+2, :]
#     oh[1,l1-2:l1+2, :]

#     print("testing the length location of mask")
#     ohm[0,l0-2:l0+2]
#     ohm[1,l1-2:l1+2]
    
    
    

In [None]:
# bs = 2
# u = torch.arange(1, 11, device=device)[None, :, None].repeat(bs, 1, 1)
# alpha = torch.rand(bs, 5)[:, None]
# beta = torch.rand(bs, 5)[:, None]
# kappa = torch.randn(bs, 5)[:, None]
# u, u.shape
# alpha, alpha.shape
# beta, beta.shape
# kappa, kappa.shape
# res = alpha * torch.exp(-beta * (kappa - u)**2)
# res.shape
# res.sum(dim=-1).shape

In [None]:
# b = torch.rand(2, 9)
# b[None].shape
# b[:, None].shape
# b[:, :, None].shape
# b.chunk(5, -1)

# a = torch.rand(2, 3, 4)
# wt = torch.randn(2, 3)[:, :, None]
# a 
# wt
# (a * wt).sum(1)

In [None]:
class GaussianAttention(nn.Module):
    def __init__(self, lstm_hidden_size, n_mixtures):
        super(GaussianAttention, self).__init__()
        self.n_mixtures = n_mixtures
        self.linear_layer = nn.Linear(lstm_hidden_size, 3*n_mixtures)
        
    def forward(self, h, kappa_prev, one_hot_batch, one_hot_mask, eps=1e-6):
        B, T = one_hot_mask.shape
        
        out_1 = self.linear_layer(h) # (B, 3*K)
        alpha, beta, kappa = (torch.exp(out_1) + eps)[:, None].chunk(3, dim=-1) # (B, 1, K) each
        #kappa = (kappa + kappa_prev) * 0.3
        kappa = kappa * 0.04 + kappa_prev
        
        u = torch.arange(0, T, device="cpu")[None, :, None].repeat(B, 1, 1).to(device) # (B, T, 1)
        
        phi = alpha * torch.exp(-beta * torch.pow(kappa - u, 2)) # (B, T, K)
        phi = phi.sum(dim=-1) # (B, T)
        
        phi = (phi * one_hot_mask)[:, :, None] # (B, T, 1)
        
        w = (phi * one_hot_batch).sum(1) # (B, V) V = vocab_size
        
        attn_params = {
            "w": w,
            "phi": phi,
            "alpha": alpha,
            "beta": beta,
            "kappa": kappa,
            "out_1": out_1
        }

        return attn_params
        

In [None]:
# batch_size = 2
# hidden_size = 4
# input_size = 3
# n_mixtures = 5
# vocab_size = 6
# timesteps = 7

# h = torch.randn(batch_size, hidden_size)
# kappa_prev = torch.randn(batch_size, n_mixtures)[:,None,:]
# ohb = torch.randn(batch_size, timesteps, vocab_size)
# ohm = torch.ones(batch_size, timesteps)
# ohm[0, 5:] = 0
# ohm[1, 3:] = 0
# ohm
# gl = GaussianAttention(hidden_size, n_mixtures)
# gl(h, kappa_prev, ohb, ohm)

In [None]:
from einops import rearrange 

class HandwritingSynthesis(nn.Module):
    def __init__(self, input_size, hidden_size, n_output_mixtures, n_attn_mixtures, embedding_size):
        super(HandwritingSynthesis, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_output_mixtures = n_output_mixtures
        self.n_attn_mixtures = n_attn_mixtures
        self.embedding_size = embedding_size
        self.output_size = 6*n_output_mixtures + 1
        
        self.lstm1 = nn.LSTMCell(input_size + embedding_size, hidden_size)
        self.gaussian_attn = GaussianAttention(hidden_size, n_attn_mixtures)
        self.lstm2 = nn.LSTMCell(embedding_size + input_size + hidden_size, hidden_size)
        self.linear = nn.Linear(hidden_size, self.output_size)
        
    def forward(self, stroke_t, one_hots, one_hot_mask, prev_state):
        bs = stroke_t.shape[0]
        # K: n_attn_mixtures, E: embedding_size, H: hidden_size, T: Total timesteps
        # stroke_t: (B,3), kappa: (B,1,K), attn: (B,E), hidden h and c : (B, H)
        # one_hots : (B, T, E), one_hot_mask = (B, T)
        attn_prev, kappa_prev, hid_1, hid_2 = prev_state["w"], prev_state["kappa"], prev_state["hidden_1"], prev_state["hidden_2"]
        
        hid_1 = self.lstm1(
            torch.cat([stroke_t, attn_prev], dim=-1),
            hid_1)
        
        params_new = self.gaussian_attn(
            hid_1[0],
            kappa_prev,
            one_hots,
            one_hot_mask)
        
        attn_new = params_new["w"]
        
        hid_2 = self.lstm2(
            torch.cat([stroke_t, attn_new, hid_1[0]], dim=-1),
            hid_2)
        
        out = self.linear(hid_2[0])
        
        params_new["hidden_1"] = hid_1
        params_new["hidden_2"] = hid_2
        
        return out, params_new

In [None]:
# batch_size = 2
# input_size = 3
# hidden_size = 4
# n_output_mixtures = 8
# n_attn_mixtures = 5
# embedding_size = 6
# timestamps = 7

# h = HandwritingSynthesis(input_size, hidden_size, n_output_mixtures, n_attn_mixtures, embedding_size)

# stroke = torch.randn(batch_size, 3)
# stroke[:, 2] = 0
# stroke

# one_hots = torch.zeros(batch_size, timestamps, embedding_size)
# one_hots[0,:5, 0]= 1
# one_hots[1, :3, 3] =1
# print("one_hots", one_hots)

# one_hot_mask = torch.ones(batch_size, timestamps)
# one_hot_mask[0,5:] = 0
# one_hot_mask[1,3:] = 0
# print("one_hot_mask", one_hot_mask)

# attn = {
#     "w": torch.randn(batch_size, embedding_size),
#     "phi": torch.randn(batch_size, timestamps, 1),
#     "alpha": torch.randn(batch_size, 1, n_attn_mixtures),
#     "beta": torch.randn(batch_size, 1, n_attn_mixtures),
#     "kappa": torch.randn(batch_size, 1, n_attn_mixtures)
# }
# prev_state = (attn,
#               (torch.zeros(batch_size, hidden_size), torch.zeros(batch_size, hidden_size)),
#               (torch.zeros(batch_size, hidden_size), torch.zeros(batch_size, hidden_size)))
# print("prev_state", prev_state)


# h(stroke, one_hots, one_hot_mask, prev_state)

In [None]:
def get_mixture_params_from_output(outputs, bias=0):
    # outputs:  b, max_len-1, 6*n_output_mixtures+1
    pis = nn.Softmax(2)((1+bias)*outputs[:,:,:n_output_mixtures])
    mus = rearrange(outputs[:,:,n_output_mixtures:3*n_output_mixtures], 'b l (n d) -> b l n d', d=2)
    
    sigmas = rearrange(torch.exp(outputs[:,:,3*n_output_mixtures:5*n_output_mixtures]-bias), 'b l (n d) -> b l n d', d=2) + eps
    phos = rearrange((1-eps) * torch.tanh(outputs[:,:,5*n_output_mixtures:6*n_output_mixtures]), 'b l (n d) -> b l n d', d=1)

    
    covs = torch.zeros(outputs.shape[0], outputs.shape[1], n_output_mixtures, 2, 2, device=device)
    covs[:,:,:,0,0] = sigmas[:,:,:,0]**2
    covs[:,:,:,1,1] = sigmas[:,:,:,1]**2
    covs[:,:,:,0,1] = phos[:,:,:,0] * sigmas[:,:,:,0] * sigmas[:,:,:,1]
    covs[:,:,:,1,0] = covs[:,:,:,0,1]
    
    return pis, mus, covs

def get_mixture_distributions_from_output(outputs):
    pis, mus, covs = get_mixture_params_from_output(outputs)
    
    distributions = torch.distributions.MultivariateNormal(mus, covs)
    return pis, distributions

def get_pen_lift_probs_from_output(outputs):
    return 1/(1 + torch.exp(outputs[:,:,-1]))

def nll(outputs, targets, mask_batch):
    # outputs:  b, max_len-1, 6*n_mixtures+1
    # targets: b, max_len-1, 3
    outputs = outputs.to(device)
    targets = targets.to(device)
    
    target_coords = targets[:,:,0:2].unsqueeze(2).repeat_interleave(
            torch.tensor([n_output_mixtures], device=device), dim=2)
    # target_coords: b, max_len-1, n_mixtures, 3
    stroke_lift = targets[:,:,-1] # b, max_len-1, 1
    
    pis, distributions = get_mixture_distributions_from_output(outputs)
    es = get_pen_lift_probs_from_output(outputs)

    probs = distributions.log_prob(target_coords) + eps
    loss1 = - torch.logsumexp(torch.log(pis) + probs, dim=2)
    loss2 = - torch.log(es)*stroke_lift
    loss3 = - torch.log(1 - es)*(1-stroke_lift)
    loss_per_point = loss1 + loss2 + loss3
    loss_per_point *= mask_batch
    
    return loss_per_point.sum()/outputs.shape[0]

In [None]:
def get_initial_prev_states(mode='train'):
    dim = batch_size if mode == 'train' else 1
    h01 = torch.zeros(dim, hidden_size, device=device)
    c01 = torch.zeros(dim, hidden_size, device=device)
    h02 = torch.zeros(dim, hidden_size, device=device)
    c02 = torch.zeros(dim, hidden_size, device=device)
    kappa = torch.zeros(dim, 1, n_attn_mixtures, device=device)
    w = torch.zeros(dim, embedding_size, device=device)
    
    prev_states = {
        "kappa": torch.zeros(dim, 1, n_attn_mixtures, device=device),
        "w": torch.zeros(dim, embedding_size, device=device),
        "hidden_1": (h01, c01),
        "hidden_2": (h02, c02)
    }
    return prev_states

def get_next_point(model, point_prev, text_one_hot, mask, prev_states, bias=0):
    with torch.no_grad():
        outputs, prev_states = model(point_prev.unsqueeze(0), text_one_hot, mask, prev_states)
        
        outputs = outputs[:, None]
        
        es = get_pen_lift_probs_from_output(outputs)
        pis, mus, covs = get_mixture_params_from_output(outputs, bias=0)
        
        sample_index = 0
        if n_output_mixtures > 1:
            sample_index = np.random.choice(
                range(n_output_mixtures),
                p = pis.squeeze().cpu().numpy())

        pen_off = torch.tensor(
            [np.random.binomial(1,es.item())],
            device=device)

        sample_point = torch.tensor(
            np.random.multivariate_normal(
                mus.squeeze(0).squeeze(0)[sample_index].cpu().numpy(),
                covs.squeeze(0).squeeze(0)[sample_index].cpu().numpy()),
            device=device)
        
        return torch.cat((sample_point, pen_off), dim=0), prev_states
        
def sample(model, text, bias=0):
    timestamps = 500
    text_one_hot = string_to_one_hot(text)[None]
    print("text_one_hot", text_one_hot.shape, text_one_hot)
    mask = torch.ones(1, text_one_hot.shape[1], device=device)
    print("mask", mask.shape, mask)
    sample_stroke = torch.zeros(timestamps, 3, device=device)
    prev_states = get_initial_prev_states("sample")

    phis = []
    hid1s = []
    hid2s = []
    alphas = []
    betas = []
    kappas = []
    for i in range(timestamps-2):
        prev_point = sample_stroke[i]
        new_point, prev_states = get_next_point(model, prev_point, text_one_hot, mask, prev_states, bias=0)
        phis.append(prev_states["phi"].squeeze()[None])
        hid1s.append(prev_states["hidden_1"][0].squeeze()[None])
        hid2s.append(prev_states["hidden_2"][0].squeeze()[None])
        alphas.append(prev_states["alpha"].squeeze()[None])
        betas.append(prev_states["beta"].squeeze()[None])
        kappas.append(prev_states["kappa"].squeeze()[None])
        sample_stroke[i+1] = new_point
    phis = torch.cat(phis, dim=0).T
    hid1s = torch.cat(hid1s, dim=0).T
    hid2s = torch.cat(hid2s, dim=0).T
    alphas = torch.cat(alphas, dim=0).T
    betas = torch.cat(betas, dim=0).T
    kappas = torch.cat(kappas, dim=0).T
    return sample_stroke, phis, hid1s, hid2s, alphas, betas, kappas



In [None]:
from einops import rearrange

def get_inputs_targets_mask(strokes_batch, strokes_mask):
    inputs = strokes_batch[:, :-1, :]
    targets = strokes_batch[:, 1:, :]
    mask = strokes_mask[:, 1:]
    return inputs, targets, mask

def train_batch(model,
                optimizer,
                strokes,
                strokes_mask,
                one_hots,
                one_hot_lengths,
                one_hots_mask,
                prev_states,
                output_holder):
    
    optimizer.zero_grad()
    
    inputs, targets, mask = get_inputs_targets_mask(strokes, strokes_mask)
    
    T = inputs.shape[1]
    
    for t in range(T): 
        output_holder[t], prev_states = model(
            inputs[:, t, :],
            one_hots,
            one_hots_mask,
            prev_states)
    loss = nll(rearrange(output_holder, 'T B S -> B T S'), targets, mask)
    loss.backward()
    
    #### Do Gradient clipping here if need be
    torch.nn.utils.clip_grad_value_(model.parameters(), 10)
    torch.nn.utils.clip_grad_value_(loss, 100)
    
    optimizer.step()

    return loss.item(), prev_states["kappa"].squeeze()

def save_model(model, model_name):
    torch.save({
        'model_state_dict': model.state_dict(),
    }, base_dir + model_name)
    
def detach(x):
    if type(x) is tuple:
        return (x[0], x[1])
    if type(x) is dict:
        d = {}
        for k in x:
            d[k] = x[k].detach()
        return d
    return x.detach()
    
def train(model, optimizer, train_dataloader, num_epochs):
    losses = []
    n_iter = 0
    total_loss = 0
    start = time.time()
    
    print_every = 50
    sample_every = print_every * 5
    
    prev_states = get_initial_prev_states("train")
    output_holder = torch.zeros(timestamps-1, batch_size, 6*n_output_mixtures+1, device=device).float()
    for epoch in range(num_epochs):
        
        for strokes, _, strokes_mask, one_hots, one_hot_lengths, one_hots_mask in train_dataloader:
            strokes = strokes.to(device).float()
            strokes_mask = strokes_mask.to(device).float()
            one_hots = one_hots.to(device).float()
            one_hots_mask = one_hots_mask.to(device).float()
            one_hot_lengths = one_hot_lengths.to(device).float()
            
            batch_loss, ka = train_batch(
                model,
                optimizer,
                strokes,
                strokes_mask,
                one_hots,
                one_hot_lengths,
                one_hots_mask,
                prev_states,
                output_holder)
            
            for k in prev_states:
                prev_states[k] = detach(prev_states[k])
                
            output_holder = output_holder.detach()
            
            total_loss += batch_loss
            
            if n_iter % print_every == 0:
                avg_loss = total_loss/print_every
                losses.append(avg_loss)
                print(f"iteration: {n_iter} "\
                      f"of {len(train_dataloader) * num_epochs}, " \
                      f"avg_loss: {avg_loss:.2f}, "\
                      f"timeSinceStart: {time.time() - start :.2f}, "\
                      f"Epoch: {epoch}")
                total_loss = 0
                save_model(model, model_name)
                
            if n_iter % sample_every == 0:
                    text = "can you write this please"
                    sample_stroke, phis, hid1s, hid2s, alphas, betas, kappas = sample(model, text, bias=0.5)
                    print(text)
                    sample_stroke[1:, 0:2] *= torch.tensor(std, device=device)
                    sample_stroke[1:, 0:2] += torch.tensor(means, device=device)
                    draw(sample_stroke.cpu(), plot_end_points=False)
                    print("actual kappa")
                    attention_plot(ka.squeeze().detach()[None])
                    print("phi")
                    attention_plot(phis)
                    print("hid_1")
                    attention_plot(hid1s)
                    print("hid_2")
                    attention_plot(hid2s)
                    print("alphas")
                    attention_plot(alphas[None])
                    print("betas")
                    attention_plot(betas[None])
                    print("kappas")
                    attention_plot(kappas[None])
                
            n_iter += 1
    return losses

In [None]:
load_previous_state = True
model = HandwritingSynthesis(input_size, hidden_size, n_output_mixtures, n_attn_mixtures, embedding_size).to(device)
if load_previous_state:
    checkpoint = torch.load(base_dir + model_name)
    model.load_state_dict(checkpoint['model_state_dict'])
optim = torch.optim.Adam(model.parameters(), lr = lr)
train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=batch_size,
    drop_last=True)


In [None]:
train(model, optim, train_dataloader, n_epochs)