In [7]:
import numpy as np 
import pandas as pd 

import os
import string
from string import digits
import matplotlib.pyplot as plt
%matplotlib inline
import re

import seaborn as sns
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from keras.layers import Input, LSTM, Embedding, Dense
from keras.models import Model

print(os.listdir("../input"))

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
pd.set_option('display.max_colwidth', -1)


['Hindi_English_Truncated_Corpus.csv']


In [8]:
lines=pd.read_csv("../input/Hindi_English_Truncated_Corpus.csv",encoding='utf-8')

In [13]:
lines['source'].value_counts()

tides        50000
ted          39881
indic2012    37726
Name: source, dtype: int64

In [14]:
lines=lines[lines['source']=='ted']

In [15]:
lines.head(10)

Unnamed: 0,source,english_sentence,hindi_sentence
0,ted,politicians do not have permission to do what needs to be done.,"राजनीतिज्ञों के पास जो कार्य करना चाहिए, वह करने कि अनुमति नहीं है ."
1,ted,"I'd like to tell you about one such child,","मई आपको ऐसे ही एक बच्चे के बारे में बताना चाहूंगी,"
3,ted,what we really mean is that they're bad at not paying attention.,हम ये नहीं कहना चाहते कि वो ध्यान नहीं दे पाते
7,ted,"And who are we to say, even, that they are wrong",और हम होते कौन हैं यह कहने भी वाले कि वे गलत हैं
13,ted,So there is some sort of justice,तो वहाँ न्याय है
23,ted,This changed slowly,धीरे धीरे ये सब बदला
26,ted,were being produced.,उत्पन्न नहीं कि जाती थी.
30,ted,"And you can see, this LED is going to glow.","और जैसा आप देख रहे है, ये एल.ई.डी. जल उठेगी।"
32,ted,"to turn on the lights or to bring him a glass of water,","लाईट जलाने के लिए या उनके लिए पानी लाने के लिए,"
35,ted,Can you imagine saying that?,क्या आप ये कल्पना कर सकते है


In [16]:
pd.isnull(lines).sum()

source              0
english_sentence    0
hindi_sentence      0
dtype: int64

In [17]:
lines=lines[~pd.isnull(lines['english_sentence'])]

In [18]:
lines.drop_duplicates(inplace=True)

In [19]:
lines=lines.sample(n=25000,random_state=42)
lines.shape

(25000, 3)

In [20]:
# Lowercase all characters
lines['english_sentence']=lines['english_sentence'].apply(lambda x: x.lower())
lines['hindi_sentence']=lines['hindi_sentence'].apply(lambda x: x.lower())

In [21]:
# Remove quotes
lines['english_sentence']=lines['english_sentence'].apply(lambda x: re.sub("'", '', x))
lines['hindi_sentence']=lines['hindi_sentence'].apply(lambda x: re.sub("'", '', x))

In [22]:
exclude = set(string.punctuation) # Set of all special characters
# Remove all the special characters
lines['english_sentence']=lines['english_sentence'].apply(lambda x: ''.join(ch for ch in x if ch not in exclude))
lines['hindi_sentence']=lines['hindi_sentence'].apply(lambda x: ''.join(ch for ch in x if ch not in exclude))

In [23]:
# Remove all numbers from text
remove_digits = str.maketrans('', '', digits)
lines['english_sentence']=lines['english_sentence'].apply(lambda x: x.translate(remove_digits))
lines['hindi_sentence']=lines['hindi_sentence'].apply(lambda x: x.translate(remove_digits))

lines['hindi_sentence'] = lines['hindi_sentence'].apply(lambda x: re.sub("[२३०८१५७९४६]", "", x))

# Remove extra spaces
lines['english_sentence']=lines['english_sentence'].apply(lambda x: x.strip())
lines['hindi_sentence']=lines['hindi_sentence'].apply(lambda x: x.strip())
lines['english_sentence']=lines['english_sentence'].apply(lambda x: re.sub(" +", " ", x))
lines['hindi_sentence']=lines['hindi_sentence'].apply(lambda x: re.sub(" +", " ", x))


In [24]:
# Add start and end tokens to target sequences
lines['hindi_sentence'] = lines['hindi_sentence'].apply(lambda x : 'START_ '+ x + ' _END')

