#Install libraries

In [None]:
!pip install datasets
!pip install transformers
!pip install 'huggingface_hub[tensorflow]'
!pip install tensorflow

#Import libraries

In [None]:
import datasets
from datasets import load_dataset,load_dataset_builder
from transformers import AutoTokenizer, AutoConfig, TFAutoModelForPreTraining,AdamWeightDecay
from transformers.keras_callbacks import PushToHubCallback

from huggingface_hub import notebook_login

import tensorflow as tf

#Declare config parameters

In [None]:
DRIVE_MOUNT_PATH="/content/drive/"

FULL_DATASET_PATH = DRIVE_MOUNT_PATH + "MyDrive/colab/product_review_generator/datasets/full_dataset"
TRAIN_DATASET_PATH = DRIVE_MOUNT_PATH + "MyDrive/colab/product_review_generator/datasets/train_dataset"
VAL_DATASET_PATH= DRIVE_MOUNT_PATH + "MyDrive/colab/product_review_generator/datasets/val_dataset"
MODEL_SAVE_PATH=DRIVE_MOUNT_PATH + "MyDrive/colab/product_review_generator/model/fine_tuned"

DATASET_NAME="amazon_us_reviews"
SUBSET_NAME="Apparel_v1_00"
COLUMNS_TO_REMOVE = ['marketplace', 'customer_id', 'review_id', 'product_id', 'product_parent',
                     'product_category','helpful_votes', 'total_votes', 'vine','review_date']

RECORDS_PER_LABEL = 50000

VALIDATION_DATA_SPLIT = 0.2

MODEL_NAME="distilgpt2"

HF_MODEL_LOCAL_PATH="/content/model_local"
HF_MODEL_ID="praveenseb/product_review_generator"
HF_TOKEN = "<HF_TOKEN>"

SPECIAL_TOKENS  = { "bos_token": "<|BOS|>",
                   "eos_token": "<|EOS|>",
                   "unk_token": "<|UNK|>",
                   "pad_token": "<|PAD|>",
                   "sep_token": "<|SEP|>"}

REVIEW_LEN_MIN = 10
REVIEW_LEN_MAX = 100

TOKEN_LEN_MAX  = 300

RATING_DEF = {
    1: 'Terrible',
    2: 'Bad',
    3: 'Acceptable ',
    4: 'Good',
    5: 'Excellent'
}

TRAIN_BATCH_SIZE=32

EPOCHS=2

#Mount Google Drive in the runtime's VM

In [None]:
from google.colab import drive
drive.mount(DRIVE_MOUNT_PATH, force_remount=True)

#Fetch dataset info

In [None]:
amz_builder =  load_dataset_builder(DATASET_NAME,  SUBSET_NAME)
print("Dataset features -",amz_builder.info.features)
print("Dataset splits -",amz_builder.info.splits)

#Load the dataset and save to disk

In [None]:
amz_dataset = load_dataset(DATASET_NAME,  SUBSET_NAME)

#Save to disk
amz_dataset.save_to_disk(FULL_DATASET_PATH)

In [None]:
#load from disk if there is a saved version avaiable
#amz_dataset=datasets.load_from_disk(FULL_DATASET_PATH)

In [None]:
print("Total number of records - ",amz_dataset["train"].num_rows)

#Pre-process the dataset

In [None]:
def create_train_val_data(input_dataset):
  #Remove columns that are not required
  input_dataset = input_dataset["train"].remove_columns(COLUMNS_TO_REMOVE)

  #Filter on verified_purchase= 1 (True) and review_body word count. Suffle the filtered dataset
  filtered_dataset = input_dataset.filter(lambda example: example["verified_purchase"] == 1 
                                   and len(example["review_body"].split()) in range(REVIEW_LEN_MIN,REVIEW_LEN_MAX)).shuffle()
  print("Record count after filtering on verified_purchase and review_body word count -",filtered_dataset.num_rows)

  #Pick equal number of records for ratings 1 to 5
  for i in range(1,6):
    temp_dict=filtered_dataset.filter(lambda example: example["star_rating"] == i).shuffle()[:RECORDS_PER_LABEL]
    temp_dataset = datasets.Dataset.from_dict(temp_dict)
    if i==1:
      processed_dataset = temp_dataset
    else:
      processed_dataset = datasets.concatenate_datasets([processed_dataset,temp_dataset])

  print("Number of records in the processed dataset -",processed_dataset.num_rows)
  return processed_dataset.train_test_split(shuffle = True, test_size=VALIDATION_DATA_SPLIT)

#Create Training and Validation datasets

In [None]:
amz_train_val_dataset = create_train_val_data(amz_dataset)

