# About this Notebook

The goal of this notebook is to build a classifier to find toxic comments. The data has been taken from a series of Kaggle competitions to classify Wikipedia comments as toxic/nontoxic. The data has been sourced from Google and Jigsaw. 

The notebook will start with simple bag-of-words and tf-idf features and use simple models like logistic regression and Naive Bayes to perform classification with these features. Though the full dataset includes non-English comments, I will restrict myself to English-only comment for this iteration. 

We will then move on to deep learning approaches, using a combination of pretrained word embeddings and simple deep learning models like RNNs and 1D convolutions to do more benchmarking. 

Next, we will explore deep learning models that have 'memory' using LSTMs (Long Short Term Memory) and GRUs (Gated Recurrent Units). 

Finally, we will approach state of the art performance using pretrained models like BERT and xlnet.

For metrics, I will focus on both ROC and precision-recall curves. In addition, I will look at the confusion matrix and performance across different flavors of toxicity.

Credits:
- https://www.kaggle.com/tanulsingh077/deep-learning-for-nlp-zero-to-transformers-bert
- https://www.kaggle.com/jagangupta/stop-the-s-toxic-comments-eda
- https://www.kaggle.com/clinma/eda-toxic-comment-classification-challenge
- https://www.kaggle.com/abhi111/naive-bayes-baseline-and-logistic-regression

I will do a tiered approach to feature engineering and building the model:

No Deep Learning:
1. Do cleanup of text for things like punctuation, numbers, weird symbols, etc. 
2. Use regex and string functions to essentially do tokenization.
3. Create non-semantic features related to capitalization, misspelling, punctuation, length, repetition, etc. 
4. Use NB and logistic models with regularization.
5. Have clear metrics that evaluate on different types of toxicity and pull out examples where model does poorly.

Deep Learning:
1. Use standard tokenizers and compare with 'homegrown' version from above.
2. Use open source word embeddings for corpus as input to RNN models. Quantify how misspellings affect the standard tokenizers.
3. Find way to input additional features like punctuation/capitalization from approach above to Deep Learning RNN models.
4. Try progressively more complicated deep learning sequence models approaching SOTA.
5. Use metrics from above.

Potential Modules:
1. Correct misspellings
2. Analytics for preprocessing
3. Analytics for model performance (use multi-labels, make easy way to look at specific examples)
4. Automatically generate a lookup table for common variations of words (particularly toxic words, e.g., 'mothafucka' -> 'motherfucker')




In [None]:
import numpy as np
import pandas as pd 
from collections import defaultdict as ddict, Counter
from itertools import compress
from tqdm import tqdm
from scipy.sparse import csr_matrix, hstack
from sklearn import metrics
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer 
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB
from wordcloud import WordCloud

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
from plotly import graph_objs as go
import plotly.express as px
import plotly.figure_factory as ff
import re

import random
import string
from nltk.corpus import stopwords
stop = stopwords.words('english')
from nltk.stem.wordnet import WordNetLemmatizer 
lemmatizer = WordNetLemmatizer()
from nltk.tokenize import word_tokenize
# Tweet tokenizer does not split at apostophes which is what we want
from nltk.tokenize import TweetTokenizer   
pd.options.display.max_rows = 999

## Load data

In [None]:
pre_path = '/kaggle/input/'

In [None]:
train = pd.read_csv(pre_path + 'jigsaw-multilingual-toxic-comment-classification/jigsaw-toxic-comment-train.csv')
#The following is a non-English dataset and won't be used presently
validation = pd.read_csv(pre_path + 'jigsaw-multilingual-toxic-comment-classification/validation.csv')
#The following is a non-English dataset and won't be used presently
test = pd.read_csv(pre_path + 'jigsaw-multilingual-toxic-comment-classification/test.csv')

## EDA

In [None]:
print(train.info())
train.head()
print(train.describe())

In [None]:
CATEGORIES = list(train.columns[2:8])
df_comb = train.groupby(CATEGORIES)\
                    .size()\
                    .sort_values(ascending=False)\
                    .reset_index()\
                    .rename(columns={0: 'count'})

df_comb['label'] ='nontoxic'