In [25]:
lines.head()

Unnamed: 0,source,english_sentence,hindi_sentence
82040,ted,we still dont know who her parents are who she is,START_ हम अभी तक नहीं जानते हैं कि उसके मातापिता कौन हैं वह कौन है _END
85038,ted,no keyboard,START_ कोई कुंजीपटल नहीं _END
58018,ted,but as far as being a performer,START_ लेकिन एक कलाकार होने के साथ _END
74470,ted,and this particular balloon,START_ और यह खास गुब्बारा _END
122330,ted,and its not as hard as you think integrate climate solutions into all of your innovations,START_ और जितना आपको लगता है यह उतना कठिन नहीं हैअपने सभी नवाचारों में जलवायु समाधान को एकीकृत करें _END


In [26]:
### Get English and Hindi Vocabulary
all_eng_words=set()
for eng in lines['english_sentence']:
    for word in eng.split():
        if word not in all_eng_words:
            all_eng_words.add(word)

all_hindi_words=set()
for hin in lines['hindi_sentence']:
    for word in hin.split():
        if word not in all_hindi_words:
            all_hindi_words.add(word)

In [27]:
len(all_eng_words)

14030

In [28]:
len(all_hindi_words)

17540

In [29]:
lines['length_eng_sentence']=lines['english_sentence'].apply(lambda x:len(x.split(" ")))
lines['length_hin_sentence']=lines['hindi_sentence'].apply(lambda x:len(x.split(" ")))

In [30]:
lines.head()

Unnamed: 0,source,english_sentence,hindi_sentence,length_eng_sentence,length_hin_sentence
82040,ted,we still dont know who her parents are who she is,START_ हम अभी तक नहीं जानते हैं कि उसके मातापिता कौन हैं वह कौन है _END,11,16
85038,ted,no keyboard,START_ कोई कुंजीपटल नहीं _END,2,5
58018,ted,but as far as being a performer,START_ लेकिन एक कलाकार होने के साथ _END,7,8
74470,ted,and this particular balloon,START_ और यह खास गुब्बारा _END,4,6
122330,ted,and its not as hard as you think integrate climate solutions into all of your innovations,START_ और जितना आपको लगता है यह उतना कठिन नहीं हैअपने सभी नवाचारों में जलवायु समाधान को एकीकृत करें _END,16,20


In [31]:
lines[lines['length_eng_sentence']>30].shape

(0, 5)

In [32]:
lines=lines[lines['length_eng_sentence']<=20]
lines=lines[lines['length_hin_sentence']<=20]

In [33]:
lines.shape

(24774, 5)

In [34]:
print("maximum length of Hindi Sentence ",max(lines['length_hin_sentence']))
print("maximum length of English Sentence ",max(lines['length_eng_sentence']))

maximum length of Hindi Sentence  20
maximum length of English Sentence  20


In [35]:
max_length_src=max(lines['length_hin_sentence'])
max_length_tar=max(lines['length_eng_sentence'])

In [36]:
input_words = sorted(list(all_eng_words))
target_words = sorted(list(all_hindi_words))
num_encoder_tokens = len(all_eng_words)
num_decoder_tokens = len(all_hindi_words)
num_encoder_tokens, num_decoder_tokens

(14030, 17540)

In [37]:
num_decoder_tokens += 1 #for zero padding


In [38]:
input_token_index = dict([(word, i+1) for i, word in enumerate(input_words)])
target_token_index = dict([(word, i+1) for i, word in enumerate(target_words)])

In [39]:
reverse_input_char_index = dict((i, word) for word, i in input_token_index.items())
reverse_target_char_index = dict((i, word) for word, i in target_token_index.items())

In [40]:
lines = shuffle(lines)
lines.head(10)

