In [1]:
import math
import json
import requests
import itertools
import numpy as np
import time
import datetime

In [2]:
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.models import Sequential
from keras.layers import Dense, Flatten, LSTM, Conv1D, MaxPooling1D, Dropout, Activation
from keras.layers.embeddings import Embedding

## Plot
import plotly.offline as py
import plotly.graph_objs as go
py.init_notebook_mode(connected=True)
import matplotlib as plt

# NLTK
import nltk
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer

# Other
import re
import string
import numpy as np
import pandas as pd
from sklearn.manifold import TSNE

Using TensorFlow backend.


In [3]:
df=pd.read_csv(r'/kaggle/input/raw-4-years-data-for-corpus/reddit_data_raw_4years.csv')

In [4]:
corpus=[]

for title in df['title']:
    corpus.append(title)
    
len(corpus)

433802

In [5]:
wpt = nltk.WordPunctTokenizer()
stop_words = nltk.corpus.stopwords.words('english')

def normalize_document(doc):
    try:
        # lower case and remove special characters\whitespaces
        doc = re.sub(r'[^a-zA-Z\s]', '', doc, re.I|re.A)
        doc = doc.lower()
        doc = doc.strip()
        # tokenize document
        tokens = wpt.tokenize(doc)
        # filter stopwords out of document
        filtered_tokens = [token for token in tokens if token not in stop_words]
        # re-create document from filtered tokens
        doc = ' '.join(filtered_tokens)
        return doc
    except:
        return ""

normalize_corpus = np.vectorize(normalize_document)

In [6]:
norm_corpus = normalize_corpus(corpus)
norm_corpus

array(['cash cow venezuelas oil company verges collapse',
       'new study squelches treasured theory indians originsthe aryans come india conquered',
       'indias garment exports may hit billion fy', ..., 'aayega toh',
       'junior hockey world cup india claim title win belgium',
       'ribs broken money gone stranded german grateful support gurdwaras np'],
      dtype='<U300')

In [7]:
MAX_NB_WORDS = 30000
# Max number of words in each complaint.
MAX_SEQUENCE_LENGTH = 30
# This is fixed.
EMBEDDING_DIM = 300

In [8]:
from gensim.models import word2vec

# tokenize sentences in corpus
wpt = nltk.WordPunctTokenizer()
tokenized_corpus = [wpt.tokenize(document) for document in norm_corpus]

# Set values for various parameters
feature_size = EMBEDDING_DIM    # Word vector dimensionality  
window_context = 6          # Context window size                                                                                    
min_word_count = 1   # Minimum word count                        
sample = 1e-3   # Downsample setting for frequent words

w2v_model = word2vec.Word2Vec(tokenized_corpus, size=feature_size, 
                          window=window_context, min_count=min_word_count,
                          sample=sample, iter=50)

In [9]:
words = w2v_model.wv.index2word
wvs = w2v_model.wv[words]

In [10]:
df_train=pd.read_csv('/kaggle/input/reddit-data-balanced/reddit_data_balanced.csv')
# df_train2=pd.read_csv('/kaggle/input/reddit-balanced-modified/reddit_data_balanced_modified_27-4-2020.csv')

In [11]:
df_train.head(30)

