In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import datetime
import tensorflow_addons as tfa
PAD_ID = 0

class DateData:
    def __init__(self, n):
        np.random.seed(1)
        self.date_cn = []
        self.date_en = []
        for timestamp in np.random.randint(143835585, 2043835585, n):
            date = datetime.datetime.fromtimestamp(timestamp)
            self.date_cn.append(date.strftime("%y-%m-%d"))
            self.date_en.append(date.strftime("%d/%b/%Y"))
        self.vocab = set(
            [str(i) for i in range(0, 10)] + ["-", "/", "<GO>", "<EOS>"] + [
                i.split("/")[1] for i in self.date_en])
        self.v2i = {v: i for i, v in enumerate(sorted(list(self.vocab)), start=1)}
        self.v2i["<PAD>"] = PAD_ID
        self.vocab.add("<PAD>")
        self.i2v = {i: v for v, i in self.v2i.items()}
        self.x, self.y = [], []
        for cn, en in zip(self.date_cn, self.date_en):
            self.x.append([self.v2i[v] for v in cn])
            self.y.append(
                [self.v2i["<GO>"], ] + [self.v2i[v] for v in en[:3]] + [
                    self.v2i[en[3:6]], ] + [self.v2i[v] for v in en[6:]] + [
                    self.v2i["<EOS>"], ])
        self.x, self.y = np.array(self.x), np.array(self.y)
        self.start_token = self.v2i["<GO>"]
        self.end_token = self.v2i["<EOS>"]

    def sample(self, n=64):
        bi = np.random.randint(0, len(self.x), size=n)
        bx, by = self.x[bi], self.y[bi]
        decoder_len = np.full((len(bx),), by.shape[1] - 1, dtype=np.int32)
        return bx, by, decoder_len

    def idx2str(self, idx):
        x = []
        for i in idx:
            x.append(self.i2v[i])
            if i == self.end_token:
                break
        return "".join(x)

    @property
    def num_word(self):
        return len(self.vocab)

In [2]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import tensorflow_addons as tfa
import pickle


class Seq2Seq(keras.Model):
    def __init__(self, enc_v_dim, dec_v_dim, emb_dim, units, attention_layer_size, max_pred_len, start_token, end_token):
        super().__init__()
        self.units = units

        # encoder
        self.enc_embeddings = keras.layers.Embedding(
            input_dim=enc_v_dim, output_dim=emb_dim,    # [enc_n_vocab, emb_dim]
            embeddings_initializer=tf.initializers.RandomNormal(0., 0.1),
        )
        self.encoder = keras.layers.LSTM(units=units, return_sequences=True, return_state=True)

        # decoder
        self.attention = tfa.seq2seq.LuongAttention(units, memory=None, memory_sequence_length=None)
        self.decoder_cell = tfa.seq2seq.AttentionWrapper(
            cell=keras.layers.LSTMCell(units=units),
            attention_mechanism=self.attention,
            attention_layer_size=attention_layer_size,
            alignment_history=True,                     # for attention visualization
        )

        self.dec_embeddings = keras.layers.Embedding(
            input_dim=dec_v_dim, output_dim=emb_dim,    # [dec_n_vocab, emb_dim]
            embeddings_initializer=tf.initializers.RandomNormal(0., 0.1),
        )
        decoder_dense = keras.layers.Dense(dec_v_dim)   # output layer

        # train decoder
        self.decoder_train = tfa.seq2seq.BasicDecoder(
            cell=self.decoder_cell,
            sampler=tfa.seq2seq.sampler.TrainingSampler(),   # sampler for train
            output_layer=decoder_dense
        )
        self.cross_entropy = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.opt = keras.optimizers.Adam(0.05, clipnorm=5.0)

        # predict decoder
        self.decoder_eval = tfa.seq2seq.BasicDecoder(
            cell=self.decoder_cell,
            sampler=tfa.seq2seq.sampler.GreedyEmbeddingSampler(),       # sampler for predict
            output_layer=decoder_dense
        )

        # prediction restriction
        self.max_pred_len = max_pred_len
        self.start_token = start_token
        self.end_token = end_token

    def encode(self, x):
        o = self.enc_embeddings(x)
        init_s = [tf.zeros((x.shape[0], self.units)), tf.zeros((x.shape[0], self.units))]
        o, h, c = self.encoder(o, initial_state=init_s)
        return o, h, c

    def set_attention(self, x):
        o, h, c = self.encode(x)
        # encoder output for attention to focus
        self.attention.setup_memory(o)
        # wrap state by attention wrapper
        s = self.decoder_cell.get_initial_state(batch_size=x.shape[0], dtype=tf.float32).clone(cell_state=[h, c])
        return s

    def inference(self, x, return_align=False):
        s = self.set_attention(x)
        done, i, s = self.decoder_eval.initialize(
            self.dec_embeddings.variables[0],
            start_tokens=tf.fill([x.shape[0], ], self.start_token),
            end_token=self.end_token,
            initial_state=s,
        )
        pred_id = np.zeros((x.shape[0], self.max_pred_len), dtype=np.int32)
        for l in range(self.max_pred_len):
            o, s, i, done = self.decoder_eval.step(
                time=l, inputs=i, state=s, training=False)
            pred_id[:, l] = o.sample_id
        if return_align:
            return np.transpose(s.alignment_history.stack().numpy(), (1, 0, 2))
        else:
            s.alignment_history.mark_used()  # otherwise gives warning
            return pred_id

    def train_logits(self, x, y, seq_len):
        s = self.set_attention(x)
        dec_in = y[:, :-1]   # ignore <EOS>
        dec_emb_in = self.dec_embeddings(dec_in)
        o, _, _ = self.decoder_train(dec_emb_in, s, sequence_length=seq_len)
        logits = o.rnn_output
        return logits

    def step(self, x, y, seq_len):
        with tf.GradientTape() as tape:
            logits = self.train_logits(x, y, seq_len)
            dec_out = y[:, 1:]  # ignore <GO>
            loss = self.cross_entropy(dec_out, logits)
            grads = tape.gradient(loss, self.trainable_variables)
        self.opt.apply_gradients(zip(grads, self.trainable_variables))
        return loss.numpy()


