Skip to content

Conversation

yhliang2018
Copy link
Contributor

Add recommendation model to the garden. Will update readme and add unit tests.

@yhliang2018 yhliang2018 requested review from a team and karmel as code owners May 4, 2018 20:08
qlzh727
qlzh727 previously requested changes May 4, 2018
@@ -0,0 +1,35 @@
# Recommendation Model
This is an implementation of the Neural Collaborative Filtering (NCF) framework with Neural Matrix Factorization (NeuMF) model as described in the [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031) paper. Current implementation is based on the code from the authors' [NCF code](https://github.com/hexiangnan/neural_collaborative_filtering) and the Standford pytorch implementation at [mlperf repo](https://github.com/mlperf/reference/tree/master/recommendation/pytorch).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should change "pytorch" to "PyTorch" which is their official name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, better yet, just "the Stanford implementation." Also note the typo in Stanford.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another nit: "implementation in the MLPerf Repo"

Copy link
Contributor Author

@yhliang2018 yhliang2018 May 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do. Thank you for the careful checking.

import tensorflow as tf # pylint: disable=g-bad-import-order

import numpy as np
import pandas as pd
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it worth the effort or not, maybe we can convert this to use the tf.data for importing CSV files.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See https://www.tensorflow.org/api_docs/python/tf/estimator/inputs/pandas_input_fn and the containing module for convenient converters

parse_file_to_csv(FLAGS.data_dir, FLAGS.dataset)


if __name__ == '__main__':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will probably need some update to use absl flags and app.

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir', type=str, default='/tmp/ml_data',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value "/tmp/ml_data" seems to be too generic. Can we give it a more specific name?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. ml_data could also stand for "machine learning data", which could confuse the user.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, this ml is abbr. of "MoiveLens" which is the name of the dataset used for model training and evaluation. Maybe rename it as "movie_lens"?