for i in range(len(df_comb)):
    label_index = df_comb.iloc[i,0:6].values.astype(bool)
    label = ', '.join(list(compress(CATEGORIES, label_index)))
    if label:
        df_comb.loc[i, 'label'] = label

df_comb.head(n=20)

In [None]:
df_comb[(df_comb['count']>20) & (df_comb['count']<100000)].plot.bar(x='label', y='count', figsize=(17,8))
plt.yscale('log')
plt.xticks(size=15)

#Think of a more compelling multilabel visualization (decision tree/dendogram with labels?)

In [None]:
sns.heatmap(train.iloc[:,2:8].corr(), annot=True)

By looking at the labels, we can see that roughly 90% of the 200K+ comments are nontoxic. The remaining unsavory comments have a combination of labels including toxic, severe_toxic, obscene, insult, identity_hate, and threat. The bulk of the comments are vanilla toxic and the next most common are comments that are both toxic with a combination of obscence and/or insult. Interestingly, about 4% of the 22.5K unsavory comments do not have a toxic label; they are a combination of obscene and insult. Overall, toxic is the label to predict, though it will be interesting to see how different types of models do with different flavors of toxicity.

## Preprocessing steps

In [None]:
#s="string. With. Punctuation?" 
#s.translate(str.maketrans('', '', string.punctuation))
#set(string.punctuation)

lemmatizer.lemmatize('rocks')

In [None]:
#First let's try simple regex and string-based options. 
#Later I can use more nltk and/or word embedding based models
#Additional cleaning steps should include lemmatization and trying to correct for misspellings (n-gram approach?)