Unnamed: 0,source,english_sentence,hindi_sentence,length_eng_sentence,length_hin_sentence
18072,ted,especially when the people cant pay for it,START_ ख़ासतौर पर तब जब कि लोग पैसे देने में सक्षम हैं ही नहीं _END,8,15
34880,ted,so thats what t is doing,START_ और ठीक यही टी कर रहा है _END,6,9
78273,ted,any one of you could be famous on the internet,START_ और कोई भी इंटरनेट पर अगले शनिवार तक _END,10,10
117038,ted,seen and heard different versions,START_ ऐसी ही एकतरफा कहानी के विभिन्न _END,5,8
19890,ted,theyre prophecy,START_ वे भविष्यवाणी हैं _END,2,5
100770,ted,and i think that this will be one of the more important companies in east africa,START_ और मुझे लगता है कि ये पूर्वी अफ़्रीका की महत्वपूर्ण कंपनियों में से एक होगी। _END,16,17
23151,ted,so it was like about a half an hour for a lightning call to come through,START_ उस लाइटनिंग काल के लिए भी हमे आधा घंटा इंतज़ार करना पड़ता था _END,16,15
68140,ted,and it will make very highresolution maps,START_ जो बहुत अधिक रिज़ोल्यूशन के मानचित्र उपलब्ध कराएगा _END,7,10
96287,ted,so whats a technology that will allow us,START_ तो क्या है वेह तकनीक जो कि हमें असाधारण _END,8,11
6541,ted,you think that its not true of course,START_ आप ये सोच तै है मगर ये सच नहीं है _END,8,12


### Split the data into train and test

In [41]:
X, y = lines['english_sentence'], lines['hindi_sentence']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2,random_state=42)
X_train.shape, X_test.shape

((19819,), (4955,))

### Let us save this data

In [42]:
X_train.to_pickle('X_train.pkl')
X_test.to_pickle('X_test.pkl')


In [43]:
def generate_batch(X = X_train, y = y_train, batch_size = 128):
    ''' Generate a batch of data '''
    while True:
        for j in range(0, len(X), batch_size):
            encoder_input_data = np.zeros((batch_size, max_length_src),dtype='float32')
            decoder_input_data = np.zeros((batch_size, max_length_tar),dtype='float32')
            decoder_target_data = np.zeros((batch_size, max_length_tar, num_decoder_tokens),dtype='float32')
            for i, (input_text, target_text) in enumerate(zip(X[j:j+batch_size], y[j:j+batch_size])):
                for t, word in enumerate(input_text.split()):
                    encoder_input_data[i, t] = input_token_index[word] # encoder input seq
                for t, word in enumerate(target_text.split()):
                    if t<len(target_text.split())-1:
                        decoder_input_data[i, t] = target_token_index[word] # decoder input seq
                    if t>0:
                        # decoder target sequence (one hot encoded)
                        # does not include the START_ token
                        # Offset by one timestep
                        decoder_target_data[i, t - 1, target_token_index[word]] = 1.
            yield([encoder_input_data, decoder_input_data], decoder_target_data)

### Encoder-Decoder Architecture

In [44]:
latent_dim=300

In [45]:
# Encoder
encoder_inputs = Input(shape=(None,))
enc_emb =  Embedding(num_encoder_tokens, latent_dim, mask_zero = True)(encoder_inputs)
encoder_lstm = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder_lstm(enc_emb)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]

Instructions for updating:
Colocations handled automatically by placer.


In [46]:
# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None,))
dec_emb_layer = Embedding(num_decoder_tokens, latent_dim, mask_zero = True)
dec_emb = dec_emb_layer(decoder_inputs)
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(dec_emb,
                                     initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

In [47]:
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')

In [48]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, None, 300)    4209000     input_1[0][0]                    
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, None, 300)    5262300     input_2[0][0]                    
__________________________________________________________________________________________________
lstm_1 (LS

In [49]:
train_samples = len(X_train)
val_samples = len(X_test)
batch_size = 128
epochs = 100

In [79]:
model.fit_generator(generator = generate_batch(X_train, y_train, batch_size = batch_size),
                    steps_per_epoch = train_samples//batch_size,
                    epochs=epochs,
                    validation_data = generate_batch(X_test, y_test, batch_size = batch_size),
                    validation_steps = val_samples//batch_size)



Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100

<keras.callbacks.History at 0x7ff30f20fb38>

In [80]:
model.save_weights('nmt_weights.h5')

In [50]:
# Encode the input sequence to get the "thought vectors"
encoder_model = Model(encoder_inputs, encoder_states)

# Decoder setup
# Below tensors will hold the states of the previous time step
decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]

dec_emb2= dec_emb_layer(decoder_inputs) # Get the embeddings of the decoder sequence

# To predict the next word in the sequence, set the initial states to the states from the previous time step
decoder_outputs2, state_h2, state_c2 = decoder_lstm(dec_emb2, initial_state=decoder_states_inputs)
decoder_states2 = [state_h2, state_c2]
decoder_outputs2 = decoder_dense(decoder_outputs2) # A dense softmax layer to generate prob dist. over the target vocabulary

# Final decoder model
decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs,
    [decoder_outputs2] + decoder_states2)