file_path, _ = urllib.request.urlretrieve(
DATA_URL + dataset_zip, file_path, _progress)
statinfo = os.stat(file_path)
print('\nSuccessfully downloaded', file_path, statinfo.st_size, 'bytes.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leading \n in a string is bit weird, maybe just print('\n', 'Successfully downloaded', ...)

Or we can change to tf.logging to mitigate the line return issue.


# Get the info of users who have more than 20 ratings on items.
grouped = df.groupby(USER_COLUMN)
df = grouped.filter(lambda x: len(x) >= 20)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic number 20 need some clarification.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, it's from the paper. Will add comments on it.


# Calculate HR score
def _get_hr(ranklist, gt_item):
for item in ranklist:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be simplified as:

return 1 if gt_item in ranklist else 0


# Calculate NDCG score
def _get_ndcg(ranklist, gt_item):
for i in range(len(ranklist)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be updated to:

return 0 if gt_item not in ranklist else math.log(2) / math.log(ranklist.index(gt_item))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expression should be math.log(2) / math.log(ranklist.index(gt_item) + 2)?



def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main function is way too long. Please extract them into sub functions, like load_data, create_model, and run_training.



if __name__ == '__main__':
parser = argparse.ArgumentParser()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, but probably in future PR, change to use absl flag and tf.logging.

Copy link
Contributor

@k-w-w k-w-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great and easy to understand overall. The model written with Keras is very clean.


# serialize to csv file
df_train_ratings = pd.DataFrame(list(all_ratings))
df_train_ratings['fake_rating'] = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does a fake rating need to be added?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fake rating is to indicate the interaction between the user and the item. Here the 'all_rating' only contains user_id and item_id fields, but the final csv file needs three fields of 'user_id, item_id, interaction'. I will add some comments here.


Args:
est_model: The Estimator.
user_input: The user input for evaluation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expected types/shapes would be good for these arguments and return items

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, and it would be helpful on most of the other functions as well.



def _validate_batch_size_for_multi_gpu(batch_size):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add an explanation about why the batch size needs to be a multiple (is the batch being split across each of the GPUs?)/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The batch will be evenly split across all GPUs. Will add comments there.

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir', type=str, default='/tmp/ml_data',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. ml_data could also stand for "machine learning data", which could confuse the user.

Copy link
Contributor

@karmel karmel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay

@@ -0,0 +1,35 @@
# Recommendation Model
This is an implementation of the Neural Collaborative Filtering (NCF) framework with Neural Matrix Factorization (NeuMF) model as described in the [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031) paper. Current implementation is based on the code from the authors' [NCF code](https://github.com/hexiangnan/neural_collaborative_filtering) and the Standford pytorch implementation at [mlperf repo](https://github.com/mlperf/reference/tree/master/recommendation/pytorch).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, better yet, just "the Stanford implementation." Also note the typo in Stanford.

@@ -0,0 +1,35 @@
# Recommendation Model
This is an implementation of the Neural Collaborative Filtering (NCF) framework with Neural Matrix Factorization (NeuMF) model as described in the [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031) paper. Current implementation is based on the code from the authors' [NCF code](https://github.com/hexiangnan/neural_collaborative_filtering) and the Standford pytorch implementation at [mlperf repo](https://github.com/mlperf/reference/tree/master/recommendation/pytorch).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another nit: "implementation in the MLPerf Repo"

# Recommendation Model
This is an implementation of the Neural Collaborative Filtering (NCF) framework with Neural Matrix Factorization (NeuMF) model as described in the [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031) paper. Current implementation is based on the code from the authors' [NCF code](https://github.com/hexiangnan/neural_collaborative_filtering) and the Standford pytorch implementation at [mlperf repo](https://github.com/mlperf/reference/tree/master/recommendation/pytorch).

NCF is a general framework under which a neural network architecture is proposed to model latent features of users and items in collaborative filtering of recommendation. Unlike traditional models, NCF does not resort to Matrix Factorization (MF) with an inner product on latent features of users and items. It replaces the inner product with a multi-layer perceptron that can learn an arbitrary function from data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand precisely what the first sentence means here. Maybe, "NCF is a type of machine learning model used for collaborative filtering and recommendation in which a neural network is used to determine latent features representing users and items."?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the first sentence is quoted from the paper. I will rephrase it better.

# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Download and extract the movielens dataset from grouplens website.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MovieLens, GroupLens

import tensorflow as tf # pylint: disable=g-bad-import-order

import numpy as np
import pandas as pd
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See https://www.tensorflow.org/api_docs/python/tf/estimator/inputs/pandas_input_fn and the containing module for convenient converters

.format(time.time()-t1, num_users, num_items, train.nnz,
len(test_ratings)))

# Create NeuMF model from tf.keras Model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really a Keras model that is returned, right? Create a keras model from the NeuMF wrapper? (But, see comment below-- why?)

self.model_layers = model_layers
self.learning_rate = learning_rate

def __call__(self, multi_gpu, batch_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call is a reserved method that takes specific inputs and gives specific outputs. (See in Resnet, for example.) If this is just assembling a keras model, let's name it something more straightforward.


class NeuMF(object):
"""Neural matrix factorization (NeuMF) model for recommendations."""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand why this is a class. It seems that this is just a helper that gets the keras model. Why make it a class? And why with the given method names? Somewhat confused here, since this is not ever passed to the estimator.

Copy link
Contributor Author

@yhliang2018 yhliang2018 May 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will update this class as a subclass of keras model.

import zipfile

import numpy as np
import pandas as pd
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure where @qlzh727 's comment went on pandas-- but, it occurs to me now that we are using pandas just to download the data and re-save as csv. I guess that's fine, especially since it's not on the training path, but, it does seem strange. Let's not worry now about fixing this, but maybe add a note that makes it explicit that the actual data loading for the model is with numpy; this is just a series of helpers for processing the data prior to training on it. Although if it's easy to do this without the additional pandas dependencies, that would be preferable.

Prepare input for model training and evaluation.
"""
import numpy as np
import scipy.sparse as sp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another dependency-- note that it will have to be added, along with pandas, to the list of dependencies, and also to any BUILD files.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, starting to review this file, but I think we do this in TF instead of numpy/scipy. That would likely improve performance, and reduce dependencies. See https://www.tensorflow.org/get_started/datasets_quickstart#reading_a_csv_file , and then the rest should just be general matrix manipulation, which should be done in TF directly. Not going to review the rest of this file now, but will re-review when done as "TF-first".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, will do.

robieta
robieta previously requested changes May 7, 2018
Returns:
A processed pandas DataFrame of the rating dataset.
"""
names = ['user_id', 'item_id', 'rating', 'timestamp']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this a global constant?

t2 = time.time()
test_ratings = []
test_negs = []
all_items = set(range(len(original_items)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment about what exactly this is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This basically tries to generate the 0-based index for all the items, and put it into a set for the later use. Comments have added in the code.

test_item = user_to_items[user].pop() # Get the latest one

all_ratings.remove((user, test_item)) # Remove the test item
all_negs = all_items - set(user_to_items[user])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how long this takes, but the .difference(x) set method is generally faster than - set(x).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, will update it.



def main(_):
"""Download and extract the data from grouplens website."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change this to use tempfile and tf.gfile so it works on distributed filesystems?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf # pylint: disable=g-bad-import-order

import data_download
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these be official.recommendation.module_name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Will update it in next PR with absl flags

est_model: The Estimator.
user_input: The user input for evaluation.
item_input: The item input for evaluation.
gt_items: The test item for HR and NDCG calculation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does gt stand for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess gt here is short for ground truth, as this one is the leave-one-out validation item. Will add comment on it.

num_users = len(gt_items)
step = len(user_input) // num_users # Step for each user
# Evaluation on each user
for idx in range(num_users):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

undefined abbreviations make this hard to follow. Does idx have some significance/stand for something? If so could you spell it out, and if not could you make it clear that this is just a loop counter?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:) I think idx should be among the acceptable loop counters. idx is commonly used in numpy, elsewhere for index-- ie, a counter.

start = idx * step
end = (idx + 1) * step

items = item_input[start:end]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is is easy to refactor these into something like a list of lists? There is just extra mental overhead of reading "index into slice" code as opposed to something like predicted_scores=all_predicted_scores[user_id]. If it's too much trouble to change feel free to ignore this request.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Done!

def main(_):
tf.logging.set_verbosity(tf.logging.INFO)

print('NeuMF arguments: {}'.format(FLAGS))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make the logic use a flags object rather than FLAGS following the pattern of the other main functions.

return [int(tmp[0]), int(tmp[1]), float(tmp[2]) > 0]

lines = open(train_fname, 'r').readlines()[1:]
data = list(map(process_line, lines))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally prefer list comprehension to map() because the behavior changes from 2 to 3.

NCF is a general framework under which a neural network architecture is proposed to model latent features of users and items in collaborative filtering of recommendation. Unlike traditional models, NCF does not resort to Matrix Factorization (MF) with an inner product on latent features of users and items. It replaces the inner product with a multi-layer perceptron that can learn an arbitrary function from data.

Two instantiations of NCF are Generalized Matrix Factorization (GMF) and Multi-Layer Perceptron (MLP). GMF applies a linear kernel to model the latent feature interactions, and and MLP uses a nonlinear kernel to learn the interaction function from data. NeuMF is a fused model of GMF and MLP to better model the complex user-item interactions, and unifies the strengths of linearity of MF and non-linearity of MLP for modeling the user-item latent structures. NeuMF allows GMF and MLP to learn separate embeddings, and combines the two models by concatenating their last hidden layer. [neumf_model.py](neumf_model.py) defines the architecture details.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to get a list of abbreviations used (like the docstring at the top of neumf_model.py), including HR, GT, NDCG, ml-1m, etc.


Note the ml-20m dataset is large (the rating file is ~500 MB), and it may take several minutes (<10 mins)for data preprocessing.

### Train and Evaluate Model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Including how long the model takes the train and expected final metrics would be great

dataset_name: The dataset name to be processed.
"""

# Use random seed as parameter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be misunderstanding something, but the random seed is being set to 0, not a parameter.


Args:
est_model: The Estimator.
user_input: The user input for evaluation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, and it would be helpful on most of the other functions as well.

inputs=[user_input, item_input], outputs=prediction)

# Use multiple gpus
if multi_gpu:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will update this multi_gpu mode with distribution strategy, based on discussions with @robieta

@yhliang2018 yhliang2018 force-pushed the official/recommendation branch 2 times, most recently from a98dc2f to 89770cf Compare May 16, 2018 22:11
@yhliang2018
Copy link
Contributor Author

Hi All,

Thanks a lot for the helpful comments. I have updated the data input pipeline with tf.dataset, and also add the distribution strategy for multiple-gpu mode. Now the Adam optimizer doesn't work well with distribution strategy, so we use GradientDescent optimizer for now. Will update it as the bug is fixed.

Will add unit test and official flags in next PR.

Copy link
Contributor

@karmel karmel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there-- a few relatively minor comments. Thanks--

return ratings


def load_movielens_1million(file_name, sort=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra underscore (1_million, 20_million) would be clearer.

- Ratings are made on a 5-star scale (whole-star ratings only)
- Timestamp is represented in seconds since midnight Coordinated Universal
Time (UTC) of January 1, 1970.
- Each user has at least 20 ratings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:)



def load_data(file_name):
"""Load data from a csv file."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This says csv, but splits on \t.


def load_data(file_name):
"""Load data from a csv file."""
lines = open(file_name, "r").readlines()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tf.gfile.GFile().

"""
# Load training positive instances into memory for later train data generation
train_data = load_data(train_fname)
num_users = max(train_data, key=lambda x: x[0])[0] + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we return a np.array from load_dataset instead? That allows us to grab max more efficiently than doing this. Also, this seems unreliable; if I have non-sequential uids, this is not true, right? Better to do unique()?

cycle_index, total_training_cycle - 1))

# Train the model
train_cycle_begin = time.time()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be done in the followup, but probably not necessary to collect/print out this very rough timing info.

help="Batch size.")
parser.add_argument(
"--num_factors", type=int, default=8,
help="Embedding size of MF model.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

embedding_size maybe? "Factors" is used a lot in MF proper, but our users are more likely deep-learning-first rather than MF-first.

help="Size of hidden layers for MLP.")
parser.add_argument(
"--reg_mf", type=float, default=0,
help="Regularization for MF embeddings.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mf_regularization?

reg_mf: A floating number, the regularization for MF embeddings.

Raises:
ValueError: if the first model layer is not even.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"not a multiple of two" would be clearer


# Input variables
user_input = tf.keras.layers.Input(
shape=(1,), dtype="int32", name=constants.USER)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird that the dtype is a string here. tf.int32 doesn't work?

@yhliang2018
Copy link
Contributor Author

Hi Karmel,

Thank you for the helpful comments! I have addressed most of them. Others like remove print statements, unit test, and official flags will be done in next PR.

As our offline discussion, for input_fn we still use csv files for tf.dataset which is required by distribution strategy. Will create a CL for dataset_input_fn (similar to numpy_input_fn), or seeking other possible approaches later.

@robieta robieta dismissed their stale review May 21, 2018 23:30

My issues have been addressed.

@karmel
Copy link
Contributor

karmel commented May 22, 2018

Thanks-- note that @qlzh727 is out, so you may have to assume he approves for now, and make any last requested changes in a follow-up.

@yhliang2018 yhliang2018 dismissed qlzh727’s stale review May 22, 2018 05:31

Will make the requested changes in next PR as Karmel suggested

@yhliang2018 yhliang2018 merged commit 81d7766 into master May 22, 2018
@yhliang2018 yhliang2018 deleted the official/recommendation branch May 22, 2018 05:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants