In [None]:
import tensorflow as tf
import numpy as np
import pathlib
import warnings
import tensorflow_text as tf_text
warnings.filterwarnings('ignore')

2024-03-11 15:41:41.329903: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-03-11 15:41:41.465738: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-11 15:41:41.465778: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-11 15:41:41.471917: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-11 15:41:41.485468: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-03-11 15:41:41.486267: I tensorflow/core/platform/cpu_feature_guard.cc:1

In [None]:
path_to_file = tf.keras.utils.get_file("spa-eng.zip", origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip', extract=True)
path_to_file = pathlib.Path(path_to_file).parent/'spa-eng/spa.txt'
path_to_file

In [None]:
def load_data(path):
    text = path.read_text(encoding="utf-8")
    lines = text.splitlines()
    pairs = [line.split("\t") for line in lines]
    context = np.array([context for target,context in pairs])
    target = np.array([target for target,context in pairs])
    return context, target

In [None]:
context_raw, target_raw = load_data(path_to_file)

In [None]:
context_raw.shape

In [None]:
context_raw[-1]

In [None]:
target_raw[-1]

In [None]:
BUFFER_SIZE=len(target_raw)
BATCH_SIZE=64
is_train = np.random.uniform(size=(len(target_raw,))) < 0.8
train_raw = tf.data.Dataset.from_tensor_slices( (context_raw[is_train], target_raw[is_train]) ).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
val_raw = tf.data.Dataset.from_tensor_slices( (context_raw[~is_train], target_raw[~is_train]) ).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
for (example_context_strings, example_target_strings) in train_raw.take(1):
    print(example_context_strings[:5])
    print(example_target_strings[:5])

In [None]:
def tf_lower_and_split_punct(text):
    text = tf_text.normalize_utf8(text, "NFKD")
    text = tf.strings.lower(text)
    text = tf.strings.regex_replace(text, '[^ a-z.?!,¿]', '')
    text = tf.strings.regex_replace(text, '[.?!,¿]', r' \0 ')
    text = tf.strings.strip(text)
    text = tf.strings.join( [ '[START]', text, '[END]' ], separator=' ')
    return text

In [None]:
example_text = tf.constant('¿Todavía está en casa?')
print(tf_lower_and_split_punct(example_text).numpy().decode('UTF-8'))

In [None]:
max_vocab_size=5000
context_text_processor = tf.keras.layers.TextVectorization(max_tokens=max_vocab_size, standardize=tf_lower_and_split_punct, ragged=True)
context_text_processor.adapt(train_raw.map(lambda context, target : context))

In [None]:
target_text_processor = tf.keras.layers.TextVectorization(max_tokens=max_vocab_size, standardize=tf_lower_and_split_punct, ragged=True)
target_text_processor.adapt(train_raw.map(lambda context, target : target))

In [None]:
context_text_processor(example_context_strings[:5])

In [None]:
target_text_processor(example_target_strings[:5])

In [None]:
context_text_processor.get_vocabulary()[:10]

In [None]:
target_text_processor.get_vocabulary()[:10]

In [None]:
def process_text(context, target):
    context = context_text_processor(context).to_tensor()
    target = target_text_processor(target)
    targ_in = target[:,:-1].to_tensor()
    targ_out = target[:,1:].to_tensor()
    return (context, targ_in), targ_out

In [None]:
train_ds = train_raw.map(process_text, tf.data.AUTOTUNE)
val_ds = val_raw.map(process_text, tf.data.AUTOTUNE)

In [None]:
for (ex_context_tok, ex_tar_in), ex_tar_out in train_ds.take(1):
    print(ex_context_tok[0])
    print(ex_tar_in[0])
    print(ex_tar_out[0])

In [None]:
ex_context_tok.shape

In [None]:
ex_tar_in.shape

In [None]:
ex_tar_out.shape

In [None]:
UNITS=256

In [None]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, text_processor, units):
        super().__init__()
        self.text_processor = text_processor
        self.vocab_size = text_processor.vocabulary_size()
        self.units = units
        self.embedding = tf.keras.layers.Embedding(input_dim=self.vocab_size, output_dim=self.units, mask_zero=True)
        self.rnn = tf.keras.layers.Bidirectional(merge_mode='sum', 
                                                 layer=tf.keras.layers.GRU(self.units, 
                                                                           return_sequences=True, 
                                                                           recurrent_initializer='glorot_uniform')
                                                )
    def call(self, context):
        x= context
        x = self.embedding(x)
        x = self.rnn(x)
        return x