train_dataset = amz_train_val_dataset["train"]
val_dataset = amz_train_val_dataset["test"]

#save the final train and test datasets to disk
train_dataset.save_to_disk(TRAIN_DATASET_PATH)
val_dataset.save_to_disk(VAL_DATASET_PATH)

In [None]:
#load from disk if there is a saved version avaiable

#train_dataset=datasets.load_from_disk(TRAIN_DATASET_PATH)
#val_dataset=datasets.load_from_disk(VAL_DATASET_PATH)

In [None]:
print("Number of training records -",train_dataset.num_rows)
print("Number of validation records -",val_dataset.num_rows)

#Tokenizer and helper functions

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.add_special_tokens(SPECIAL_TOKENS)

In [None]:
#Get rid of &#34 and <br /> tags in the text
def clean_string(text):
    text = str.replace(text,"&#34;","\"")
    text = str.replace(text,"<br />","")
    return text

In [None]:
def tokenize_function(example):
    example['input_text'] = SPECIAL_TOKENS['bos_token']+ \
    example['product_title']+ \
    SPECIAL_TOKENS['sep_token']+ \
    RATING_DEF[(example['star_rating'])]+ \
    SPECIAL_TOKENS['sep_token']+ \
    clean_string(example['review_headline'])+ \
    SPECIAL_TOKENS['sep_token']+ \
    clean_string(example['review_body'])+ \
    SPECIAL_TOKENS['eos_token']
    
    tokens = tokenizer(example["input_text"], padding="max_length", truncation=True, max_length=TOKEN_LEN_MAX)
    tokens["labels"] = tokens["input_ids"].copy()
    return tokens

In [None]:
train_token = train_dataset.map(tokenize_function)
val_token = val_dataset.map(tokenize_function)

#Define the model

In [None]:
config = AutoConfig.from_pretrained(MODEL_NAME,
                                    bos_token_id=tokenizer.bos_token_id,
                                    eos_token_id=tokenizer.eos_token_id,
                                    sep_token_id=tokenizer.sep_token_id,
                                    pad_token_id=tokenizer.pad_token_id,
                                    output_hidden_states=False)

model = TFAutoModelForPreTraining.from_pretrained(MODEL_NAME, config=config)

model.resize_token_embeddings(len(tokenizer))

In [None]:
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.0002,
    decay_steps=1000,
    decay_rate=0.95,
    staircase=True)
    
optimizer = AdamWeightDecay(learning_rate=lr_schedule)

In [None]:
model.compile(optimizer=optimizer)
model.summary()

#Fine-tune the model

In [None]:
train_tf_dataset=model.prepare_tf_dataset(
    train_token,shuffle=True,batch_size=TRAIN_BATCH_SIZE)

val_tf_dataset=model.prepare_tf_dataset(
    val_token,shuffle=True,batch_size=TRAIN_BATCH_SIZE)

In [None]:
#HF Login
notebook_login()

In [None]:
hfhub_callback = PushToHubCallback(
    output_dir=HF_MODEL_LOCAL_PATH,
    tokenizer=tokenizer,
    save_strategy = "epoch",
    checkpoint = True,
    hub_model_id=HF_MODEL_ID,
    hub_token = HF_TOKEN
)

In [None]:
model.fit(train_tf_dataset, epochs=EPOCHS,callbacks=[hfhub_callback])

In [None]:
model.save_pretrained(MODEL_SAVE_PATH)
drive.flush_and_unmount()

In [None]:
val_loss = model.evaluate(val_tf_dataset)
print("Validation loss is ",val_loss)

#Generate review text with the fine-tuned model

In [None]:
title = "Columbia Women's Benton Springs Full-Zip Fleece Jacket"
rating = 5
review_title = "Awesome Jacket!"

prompt = SPECIAL_TOKENS['bos_token'] + title + \
                SPECIAL_TOKENS['sep_token'] +  RATING_DEF[rating] + SPECIAL_TOKENS['sep_token'] + \
                 review_title + SPECIAL_TOKENS['sep_token']
print("The input prompt is -",prompt) 
         
prompt_tokens = tf.expand_dims(tf.convert_to_tensor(tokenizer.encode(prompt)),0)

In [None]:
#Generate 10 sample reviews
generated_text = model.generate(prompt_tokens,                                   
                                min_length=10, 
                                max_length=150,
                                top_k=30,                                 
                                top_p=0.7,        
                                temperature=0.9,
                                repetition_penalty=2.0,
                                num_return_sequences=10,
                                do_sample=True
                                )

In [None]:
for i, text in enumerate(generated_text):
    text = tokenizer.decode(text, skip_special_tokens=False)
    review = text.split("<|SEP|>")[3].split("<|EOS|>")[0]
    print("\n",i,review)