In [51]:
def decode_sequence(input_seq):
    # Encode the input as state vectors.
    states_value = encoder_model.predict(input_seq)
    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1,1))
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0] = target_token_index['START_']

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict([target_seq] + states_value)

        # Sample a token
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += ' '+sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '_END' or
           len(decoded_sentence) > 50):
            stop_condition = True

        # Update the target sequence (of length 1).
        target_seq = np.zeros((1,1))
        target_seq[0, 0] = sampled_token_index

        # Update states
        states_value = [h, c]

    return decoded_sentence

In [83]:
train_gen = generate_batch(X_train, y_train, batch_size = 1)
k=-1


In [84]:
k+=1
(input_seq, actual_output), _ = next(train_gen)
decoded_sentence = decode_sequence(input_seq)
print('Input English sentence:', X_train[k:k+1].values[0])
print('Actual Hindi Translation:', y_train[k:k+1].values[0][6:-4])
print('Predicted Hindi Translation:', decoded_sentence[:-4])

Input English sentence: had used her brutally
Actual Hindi Translation:  उसका निर्दयता से उपभोग किया था 
Predicted Hindi Translation:  उसका निर्दयता से उपभोग किया था 


In [85]:
k+=1
(input_seq, actual_output), _ = next(train_gen)
decoded_sentence = decode_sequence(input_seq)
print('Input English sentence:', X_train[k:k+1].values[0])
print('Actual Hindi Translation:', y_train[k:k+1].values[0][6:-4])
print('Predicted Hindi Translation:', decoded_sentence[:-4])

Input English sentence: in fact legend has is that when doubting thomas the apostle saint thomas
Actual Hindi Translation:  यहाँ तक की यह भी कहा जाता है की देवदूत के शक करने पर संत थोमस 
Predicted Hindi Translation:  यहाँ भी यह कहा की है कि सामाजिक शक्ति से रहा ह


In [86]:
k+=1
(input_seq, actual_output), _ = next(train_gen)
decoded_sentence = decode_sequence(input_seq)
print('Input English sentence:', X_train[k:k+1].values[0])
print('Actual Hindi Translation:', y_train[k:k+1].values[0][6:-4])
print('Predicted Hindi Translation:', decoded_sentence[:-4])

Input English sentence: that i could actually get a fractional percent of a cent
Actual Hindi Translation:  मुझे एक पैसे का केवल कुछ भाग ही मिलेगा 
Predicted Hindi Translation:  मुझे एक पैसे का अवसर का एक जरुरत है 


In [87]:
k+=1
(input_seq, actual_output), _ = next(train_gen)
decoded_sentence = decode_sequence(input_seq)
print('Input English sentence:', X_train[k:k+1].values[0])
print('Actual Hindi Translation:', y_train[k:k+1].values[0][6:-4])
print('Predicted Hindi Translation:', decoded_sentence[:-4])

Input English sentence: the bad news is they were all photocopies so we didnt make a dime in revenue
Actual Hindi Translation:  बुरी खबर ये है कि ये सब फ़ोटोकापियाँ थी तो हमने इस से एक भी पैसा नहीं कमाया। 
Predicted Hindi Translation:  बुरी खबर ये है कि हम किसी बात पर विश्वास नहीं घ


In [88]:
k+=1
(input_seq, actual_output), _ = next(train_gen)
decoded_sentence = decode_sequence(input_seq)
print('Input English sentence:', X_train[k:k+1].values[0])
print('Actual Hindi Translation:', y_train[k:k+1].values[0][6:-4])
print('Predicted Hindi Translation:', decoded_sentence[:-4])