def remove_punctuation(text, exclude=["'"]):
    #Remove punctuation but leave apostrophe
    #TO DO: remove numbers
    if exclude:
        punctuation_to_remove = ''.join(list(set(string.punctuation)-set(exclude)))
    else:
        punctuation_to_remove = string.punctuation
    text = text.translate(str.maketrans('', '', punctuation_to_remove)
    return text
                          
def remove_numbers(text):
    #Remove punctuation but leave apostrophe
    #TO DO: remove numbers
    text = text.translate(str.maketrans('', '', string.digits)
    return text
                          
def keep_alpha_char(text):
    pass

def remove_stop_words(text):
    return ' '.join([word.strip() for word in text.split() if word not in stop])

def tokenize(text):
    return text.lower().split()

def clean_text(text):
    return remove_stop_words(remove_punctuation(text))

In [None]:
def lemmatize(text_list, lemmatizer=None):
    if lemmatizer:
        return [lemmatizer.lemmatize(word) for word in text_list]
    else:
        return text_list

In [None]:
word_counter = {}

for categ in CATEGORIES:
    d = Counter()
    train[train[categ] == 1]['comment_text'].apply(lambda t: d.update(lemmatize(clean_text(t).split())))
    word_counter[categ] = pd.DataFrame.from_dict(d, orient='index')\
                                        .rename(columns={0: 'count'})\
                                        .sort_values('count', ascending=False)

In [None]:
def angry_color_func(word, font_size, position, orientation, random_state=None,
                    **kwargs):
    return "hsl(%d, 100%%, 50%%)" % ((random.randint(-40, 40)+360)%360)

for w in word_counter:
    wc = word_counter[w]

    wordcloud = WordCloud(
          background_color='black',
          max_words=200,
          max_font_size=100, 
          random_state=461
         ).generate_from_frequencies(wc.to_dict()['count'])

    fig = plt.figure(figsize=(8, 8))
    plt.title(w.upper().replace('_', ' '), size=40)
    plt.imshow(wordcloud.recolor(color_func=angry_color_func, random_state=3),
           interpolation="bilinear")
    plt.axis('off')

    plt.show()

Overall, we can see a lot of disturbing words for the word clouds in each category. Identity Hate has more specific attacks against race, religion, sexual orientation, and gender. Threat has more hate-related verbs and seems to be a bit different in its words from all the other categories. The most represented categories in the dataset are toxic/obscene/insult. Overall, these categories seem to have similar highly represented words. We will now see if these common words translate into highly predictive features. 

In [None]:
#Quantify how much misspellings and long tail might be affecting results
word_counter['threat'].hist(bins=60)
plt.yscale('log')
print(sum(word_counter['threat']['count']==1))
print(len(word_counter['threat']))
threat = word_counter['threat']
threat.tail(n=999)



We will drop the other columns and approach this problem as a Binary Classification Problem and also we will have our exercise done on a smaller subsection of the dataset(only 12000 data points) to make it easier to train the models

## Sample data

In [None]:
train

In [None]:
#train.drop(['severe_toxic','obscene','threat','insult','identity_hate'],axis=1,inplace=True)

train_full = train.copy()
#train = train.loc[:10000,:]
train.comment_text[train.toxic==1][1:2].values
train.toxic.value_counts()

## Create train and test sets

In [None]:
xtrain, xvalid, ytrain, yvalid = train_test_split(train.comment_text.values, train.toxic.values, 
                                                  stratify=train.toxic.values, 
                                                  random_state=42, 
                                                  test_size=0.2, shuffle=True)


In [None]:
xtrain

## Create Features

In [None]:
count_vectorizer = CountVectorizer(stop_words='english')
count_train = count_vectorizer.fit_transform(xtrain)
count_valid = count_vectorizer.transform(xvalid)


In [None]:
type(count_valid)
count_vectorizer.get_feature_names()[200000:]

## Define Metrics

In [None]:
def run_metrics(predictions, predictions_prob, target, visualize=True):
    fpr, tpr, thresholds = metrics.roc_curve(target, predictions_prob)
    roc_auc = metrics.auc(fpr, tpr)
    precision, recall, thresholds = metrics.precision_recall_curve(target, predictions_prob)
    average_precision = metrics.average_precision_score(yvalid, pred)
    #average_recall = metrics.recall_score(yvalid, pred)
    print('Average precision-recall score: {0:0.2f}'.format(
      average_precision))
    accuracy = metrics.accuracy_score(yvalid, pred)
    print(metrics.confusion_matrix(yvalid, pred, labels=[0,1]))
    print("Accuracy Score: {0:0.2f}".format(accuracy))
    if visualize:
        plt.figure()
        plt.plot(fpr, tpr)
        plt.title('ROC curve, AUC: {0:0.2f}'.format(roc_auc))
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        
        plt.show()
        
        plt.figure()
        plt.plot(recall, precision)
        plt.title('Precision-Recall curve')
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.show()
        #disp = metrics.plot_precision_recall_curve(nb_classifier, count_valid, yvalid)
        #disp.ax_.set_title('2-class Precision-Recall curve: '
                   #'AP={0:0.2f}'.format(average_precision))

## Train Naive Bayes Model

In [None]:
nb_classifier = MultinomialNB()

nb_classifier.fit(count_train, ytrain)
pred = nb_classifier.predict(count_valid)
pred_proba = nb_classifier.predict_proba(count_valid)[:,1]


In [None]:
run_metrics(pred, pred_proba, yvalid, visualize=True)

## Use Tfidf for the features

In [None]:
tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_df=0.7)
count_train_idf = tfidf_vectorizer.fit_transform(xtrain)
count_valid_idf = tfidf_vectorizer.transform(xvalid)

In [None]:
tfidf_vectorizer.get_feature_names()

In [None]:
nb_classifier.fit(count_train_idf, ytrain)
pred = nb_classifier.predict(count_valid_idf)
pred_proba = nb_classifier.predict_proba(count_valid_idf)[:,1]


In [None]:
run_metrics(pred, pred_proba, yvalid, visualize=True)

## Use Deep Learning

In [None]:
import tensorflow as tf
from keras.models import Sequential
from keras.layers.recurrent import LSTM, GRU,SimpleRNN
from keras.layers.core import Dense, Activation, Dropout
from keras.layers.embeddings import Embedding
from keras.layers.normalization import BatchNormalization
from keras.utils import np_utils
from sklearn import preprocessing, decomposition, model_selection, metrics, pipeline
from keras.layers import GlobalMaxPooling1D, Conv1D, MaxPooling1D, Flatten, Bidirectional, SpatialDropout1D
from keras.preprocessing import sequence, text
from keras.callbacks import EarlyStopping, History, ModelCheckpoint, TensorBoard
from tensorflow.keras.optimizers import Adam



## Preprocess data

### We will check the maximum number of words that can be present in a comment , this will help us in padding later

In [None]:
max_len = int(round(train['comment_text'].apply(lambda x:len(str(x).split())).max(), -2)+100)
print("Max length of comment text is: {}".format(max_len))

### First do Tokenization of input corpus

In [None]:
# using keras tokenizer here
token = text.Tokenizer(num_words=None)
token_toxic = text.Tokenizer(num_words=None)
token_nontoxic = text.Tokenizer(num_words=None)

token.fit_on_texts(list(xtrain) + list(xvalid))
token_toxic.fit_on_texts(train.comment_text.values[train.toxic==1])
token_nontoxic.fit_on_texts(train.comment_text.values[train.toxic==0])

xtrain_seq = token.texts_to_sequences(xtrain)
xvalid_seq = token.texts_to_sequences(xvalid)

#zero pad the sequences
xtrain_pad = sequence.pad_sequences(xtrain_seq, maxlen=max_len)
xvalid_pad = sequence.pad_sequences(xvalid_seq, maxlen=max_len)

word_index = token.word_index

In [None]:
word_toxic = token_toxic.word_index
word_nontoxic = token_nontoxic.word_index

In [None]:
print(len(word_toxic))
print(len(word_nontoxic))

Example for fitting tokenizer line-by-line if corpus is too big to fit into memory

with open('/Users/liling.tan/test.txt') as fin: for line in fin:
t.fit_on_texts(line.split()) # Fitting the tokenizer line-by-line.

M = []

with open('/Users/liling.tan/test.txt') as fin: for line in fin:

    # Converting the lines into matrix, line-by-line.
    m = t.texts_to_matrix([line], mode='count')[0]
    M.append(m)

## Use pretrained word embeddings

## Convert our one-hot word index into semantic rich GloVe vectors

In [None]:
# load the GloVe vectors in a dictionary:

embeddings_index = {}
f = open(pre_path + 'glove840b300dtxt/glove.840B.300d.txt','r',encoding='utf-8')
for line in tqdm(f):
    values = line.split(' ')
    word = values[0]
    coefs = np.asarray([float(val) for val in values[1:]])
    embeddings_index[word] = coefs
f.close()

print('Found %s word vectors.' % len(embeddings_index))

In [None]:


words_not_in_corpus = ddict(int)
words_in_corpus = ddict(int)
# create an embedding matrix for the words we have in the dataset
embedding_matrix = np.zeros((len(word_index) + 1, 300))
for word, i in tqdm(word_nontoxic.items()):
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        embedding_matrix[i] = embedding_vector
        words_in_corpus[word]+=1
    else:
        words_not_in_corpus[word]+=1

In [None]:
print(len(words_not_in_corpus))
print(len(words_in_corpus))
max(words_not_in_corpus.values())
max(words_in_corpus.values())

#For the full dataset, more than half the 'words' are not found in the glove embeddings
#For the 10K sample dataset, only ~25% of the words are not found in the glove embeddings


In [None]:
print(len(words_not_in_corpus))
print(len(words_in_corpus))
max(words_not_in_corpus.values())
max(words_in_corpus.values())

#For the full dataset, more than half the 'words' are not found in the glove embeddings
#For the 10K sample dataset, only ~25% of the words are not found in the glove embeddings


In [None]:
#Save embeddings so they can be easily loaded
np.save('/kaggle/working/glove_embedding_for_full_data', embedding_matrix)

In [None]:
#Load embeddings
embedding_matrix = np.load('/kaggle/working/glove_embedding_for_10K_sample.npy')

In [None]:
embedding_matrix.shape

## Simple RNN Model

In [None]:
opt = Adam(learning_rate=0.0001)

In [None]:
model1 = Sequential()
model1.add(Embedding(len(word_index) + 1,
                 300,
                 input_length=max_len))
model1.add(SimpleRNN(100))
model1.add(Dense(1, activation='relu'))
model1.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    
model1.summary()

In [None]:
from keras.callbacks import ModelCheckpoint,TensorBoard, EarlyStopping
EPOCHS = 10
checkpoint_filepath = '/kaggle/working/'
model_checkpoint_callback = ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_acc',
    mode='max',
    save_best_only=True)


