In [1]:
import tensorflow as tf
from transformers import AutoTokenizer, TFT5ForConditionalGeneration, DataCollatorForSeq2Seq
from datasets import load_dataset
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import ExponentialDecay

In [2]:
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

In [3]:
model_name = "google/flan-t5-small"

In [31]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFT5ForConditionalGeneration.from_pretrained(model_name)

All PyTorch model weights were used when initializing TFT5ForConditionalGeneration.

All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


In [5]:
dataset = load_dataset("gigaword")

In [6]:
dataset

DatasetDict({
    train: Dataset({
        features: ['document', 'summary'],
        num_rows: 3803957
    })
    validation: Dataset({
        features: ['document', 'summary'],
        num_rows: 189651
    })
    test: Dataset({
        features: ['document', 'summary'],
        num_rows: 1951
    })
})

In [7]:
train_d = []
train_s = []
for data in dataset["train"]:
    train_d.append(data["document"])
    train_s.append(data["summary"])

In [8]:
import numpy as np

In [9]:
train_d_length = [len(text.split()) for text in train_d]
train_s_length = [len(text.split()) for text in train_s]

In [10]:
np.percentile(train_d_length, 95)

45.0

In [11]:
np.percentile(train_s_length, 95)

13.0

In [12]:
def preprocess_function(examples):
    inputs = tokenizer(examples["document"], max_length=48, truncation=True, padding="max_length")
    outputs = tokenizer(examples["summary"], max_length=16, truncation=True, padding="max_length")
    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": outputs["input_ids"]
    }

In [13]:
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names)

In [14]:
data_collator = DataCollatorForSeq2Seq(model=model, tokenizer=tokenizer, return_tensors="tf")

In [15]:
data_collator