Input English sentence: that these kids would have been trained to see as “disabled”
Actual Hindi Translation:  जिसे ये बच्चे “विकलांग” या “अक्षम” समझने के आदी होते 
Predicted Hindi Translation:  जिसे यह बच्चों को जन्म है जिसे आपके पास तमाम ह


## Encoder Decoder with attention

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, Embedding, Bidirectional, Concatenate
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split

# Assuming you have loaded and preprocessed your data into X_train, X_test, y_train, y_test
# and have dictionaries input_token_index and target_token_index

# Parameters
batch_size = 128
epochs = 50
latent_dim = 300  # Dimensionality of the encoding space

# Encoder input
encoder_inputs = Input(shape=(None,))
enc_emb = Embedding(num_encoder_tokens, latent_dim, mask_zero=True)(encoder_inputs)
encoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
encoder_outputs, state_h, state_c = encoder_lstm(enc_emb)
encoder_states = [state_h, state_c]

# Decoder input
decoder_inputs = Input(shape=(None,))
dec_emb_layer = Embedding(num_decoder_tokens, latent_dim, mask_zero=True)
dec_emb = dec_emb_layer(decoder_inputs)
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(dec_emb, initial_state=encoder_states)

# Custom attention mechanism
def attention(inputs):
    encoder_outputs, decoder_outputs = inputs
    # Score calculation
    score = tf.matmul(decoder_outputs, encoder_outputs, transpose_b=True)
    attention_weights = tf.nn.softmax(score, axis=-1)
    # Context vector calculation
    context_vector = tf.matmul(attention_weights, encoder_outputs)
    return context_vector

# Applying attention mechanism
attn_output = attention([encoder_outputs, decoder_outputs])
# Concatenate attention output and decoder LSTM output
decoder_concat_input = Concatenate(axis=-1)([decoder_outputs, attn_output])

# Dense layer
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_concat_input)

# Define the model
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Print model summary
print(model.summary())

# Train the model
train_samples = len(X_train)
val_samples = len(X_test)
steps_per_epoch = train_samples // batch_size
val_steps = val_samples // batch_size

checkpoint = ModelCheckpoint('best_model.h5', save_best_only=True)

# Train the model
model.fit(generate_batch(X_train, y_train, batch_size=batch_size),
          steps_per_epoch=steps_per_epoch,
          epochs=epochs,
          validation_data=generate_batch(X_test, y_test, batch_size=batch_size),
          validation_steps=val_steps,
          callbacks=[checkpoint])

# Save the model
model.save('s2s_attention_custom.h5')


In [91]:
batch_size = 128
epochs = 50
latent_dim = 300

In [93]:
decoder_outputs, _, _ = decoder_lstm(dec_emb, initial_state=encoder_states)

In [92]:
def attention(inputs):
    encoder_outputs, decoder_outputs = inputs
    # score shape == (batch_size, max_length, 1)
    score = tf.matmul(decoder_outputs, encoder_outputs, transpose_b=True)
    # attention_weights shape == (batch_size, max_length, 1)
    attention_weights = tf.nn.softmax(score, axis=2)
    # context_vector shape after sum == (batch_size, hidden_size)
    context_vector = tf.matmul(attention_weights, encoder_outputs)
    return context_vector

In [None]:
import tensorflow as tf
attn_output = attention([encoder_outputs, decoder_outputs])
# Concatenate attention output and decoder LSTM output
decoder_concat_input = Concatenate(axis=-1)([decoder_outputs, attn_output])

In [None]:
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_concat_input)

In [None]:
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
print(model.summary())

In [None]:
train_samples = len(X_train)
val_samples = len(X_test)
steps_per_epoch = train_samples 
val_steps = val_samples 

In [None]:
checkpoint = ModelCheckpoint('attention_lstm.h5', save_best_only=True)

In [None]:
model.fit(generate_batch(X_train, y_train, batch_size=batch_size),
          steps_per_epoch=steps_per_epoch,
          epochs=epochs,
          validation_data=generate_batch(X_test, y_test, batch_size=batch_size),
          validation_steps=val_steps,
          callbacks=[checkpoint])

In [None]:
model.save('attention_custom.h5')