my_callbacks = [
    model_checkpoint_callback,
    TensorBoard(log_dir='/kaggle/working/logs'),
    EarlyStopping(monitor='val_loss', patience=3)
]
model_checkpoint_callback

In [None]:
model1.fit(xtrain_pad, 
           ytrain, 
           epochs=50, 
           batch_size=100, 
           callbacks=my_callbacks,
           validation_split=0.2,)

In [None]:
scores = model1.predict(xvalid_pad)[:, 0]
preds = scores>.5
run_metrics(preds, scores, yvalid)

## Simple LSTM Model

In [None]:
%%time
# A simple LSTM with glove embeddings and one dense layer
model = Sequential()
model.add(Embedding(len(word_index) + 1,
                 300,
                 weights=[embedding_matrix],
                 input_length=max_len,
                 trainable=False))

model.add(LSTM(100, dropout=0.3, recurrent_dropout=0.3))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam',metrics=['accuracy'])
    
model.summary()

In [None]:
model.fit(xtrain_pad, 
          ytrain, 
          epochs=50, 
          batch_size=100,
          callbacks=my_callbacks,
          validation_split=0.2,)

In [None]:
scores = model.predict(xvalid_pad)
preds = scores>.5
run_metrics(preds, scores, yvalid)

# Summary

So far, with very little preprocessing, we have achieved high accuracy. This is a little bit misleading however because the training set is highly imbalanced (roughly 10% positive/toxic class). 