Unnamed: 0.1,Unnamed: 0,title,flair
0,0,HELP HELP TEST,[R]eddiquette
1,1,Lets have a conversation Randians,[R]eddiquette
2,2,Forest guards ordered to watch over python tha...,Non-Political
3,3,Engineering pass-outs from Shitty colleges (Ti...,AskIndia
4,4,"The Constitution, as ABVP would have it. [Old]",Politics
5,5,Pune city,Photography
6,6,Chicken Tikka Masala - Chicken Tikka Gravy,Food
7,7,Rangoli Chandel loses calm after Twitter warni...,Politics
8,8,fake followers data of media,[R]eddiquette
9,9,"'Death of currency' not a new subject, virtual...",Demonetization


In [12]:
df_train=df_train.loc[df_train['flair']!='other']

In [13]:
df_train.shape

(83000, 3)

In [14]:
corpus2=[]

for title in df_train['title']:
    corpus2.append(title)
    
len(corpus2)

# corpus3=[]

# for title in df_train2['title']:
#     corpus3.append(title)
    
# len(corpus3)

83000

In [15]:
norm_corpus2 = normalize_corpus(corpus2)
norm_corpus2

# norm_corpus3 = normalize_corpus(corpus3)
# norm_corpus3

array(['help help test', 'lets conversation randians',
       'forest guards ordered watch python swallowed deer', ...,
       'whatsapp getting ready worlds bigvest election',
       'india japan plan obor alternative project unveiled next monday',
       'np mumbai feel everything shifting gujarat cm vijay rupani indian express'],
      dtype='<U286')

In [16]:
tokenizer = Tokenizer(num_words=MAX_NB_WORDS, filters='!"#$%&()*+,-./:;<=>?@[\]^_`{|}~', lower=True)
tokenizer.fit_on_texts(norm_corpus2)

# tokenizer2 = Tokenizer(num_words=MAX_NB_WORDS, filters='!"#$%&()*+,-./:;<=>?@[\]^_`{|}~', lower=True)
# tokenizer2.fit_on_texts(norm_corpus3)

In [17]:
list_tokenized_train = tokenizer.texts_to_sequences(norm_corpus2)

# list_tokenized_train2 = tokenizer2.texts_to_sequences(norm_corpus3)

In [18]:
X = pad_sequences(list_tokenized_train, maxlen=MAX_SEQUENCE_LENGTH)

# X2 = pad_sequences(list_tokenized_train2, maxlen=MAX_SEQUENCE_LENGTH)

In [19]:
print('Found %s unique tokens.' % len(tokenizer.word_index))

Found 51226 unique tokens.


In [20]:
t_dict={}

for word in w2v_model.wv.vocab:
    t_dict[word]=w2v_model[word]


Call to deprecated `__getitem__` (Method will be removed in 4.0.0, use self.wv.__getitem__() instead).



In [21]:
# t_dict

In [22]:
all_embs = np.stack(t_dict.values())
emb_mean,emb_std = all_embs.mean(), all_embs.std()
emb_mean,emb_std


arrays to stack must be passed as a "sequence" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.



(-0.00039438746, 0.366756)

In [23]:
word_index = tokenizer.word_index
nb_words = min(MAX_NB_WORDS, len(word_index))
embedding_matrix = np.random.normal(emb_mean, emb_std, (nb_words, EMBEDDING_DIM))
for word, i in word_index.items():
    if i >= MAX_NB_WORDS: continue
    embedding_vector = t_dict.get(word)
    if embedding_vector is not None: embedding_matrix[i] = embedding_vector

# word_index2 = tokenizer2.word_index
# nb_words2 = min(MAX_NB_WORDS, len(word_index2))
# embedding_matrix2 = np.random.normal(emb_mean, emb_std, (nb_words2, EMBEDDING_DIM))
# for word, i in word_index2.items():
#     if i >= MAX_NB_WORDS: continue
#     embedding_vector = t_dict.get(word)
#     if embedding_vector is not None: embedding_matrix2[i] = embedding_vector

In [24]:
embedding_matrix[0]

array([-5.54035758e-01, -7.14715725e-01, -3.77627676e-01, -3.42692560e-01,
        3.60031046e-02, -1.22589549e-02, -3.23111645e-01,  6.18959673e-02,
       -1.22807597e+00,  8.61130617e-01, -4.25768456e-01, -7.83921861e-02,
       -2.89682581e-01, -3.90151753e-01, -5.50786936e-01, -6.45206111e-01,
       -2.69191365e-01, -2.46128844e-01,  6.28626516e-01,  4.39771890e-01,
        2.07312066e-01,  3.85357659e-01,  3.98939635e-01,  6.13101030e-02,
        1.14850893e-01, -3.69680888e-01, -2.47556722e-02, -3.73322505e-01,
        3.23721058e-01,  2.63995182e-01, -1.37180337e-01,  1.32020673e-01,
        1.48231016e-01,  6.59202017e-02,  3.76199717e-02, -2.50323107e-02,
        2.45526469e-01,  4.00565316e-02, -8.86870992e-01,  1.50214185e-01,
       -3.34767641e-01, -4.64592434e-01,  2.37097084e-01, -4.43036339e-01,
       -1.77844703e-01,  7.42442460e-02,  8.36147984e-02,  5.17094480e-01,
       -1.06315766e-01, -2.81727410e-01, -4.56499929e-03,  7.68575157e-01,
       -3.21768668e-02,  

In [25]:
df_train.head()

Unnamed: 0.1,Unnamed: 0,title,flair
0,0,HELP HELP TEST,[R]eddiquette
1,1,Lets have a conversation Randians,[R]eddiquette
2,2,Forest guards ordered to watch over python tha...,Non-Political
3,3,Engineering pass-outs from Shitty colleges (Ti...,AskIndia
4,4,"The Constitution, as ABVP would have it. [Old]",Politics


In [26]:
df_train['category_id'] = df_train['flair'].factorize()[0]
category_id_df = df_train[['flair', 'category_id']].drop_duplicates().sort_values('category_id')
category_to_id = dict(category_id_df.values)
id_to_category = dict(category_id_df[['category_id', 'flair']].values)
df_train.head()

# df_train2['category_id'] = df_train2['flair'].factorize()[0]
# category_id_df2 = df_train2[['flair', 'category_id']].drop_duplicates().sort_values('category_id')
# category_to_id2 = dict(category_id_df2.values)
# id_to_category2 = dict(category_id_df2[['category_id', 'flair']].values)
# df_train2.head()

Unnamed: 0.1,Unnamed: 0,title,flair,category_id
0,0,HELP HELP TEST,[R]eddiquette,0
1,1,Lets have a conversation Randians,[R]eddiquette,0
2,2,Forest guards ordered to watch over python tha...,Non-Political,1
3,3,Engineering pass-outs from Shitty colleges (Ti...,AskIndia,2
4,4,"The Constitution, as ABVP would have it. [Old]",Politics,3


In [27]:
df_train.head(30)

Unnamed: 0.1,Unnamed: 0,title,flair,category_id
0,0,HELP HELP TEST,[R]eddiquette,0
1,1,Lets have a conversation Randians,[R]eddiquette,0
2,2,Forest guards ordered to watch over python tha...,Non-Political,1
3,3,Engineering pass-outs from Shitty colleges (Ti...,AskIndia,2
4,4,"The Constitution, as ABVP would have it. [Old]",Politics,3
5,5,Pune city,Photography,4
6,6,Chicken Tikka Masala - Chicken Tikka Gravy,Food,5
7,7,Rangoli Chandel loses calm after Twitter warni...,Politics,3
8,8,fake followers data of media,[R]eddiquette,0
9,9,"'Death of currency' not a new subject, virtual...",Demonetization,6


In [28]:
labels=np.reshape(np.array(df_train['category_id']), (83000, 1))
from keras.utils import to_categorical
y_binary = to_categorical(labels)

# labels2=np.reshape(np.array(df_train2['category_id']), (118000, 1))
# from keras.utils import to_categorical
# y_binary2 = to_categorical(labels2)

In [29]:
embedding_matrix.shape

(30000, 300)

In [30]:
y_binary.shape

(83000, 13)

In [31]:
# from keras.layers import Dense, Embedding, LSTM, SpatialDropout1D

# model = Sequential()
# # x = Embedding(max_features, embed_size, weights=[embedding_matrix])(inp)
# model.add(Embedding(MAX_NB_WORDS, EMBEDDING_DIM, weights=[embedding_matrix]))
# # model.add(SpatialDropout1D(0.2))
# model.add(LSTM(512, dropout=0.2, recurrent_dropout=0.2))
# model.add(Dense(14, activation='softmax'))
# model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# print(model.summary())

In [32]:
from keras.layers import Dense, Embedding, LSTM, SpatialDropout1D
from keras.layers import GRU, Bidirectional, GlobalAveragePooling1D, GlobalMaxPooling1D
from keras.layers import Input, Dense, Embedding, SpatialDropout1D, concatenate
from keras.layers.normalization import BatchNormalization
from keras.layers import Dense, Activation
from keras import optimizers

model = Sequential()
# x = Embedding(max_features, embed_size, weights=[embedding_matrix])(inp)
model.add(Embedding(MAX_NB_WORDS, EMBEDDING_DIM, weights=[embedding_matrix]))
model.add(SpatialDropout1D(0.4))
model.add(LSTM(1000, dropout=0.2, recurrent_dropout=0.2, return_sequences=True, kernel_initializer='glorot_uniform'))
# model.add(GlobalAveragePooling1D())
# model.add(BatchNormalization())
model.add(GlobalMaxPooling1D())
model.add(Dense(13, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

print(model.summary())

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_1 (Embedding)      (None, None, 300)         9000000   
_________________________________________________________________
spatial_dropout1d_1 (Spatial (None, None, 300)         0         
_________________________________________________________________
lstm_1 (LSTM)                (None, None, 1000)        5204000   
_________________________________________________________________
global_max_pooling1d_1 (Glob (None, 1000)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 13)                13013     
Total params: 14,217,013
Trainable params: 14,217,013
Non-trainable params: 0
_________________________________________________________________
None


In [33]:
# from keras.layers import Dense, Embedding, LSTM, SpatialDropout1D
# from keras.layers import GRU, Bidirectional, GlobalAveragePooling1D, GlobalMaxPooling1D
# from keras.layers import Input, Dense, Embedding, SpatialDropout1D, concatenate
# from keras.layers.normalization import BatchNormalization
# from keras.layers import Dense, Activation

# model = Sequential()
# # x = Embedding(max_features, embed_size, weights=[embedding_matrix])(inp)
# model.add(Embedding(MAX_NB_WORDS, EMBEDDING_DIM, weights=[embedding_matrix]))
# model.add(Conv1D(filters=64, kernel_size=3 ,strides=1, padding='same' , activation= 'relu')) 
# model.add(MaxPooling1D(pool_size=2))
# model.add(SpatialDropout1D(0.4))
# model.add(LSTM(1000, dropout=0.2, recurrent_dropout=0.2, return_sequences=True, kernel_initializer='glorot_uniform'))
# # model.add(GlobalAveragePooling1D())
# # model.add(BatchNormalization())
# model.add(GlobalMaxPooling1D())
# model.add(Dense(14, activation='softmax'))
# model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# print(model.summary())

In [34]:
# from keras.layers import Dense, Embedding, LSTM, SpatialDropout1D
# from keras.layers import GRU, Bidirectional, GlobalAveragePooling1D, GlobalMaxPooling1D
# from keras.layers import Input, Dense, Embedding, SpatialDropout1D, concatenate
# from keras.layers.normalization import BatchNormalization
# from keras.layers import Dense, Activation
# from keras.models import Model


# inp = Input( shape=(30,))
# x=Embedding(MAX_NB_WORDS, EMBEDDING_DIM, weights=[embedding_matrix])(inp)
# x=SpatialDropout1D(0.2)(x)
# x=LSTM(1000, dropout=0.2, recurrent_dropout=0.2, return_sequences=True)(x)
# x=BatchNormalization()(x)
# x=Activation('relu')(x)
# avg_pool = GlobalAveragePooling1D()(x)
# max_pool = GlobalMaxPooling1D()(x)
# conc = concatenate([avg_pool, max_pool])
# out=Dense(14, activation='softmax')(conc)

# model = Model(inputs=inp, outputs=out)
# model.compile(loss='logcosh', optimizer='adam', metrics=['accuracy'])

In [35]:
# from keras.layers import GRU, Bidirectional, GlobalAveragePooling1D, GlobalMaxPooling1D

# model = Sequential()
# # x = Embedding(max_features, embed_size, weights=[embedding_matrix])(inp)
# model.add(Embedding(MAX_NB_WORDS, EMBEDDING_DIM, weights=[embedding_matrix]))
# model.add(SpatialDropout1D(0.2))
# model.add(LSTM(1000, dropout=0.2, recurrent_dropout=0.2, return_sequences=True))
# model.add(GlobalMaxPooling1D())
# model.add(Dense(14, activation='softmax'))
# model.compile(loss=tf.keras.losses.Huber(delta=1.0), optimizer='adam', metrics=['accuracy'])
# print(model.summary())

In [36]:
X.shape

(83000, 30)

In [37]:
epochs = 10
batch_size = 64

history = model.fit(X[0:70000], y_binary[0:70000],
                    batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_split=0.2)


Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.



Train on 56000 samples, validate on 14000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [38]:
import tensorflow as tf

In [39]:
from sklearn.metrics import classification_report

In [40]:
df_train.head()

Unnamed: 0.1,Unnamed: 0,title,flair,category_id
0,0,HELP HELP TEST,[R]eddiquette,0
1,1,Lets have a conversation Randians,[R]eddiquette,0
2,2,Forest guards ordered to watch over python tha...,Non-Political,1
3,3,Engineering pass-outs from Shitty colleges (Ti...,AskIndia,2
4,4,"The Constitution, as ABVP would have it. [Old]",Politics,3


In [41]:
y_pred=model.predict(X[70000:])
# y_pred2=model.predict(X2[90000:])

In [42]:
# df_train.head(30)

In [43]:
# df_train['flair'].value_counts()

In [44]:
classification_report(df_train['category_id'][70000:], np.argmax(y_pred, axis=1), output_dict=True)

# classification_report(df_train2['category_id'][90000:], np.argmax(y_pred2, axis=1), output_dict=True)

{'0': {'precision': 0.4636363636363636,
  'recall': 0.1951530612244898,
  'f1-score': 0.2746858168761221,
  'support': 784},
 '1': {'precision': 0.3847184986595174,
  'recall': 0.364907819453274,
  'f1-score': 0.3745513866231648,
  'support': 1573},
 '2': {'precision': 0.4808743169398907,
  'recall': 0.5854956753160346,
  'f1-score': 0.528052805280528,
  'support': 1503},
 '3': {'precision': 0.6175438596491228,
  'recall': 0.6675094816687737,
  'f1-score': 0.6415552855407046,
  'support': 1582},
 '4': {'precision': 0.6,
  'recall': 0.6459074733096085,
  'f1-score': 0.6221079691516709,
  'support': 562},
 '5': {'precision': 0.6624087591240876,
  'recall': 0.6927480916030534,
  'f1-score': 0.6772388059701492,
  'support': 524},
 '6': {'precision': 0.669195751138088,
  'recall': 0.7875,
  'f1-score': 0.723543888433142,
  'support': 560},
 '7': {'precision': 0.5939479239971851,
  'recall': 0.6482334869431644,
  'f1-score': 0.6199045170767536,
  'support': 1302},
 '8': {'precision': 0.79126

In [45]:
model.save('/kaggle/working/reddit_predictor.h5')

In [46]:
model_json = model.to_json()
with open("/kaggle/working/model.json", "w") as json_file:
    json_file.write(model_json)

In [47]:
model.outputs

[<tf.Tensor 'dense_1/Softmax:0' shape=(None, 13) dtype=float32>]

In [48]:
out_names = [x.op.name for x in model.outputs]
out_names

['dense_1/Softmax']

In [49]:
# def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
#     """
#     Freezes the state of a session into a pruned computation graph.

#     Creates a new computation graph where variable nodes are replaced by
#     constants taking their current value in the session. The new graph will be
#     pruned so subgraphs that are not necessary to compute the requested
#     outputs are removed.
#     @param session The TensorFlow session to be frozen.
#     @param keep_var_names A list of variable names that should not be frozen,
#                           or None to freeze all the variables in the graph.
#     @param output_names Names of the relevant graph outputs.
#     @param clear_devices Remove the device directives from the graph for better portability.
#     @return The frozen graph definition.
#     """
#     graph = session.graph
#     with graph.as_default():
#         freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables()).difference(keep_var_names or []))
#         output_names = output_names or []
#         output_names += [v.op.name for v in tf.compat.v1.global_variables()]
#         input_graph_def = graph.as_graph_def()
#         if clear_devices:
#             for node in input_graph_def.node:
#                 node.device = ""
#         frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
#             session, input_graph_def, output_names, freeze_var_names)
#         return frozen_graph

In [50]:
# from keras import backend as K

# # frozen_graph = freeze_session(tf.compat.v1.Session(),
# #                               output_names=[out.op.name for out in model.outputs])

# frozen_graph = freeze_session(tf.compat.v1.Session(),
#                               output_names=['dense_2'])

In [51]:
import tensorflow as tf

# The export path contains the name and the version of the model
tf.keras.backend.set_learning_phase(0) # Ignore dropout at inference
model = tf.keras.models.load_model('/kaggle/working/reddit_predictor.h5')
export_path = '/kaggle/working/model2.pb'

# Fetch the Keras session and save the model
# The signature definition is defined by the input and output tensors
# And stored with the default serving key
with tf.compat.v1.Session() as sess:
    tf.compat.v1.saved_model.simple_save(
        sess,
        export_path,
        inputs={'input_image': model.input},
        outputs={t.name:t for t in model.outputs})

In [52]:
import os
file_name = "/kaggle/working/model2.pb/saved_model.pb"

file_stats = os.stat(file_name)

In [53]:
print(f'File Size in MegaBytes is {file_stats.st_size / (1024 * 1024)}')

File Size in MegaBytes is 0.000232696533203125


In [54]:
# tf.io.write_graph(frozen_graph, "/kaggle/working", "reddit_trained_tf.pb", as_text=False)

In [55]:
yaml_model= model.to_yaml()
# writing the yaml model to the yaml file
with open('/kaggle/working/yamlmodel.yaml', 'w') as yaml_file:
    yaml_file.write(yaml_model)

In [56]:
# import tensorflow as tf
# from tensorflow.python.framework import graph_util
# from tensorflow.python.framework import graph_io
# from pathlib import Path
# from absl import app
# from absl import flags
# from absl import logging
# import keras
# from keras import backend as K
# from keras.models import model_from_json, model_from_yaml

# K.set_learning_phase(0)
# FLAGS = flags.FLAGS

# # def del_all_flags(FLAGS):
# #     flags_dict = FLAGS._flags()
# #     keys_list = [keys for keys in flags_dict]
# #     for keys in keys_list:
# #         FLAGS.delattr(keys)

# # del_all_flags(FLAGS)

# # def del_all_flags(FLAGS):
# #     flags_dict = FLAGS._flags()
# # #     keys_list = [keys for keys in flags_dict]
# #     for keys, values in flags_dict.items():
# #         delattr(keys, values)

# # del_all_flags(FLAGS)

# def del_all_flags(FLAGS):
#     flags_dict = FLAGS._flags()
#     keys_list = [keys for keys in flags_dict]
#     for keys in keys_list:
#         FLAGS.__delattr__(keys)
        
# del_all_flags(FLAGS)

# flags.DEFINE_string('input_model2', None, '/kaggle/working/reddit_predictor.h5')
# flags.DEFINE_string('input_model_json', None, '/kaggle/working/model.json')
# flags.DEFINE_string('input_model_yaml', None, '/kaggle/working/yamlmodel.yaml')
# flags.DEFINE_string('output_model', None, 'letsee.pb')
# flags.DEFINE_boolean('save_graph_def', False,
#                      'Whether to save the graphdef.pbtxt file which contains '
#                      'the graph definition in ASCII format.')
# # flags.DEFINE_string('output_nodes_prefix', None,
# #                     'If set, the output nodes will be renamed to '
# #                     '`output_nodes_prefix`+i, where `i` will numerate the '
# #                     'number of of output nodes of the network.')
# # flags.DEFINE_boolean('quantize', False,
# #                      'If set, the resultant TensorFlow graph weights will be '
# #                      'converted from float into eight-bit equivalents. See '
# #                      'documentation here: '
# #                      'https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms')
# flags.DEFINE_boolean('channels_first', False,
#                      'Whether channels are the first dimension of a tensor. '
#                      'The default is TensorFlow behaviour where channels are '
#                      'the last dimension.')
# flags.DEFINE_boolean('output_meta_ckpt', False,
#                      'If set to True, exports the model as .meta, .index, and '
#                      '.data files, with a checkpoint file. These can be later '
#                      'loaded in TensorFlow to continue training.')

# flags.mark_flag_as_required('input_model2')
# flags.mark_flag_as_required('output_model')


# def load_model(input_model_path, input_json_path=None, input_yaml_path=None):
#     if not Path(input_model_path).exists():
#         raise FileNotFoundError(
#             'Model file `{}` does not exist.'.format(input_model_path))
#     try:
#         model = keras.models.load_model(input_model_path)
#         return model
#     except FileNotFoundError as err:
#         logging.error('Input mode file (%s) does not exist.', FLAGS.input_model2)
#         raise err
#     except ValueError as wrong_file_err:
#         if input_json_path:
#             if not Path(input_json_path).exists():
#                 raise FileNotFoundError(
#                     'Model description json file `{}` does not exist.'.format(
#                         input_json_path))
#             try:
#                 model = model_from_json(open(str(input_json_path)).read())
#                 model.load_weights(input_model_path)
#                 return model
#             except Exception as err:
#                 logging.error("Couldn't load model from json.")
#                 raise err
#         elif input_yaml_path:
#             if not Path(input_yaml_path).exists():
#                 raise FileNotFoundError(
#                     'Model description yaml file `{}` does not exist.'.format(
#                         input_yaml_path))
#             try:
#                 model = model_from_yaml(open(str(input_yaml_path)).read())
#                 model.load_weights(input_model_path)
#                 return model
#             except Exception as err:
#                 logging.error("Couldn't load model from yaml.")
#                 raise err
#         else:
#             logging.error(
#                 'Input file specified only holds the weights, and not '
#                 'the model definition. Save the model using '
#                 'model.save(filename.h5) which will contain the network '
#                 'architecture as well as its weights. '
#                 'If the model is saved using the '
#                 'model.save_weights(filename) function, either '
#                 'input_model_json or input_model_yaml flags should be set to '
#                 'to import the network architecture prior to loading the '
#                 'weights. \n'
#                 'Check the keras documentation for more details '
#                 '(https://keras.io/getting-started/faq/)')
#             raise wrong_file_err


# def main(args):
#     # If output_model path is relative and in cwd, make it absolute from root
#     output_model = FLAGS.output_model
#     if str(Path(output_model).parent) == '.':
#         output_model = str((Path.cwd() / output_model))

#     output_fld = Path(output_model).parent
#     output_model_name = Path(output_model).name
#     output_model_stem = Path(output_model).stem
#     output_model_pbtxt_name = output_model_stem + '.pbtxt'

#     # Create output directory if it does not exist
#     Path(output_model).parent.mkdir(parents=True, exist_ok=True)

#     if FLAGS.channels_first:
#         K.set_image_data_format('channels_first')
#     else:
#         K.set_image_data_format('channels_last')

#     model = load_model(FLAGS.input_model2, FLAGS.input_model_json, FLAGS.input_model_yaml)

#     # TODO(amirabdi): Support networks with multiple inputs
#     orig_output_node_names = [node.op.name for node in model.outputs]
#     if FLAGS.output_nodes_prefix:
#         num_output = len(orig_output_node_names)
#         pred = [None] * num_output
#         converted_output_node_names = [None] * num_output

#         # Create dummy tf nodes to rename output
#         for i in range(num_output):
#             converted_output_node_names[i] = '{}{}'.format(
#                 FLAGS.output_nodes_prefix, i)
#             pred[i] = tf.identity(model.outputs[i],
#                                   name=converted_output_node_names[i])
#     else:
#         converted_output_node_names = orig_output_node_names
#     logging.info('Converted output node names are: %s',
#                  str(converted_output_node_names))

#     sess = K.get_session()
#     if FLAGS.output_meta_ckpt:
#         saver = tf.train.Saver()
#         saver.save(sess, str(output_fld / output_model_stem))

#     if FLAGS.save_graph_def:
#         tf.train.write_graph(sess.graph.as_graph_def(), str(output_fld),
#                              output_model_pbtxt_name, as_text=True)
#         logging.info('Saved the graph definition in ascii format at %s',
#                      str(Path(output_fld) / output_model_pbtxt_name))

#     if FLAGS.quantize:
#         from tensorflow.tools.graph_transforms import TransformGraph
#         transforms = ["quantize_weights", "quantize_nodes"]
#         transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [],
#                                                converted_output_node_names,
#                                                transforms)
#         constant_graph = graph_util.convert_variables_to_constants(
#             sess,
#             transformed_graph_def,
#             converted_output_node_names)
#     else:
#         constant_graph = graph_util.convert_variables_to_constants(
#             sess,
#             sess.graph.as_graph_def(),
#             converted_output_node_names)

#     graph_io.write_graph(constant_graph, str(output_fld), output_model_name,
#                          as_text=False)
#     logging.info('Saved the freezed graph at %s',
#                  str(Path(output_fld) / output_model_name))


# if __name__ == "__main__":
#     app.run(main)