In [3]:
data = DateData(2000)
print("Chinese time order: yy/mm/dd ", data.date_cn[:3], "\nEnglish time order: dd/M/yyyy ", data.date_en[:3])
print("vocabularies: ", data.vocab)
print("x index sample: \n{}\n{}".format(data.idx2str(data.x[0]), data.x[0]),
        "\ny index sample: \n{}\n{}".format(data.idx2str(data.y[0]), data.y[0]))

model = Seq2Seq(
    data.num_word, data.num_word, emb_dim=12, units=14, attention_layer_size=16,
    max_pred_len=11, start_token=data.start_token, end_token=data.end_token)

# training
for t in range(1000):
    bx, by, decoder_len = data.sample(64)
    loss = model.step(bx, by, decoder_len)
    if t % 70 == 0:
        target = data.idx2str(by[0, 1:-1])
        pred = model.inference(bx[0:1])
        res = data.idx2str(pred[0])
        src = data.idx2str(bx[0])
        print(
            "t: ", t,
            "| loss: %.5f" % loss,
            "| input: ", src,
            "| target: ", target,
            "| inference: ", res,
        )



Chinese time order: yy/mm/dd  ['31-04-26', '04-07-18', '33-06-06'] 
English time order: dd/M/yyyy  ['26/Apr/2031', '18/Jul/2004', '06/Jun/2033']
vocabularies:  {'7', '2', 'Apr', '1', 'Sep', 'Mar', '4', '<GO>', '9', 'Feb', 'Oct', 'Nov', 'Jun', '<PAD>', 'Jul', 'Dec', '<EOS>', '8', '3', 'May', '-', '5', 'Jan', '0', '6', 'Aug', '/'}
x index sample: 
31-04-26
[6 4 1 3 7 1 5 9] 
y index sample: 
<GO>26/Apr/2031<EOS>
[14  5  9  2 15  2  5  3  6  4 13]
t:  0 | loss: 3.29482 | input:  89-05-25 | target:  25/May/1989 | inference:  22222000000
t:  70 | loss: 0.50147 | input:  03-09-13 | target:  13/Sep/2003 | inference:  13/Sep/2000<EOS>
t:  140 | loss: 0.11515 | input:  92-06-01 | target:  01/Jun/1992 | inference:  01/Jan/1992<EOS>
t:  210 | loss: 0.00174 | input:  23-01-28 | target:  28/Jan/2023 | inference:  28/Jan/2023<EOS>
t:  280 | loss: 0.00048 | input:  25-08-01 | target:  01/Aug/2025 | inference:  01/Aug/2025<EOS>
t:  350 | loss: 0.00032 | input:  75-04-21 | target:  21/Apr/1975 | infere