In [None]:
encoder = Encoder(context_text_processor, UNITS)

In [None]:
ex_context=encoder(ex_context_tok)

In [None]:
ex_context.shape

In [None]:
class CrossAttention(tf.keras.layers.Layer):
    def __init__(self, units):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(key_dim=units, num_heads=1)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()
    def call(self, x, context):
        attn_output, attn_scores = self.mha(query=x, value=context, return_attention_scores=True)
        attn_scores = tf.reduce_mean(attn_scores, axis=1)
        self.last_attention_weights=attn_scores
        x = self.add([x, attn_output])
        return self.layernorm(x)

In [None]:
attention_layer = CrossAttention(UNITS)


In [None]:
embed = tf.keras.layers.Embedding(input_dim=target_text_processor.vocabulary_size(), output_dim=UNITS)
ex_embed = embed(ex_tar_in)
ex_embed.shape

In [None]:
result=attention_layer(ex_embed, ex_context)
result.shape

In [None]:
attention_layer.last_attention_weights[0].numpy().sum(axis=-1)

In [None]:
class Decoder(tf.keras.layers.Layer):
    @classmethod
    def add_method(cls, fun):
        setattr(cls,fun.__name__,fun)
        return fun

In [None]:
@Decoder.add_method
def __init__(self, text_processor, units):
    super(Decoder, self).__init__()
    self.text_processor = text_processor
    self.units = units
    self.vocab_size = text_processor.vocabulary_size()
    self.word_to_id = tf.keras.layers.StringLookup(vocabulary=text_processor.get_vocabulary(), mask_token='',oov_token='[UNK]')
    self.id_to_word = tf.keras.layers.StringLookup(vocabulary=text_processor.get_vocabulary(), mask_token='',oov_token='[UNK]', invert=True)
    self.start_token=self.word_to_id('[START]')
    self.end_token=self.word_to_id('[END]')
    self.embedding = tf.keras.layers.Embedding(input_dim=self.vocab_size, output_dim=self.units, mask_zero=True)
    self.rnn = tf.keras.layers.GRU(units, return_sequences=True, return_state=True, recurrent_initializer='glorot_uniform')
    self.attention = CrossAttention(self.units)
    self.output_layer = tf.keras.layers.Dense(self.vocab_size)

In [None]:
@Decoder.add_method
def call(self, context, x, state=None, return_state=False):
    x = self.embedding(x)
    x,state = self.rnn(x, initial_state=state)
    x = self.attention(x, context)
    logits = self.output_layer(x)
    if return_state:
        return logits, state
    else:
        return logits

In [None]:
decoder = Decoder(target_text_processor, UNITS)

In [None]:
decoder(ex_context,ex_tar_in).shape

In [None]:
@Decoder.add_method
def get_initial_state(self, context):
    batch_size = tf.shape(context)[0]
    start_tokens = tf.fill([batch_size,1], self.start_token)
    done = tf.zeros([batch_size,1], dtype=tf.bool)
    embedded = self.embedding(start_tokens)
    return start_tokens, done, self.rnn.get_initial_state(embedded)[0]

In [None]:
@Decoder.add_method
def tokens_to_text(self, tokens):
    words = self.id_to_word(tokens)
    result = tf.strings.reduce_join(words, axis=-1, separator=' ')
    result = tf.strings.regex_replace(result, '^ *\[START\] *', '')
    result = tf.strings.regex_replace(result, ' *\[END\] *$', '')
    return result

In [None]:
@Decoder.add_method
def get_next_token(self, context, next_token, done, state, temperature=0.0):
    logits, state = self(context, next_token, state, return_state=True)
    if temperature == 0.0:
        next_token = tf.argmax(logits, axis=-1)
    else:
        logits = logits[:,-1,:]/temperature
        next_token = tf.random.categorical(logits, num_samples=1)
    done = done | (next_token==self.end_token)
    next_token = tf.where(done, tf.constant(0, dtype=tf.int64), next_token)
    return next_token, done, state
        

In [None]:
next_token, done, state = decoder.get_initial_state(ex_context)
tokens=[]

for n in range(10):
    next_token, done, state = decoder.get_next_token(ex_context, next_token, done, state)
    tokens.append(next_token)

tokens = tf.concat(tokens, axis=-1)
result=decoder.tokens_to_text(tokens)
result[:3]