Slightly older techniques, bag-of-words and tf-idf have done better than a simple deep learning models out-of-the-box. This can been seen by the higher AUCs and accuracy of these models in contrast to the simple RNN model. In addition, training these models was extremely fast, even on a local machine. In contrast, the deep learning models required more than 10 minutes to train even five epochs. In addition, trainingg the simple RNN required playing around with the learning rate to get network to learn. The first few attempts produced labels of all zeros. 

The simple LSTM model starts to improve dramatically over the simple RNN model even with only 5 epochs, showing that using the semantic rich word embeddings and including memory already improve simple deep learning results. Though the overall accuracy has decreased in the LSTM model vs the Naive Bayes models, the AUC and precision-recall and ROC curves are much better than the simple models. As we approach more state-of-the-art (SOTA) models and move beyond simple proof-of-concept model training, i.e., try different network parameters, experiment with data preprocessing, do hyperparameter optimization, train until the results start to degrade, add regularization, etc., the results will likely improve even more dramatically.


## Try a GRU Model

In [None]:
%%time
# GRU with glove embeddings and two dense layers
 model = Sequential()
 model.add(Embedding(len(word_index) + 1,
                 300,
                 weights=[embedding_matrix],
                 input_length=max_len,
                 trainable=False))
 model.add(SpatialDropout1D(0.3))
 model.add(GRU(300))
 model.add(Dense(1, activation='sigmoid'))

 model.compile(loss='binary_crossentropy', optimizer='adam',metrics=['accuracy'])   
    
model.summary()

%%time
# GRU with glove embeddings and two dense layers
 model = Sequential()
 model.add(Embedding(len(word_index) + 1,
                 300,
                 weights=[embedding_matrix],
                 input_length=max_len,
                 trainable=False))
 model.add(SpatialDropout1D(0.3))
 model.add(GRU(300))
 model.add(Dense(1, activation='sigmoid'))

 model.compile(loss='binary_crossentropy', optimizer='adam',metrics=['accuracy'])   
    
model.summary()

model.fit(xtrain_pad, ytrain, nb_epoch=5, batch_size=64)

scores = model.predict(xvalid_pad)


## Bidirectional RNN Model

%%time
# A simple bidirectional LSTM with glove embeddings and one dense layer
model = Sequential()
model.add(Embedding(len(word_index) + 1,
                 300,
                 weights=[embedding_matrix],
                 input_length=max_len,
                 trainable=False))