DataCollatorForSeq2Seq(tokenizer=T5TokenizerFast(name_or_path='google/flan-t5-small', vocab_size=32100, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>'

In [32]:
train_dataset = tokenized_dataset["train"].to_tf_dataset(shuffle=True,
                                                         batch_size=128,
                                                         collate_fn=data_collator)

val_dataset = tokenized_dataset["validation"].to_tf_dataset(shuffle=True,
                                                            batch_size=128,
                                                            collate_fn=data_collator)

test_dataset = tokenized_dataset["test"].to_tf_dataset(shuffle=True,
                                                       batch_size=128,
                                                       collate_fn=data_collator)

In [33]:
train_dataset, val_dataset, test_dataset

(<PrefetchDataset element_spec={'input_ids': TensorSpec(shape=(None, 48), dtype=tf.int64, name=None), 'attention_mask': TensorSpec(shape=(None, 48), dtype=tf.int64, name=None), 'labels': TensorSpec(shape=(None, 16), dtype=tf.int64, name=None), 'decoder_input_ids': TensorSpec(shape=(None, 16), dtype=tf.int64, name=None)}>,
 <PrefetchDataset element_spec={'input_ids': TensorSpec(shape=(None, 48), dtype=tf.int64, name=None), 'attention_mask': TensorSpec(shape=(None, 48), dtype=tf.int64, name=None), 'labels': TensorSpec(shape=(None, 16), dtype=tf.int64, name=None), 'decoder_input_ids': TensorSpec(shape=(None, 16), dtype=tf.int64, name=None)}>,
 <PrefetchDataset element_spec={'input_ids': TensorSpec(shape=(None, 48), dtype=tf.int64, name=None), 'attention_mask': TensorSpec(shape=(None, 48), dtype=tf.int64, name=None), 'labels': TensorSpec(shape=(None, 16), dtype=tf.int64, name=None), 'decoder_input_ids': TensorSpec(shape=(None, 16), dtype=tf.int64, name=None)}>)

In [34]:
len(train_dataset)

29719

# Unfreeze the last layer

In [19]:
initial_learning_rate = 5e-5
decay_steps = 10000
decay_rate = 0.95

lr_schedule = ExponentialDecay(
    initial_learning_rate,
    decay_steps=decay_steps,
    decay_rate=decay_rate,
    staircase=True
)

optimizer = Adam(learning_rate=lr_schedule)

In [20]:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [35]:
for layer in model.layers[:-1]:
    layer.trainable = False

In [36]:
for layer in model.layers:
    print(layer.name, layer.trainable)

shared False
encoder False
decoder False
lm_head True


In [37]:
model.summary()

Model: "tft5_for_conditional_generation_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 shared (Embedding)          multiple                  16449536  
                                                                 
 encoder (TFT5MainLayer)     multiple                  35332800  
                                                                 
 decoder (TFT5MainLayer)     multiple                  41628352  
                                                                 
 lm_head (Dense)             multiple                  16449536  
                                                                 
Total params: 76,961,152
Trainable params: 16,449,536
Non-trainable params: 60,511,616
_________________________________________________________________


In [24]:
import os
import datetime

In [25]:
def tensorboard_cb(dirpath, model_name):
    return tf.keras.callbacks.TensorBoard(os.path.join(dirpath, 
                                                       model_name, 
                                                       datetime.datetime.now().strftime("%Y%m%d-%H%M%S")))

In [38]:
model.compile(optimizer=optimizer, loss=loss)

In [39]:
history = model.fit(train_dataset,
                    validation_data=val_dataset,
                    epochs=3,
                    callbacks=[tensorboard_cb("model_logs/", "model_flan_t5_small_ggw_v1")])

Epoch 1/3
Epoch 2/3
Epoch 3/3


In [40]:
model.save_pretrained("model_flan_t5_small_ggw_v1")

In [41]:
tokenizer.save_pretrained("model_flan_t5_small_ggw_v1")

('model_flan_t5_small_ggw_v1\\tokenizer_config.json',
 'model_flan_t5_small_ggw_v1\\special_tokens_map.json',
 'model_flan_t5_small_ggw_v1\\spiece.model',
 'model_flan_t5_small_ggw_v1\\added_tokens.json',
 'model_flan_t5_small_ggw_v1\\tokenizer.json')

In [42]:
loaded_model = TFT5ForConditionalGeneration.from_pretrained("model_flan_t5_small_ggw_v1")
loaded_tokenizer = AutoTokenizer.from_pretrained("model_flan_t5_small_ggw_v1")

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at model_flan_t5_small_ggw_v1.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


In [45]:
model.evaluate(test_dataset)



2.308198928833008

In [46]:
loaded_model.compile(loss=loss)

In [47]:
loaded_model.evaluate(test_dataset)



2.308198928833008

# Testing fine tuned model

In [101]:
testing_statement = "Advancements in artificial intelligence are reshaping industries, with applications ranging from healthcare to finance. The integration of AI technologies is transforming how businesses operate and make decisions, leading to increased efficiency and innovation."

In [109]:
input = loaded_tokenizer(testing_statement, return_tensors="tf", max_length=48, truncation=True)
input

{'input_ids': <tf.Tensor: shape=(1, 45), dtype=int32, numpy=
array([[18377,  4128,    16,  7353,  6123,    33,     3,    60,     7,
         9516,    53,  5238,     6,    28,  1564,     3,  6836,    45,
         4640,    12,  4747,     5,    37,  5660,    13,  7833,  2896,
           19,     3, 21139,   149,  1623,  4368,    11,   143,  3055,
            6,  1374,    12,  1936,  3949,    11,  4337,     5,     1]])>, 'attention_mask': <tf.Tensor: shape=(1, 45), dtype=int32, numpy=
array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1]])>}

In [110]:
summary_ids = loaded_model.generate(input["input_ids"], max_length=48, length_penalty=2.0, num_beams=4, early_stopping=True)
summary_ids

<tf.Tensor: shape=(1, 10), dtype=int32, numpy=
array([[    0,  7353,  6123,    19,     3, 21139,   149,  1623,  4368,
            1]])>

In [111]:
summary = loaded_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
print(summary)

artificial intelligence is transforming how businesses operate


In [113]:
base_model = TFT5ForConditionalGeneration.from_pretrained("Yihim/flan_t5_small_ggw_v1")

config.json:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

tf_model.h5:   0%|          | 0.00/440M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at Yihim/flan_t5_small_ggw_v1.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


generation_config.json:   0%|          | 0.00/142 [00:00<?, ?B/s]

In [114]:
summary_ids = base_model.generate(input["input_ids"], max_length=48, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summary

'artificial intelligence is transforming how businesses operate'