model.add(Bidirectional(LSTM(300, dropout=0.3, recurrent_dropout=0.3)))

model.add(Dense(1,activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam',metrics=['accuracy'])
    
    
model.summary()

model.fit(xtrain_pad, ytrain, nb_epoch=5, batch_size=64)

scores = model.predict(xvalid_pad)


## Seq2seq Architecture

In [None]:
#TBD


## Transformers/Attention/BERT

# Loading Dependencies
import os
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint
from kaggle_datasets import KaggleDatasets
import transformers

from tokenizers import BertWordPieceTokenizer

Encoder FOr DATA for understanding waht encode batch does read documentation of hugging face tokenizer :
https://huggingface.co/transformers/main_classes/tokenizer.html here

def fast_encode(texts, tokenizer, chunk_size=256, maxlen=512):
    """
    Encoder for encoding the text into sequence of integers for BERT Input
    """
    tokenizer.enable_truncation(max_length=maxlen)
    tokenizer.enable_padding(max_length=maxlen)
    all_ids = []
    
    for i in tqdm(range(0, len(texts), chunk_size)):
        text_chunk = texts[i:i+chunk_size].tolist()
        encs = tokenizer.encode_batch(text_chunk)
        all_ids.extend([enc.ids for enc in encs])
    
    return np.array(all_ids)

#IMP DATA FOR CONFIG

AUTO = tf.data.experimental.AUTOTUNE


# Configuration
EPOCHS = 3
BATCH_SIZE = 16 
MAX_LEN = 192

## Tokenization

For understanding please refer to hugging face documentation again

# First load the real tokenizer
tokenizer = transformers.DistilBertTokenizer.from_pretrained('distilbert-base-multilingual-cased')
# Save the loaded tokenizer locally
tokenizer.save_pretrained('.')
# Reload it with the huggingface tokenizers library
fast_tokenizer = BertWordPieceTokenizer('vocab.txt', lowercase=False)
fast_tokenizer

x_train = fast_encode(train1.comment_text.astype(str), fast_tokenizer, maxlen=MAX_LEN)
x_valid = fast_encode(valid.comment_text.astype(str), fast_tokenizer, maxlen=MAX_LEN)
x_test = fast_encode(test.content.astype(str), fast_tokenizer, maxlen=MAX_LEN)

y_train = train1.toxic.values
y_valid = valid.toxic.values

train_dataset = (
    tf.data.Dataset
    .from_tensor_slices((x_train, y_train))
    .repeat()
    .shuffle(2048)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

valid_dataset = (
    tf.data.Dataset
    .from_tensor_slices((x_valid, y_valid))
    .batch(BATCH_SIZE)
    .cache()
    .prefetch(AUTO)
)

test_dataset = (
    tf.data.Dataset
    .from_tensor_slices(x_test)
    .batch(BATCH_SIZE)
)

def build_model(transformer, max_len=512):
    """
    function for training the BERT model
    """
    input_word_ids = Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
    sequence_output = transformer(input_word_ids)[0]
    cls_token = sequence_output[:, 0, :]
    out = Dense(1, activation='sigmoid')(cls_token)
    
    model = Model(inputs=input_word_ids, outputs=out)
    model.compile(Adam(lr=1e-5), loss='binary_crossentropy', metrics=['accuracy'])
    
    return model

## Starting Training

If you want to use any another model just replace the model name in transformers._____ and use accordingly

%%time
with strategy.scope():
    transformer_layer = (
        transformers.TFDistilBertModel
        .from_pretrained('distilbert-base-multilingual-cased')
    )
    model = build_model(transformer_layer, max_len=MAX_LEN)
model.summary()

In [None]:
n_steps = x_train.shape[0] // BATCH_SIZE
train_history = model.fit(
    train_dataset,
    steps_per_epoch=n_steps,
    validation_data=valid_dataset,
    epochs=EPOCHS
)

n_steps = x_valid.shape[0] // BATCH_SIZE
train_history_2 = model.fit(
    valid_dataset.repeat(),
    steps_per_epoch=n_steps,
    epochs=EPOCHS*2
)