diff --git a/.gitignore b/.gitignore
index c07513e..a917dee 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,7 @@
*.class
*.lst
+**/models/**
+**/data/**
+**/.idea/**
+*.tar.gz
+**/log.txt
\ No newline at end of file
diff --git a/PathContextReader.py b/PathContextReader.py
deleted file mode 100644
index 7b556a5..0000000
--- a/PathContextReader.py
+++ /dev/null
@@ -1,156 +0,0 @@
-import tensorflow as tf
-import common
-
-no_such_word = 'NOSUCH'
-no_such_composite = no_such_word + ',' + no_such_word + ',' + no_such_word
-
-
-class PathContextReader:
- class_word_table = None
- class_target_word_table = None
- class_path_table = None
-
- def __init__(self, word_to_index, target_word_to_index, path_to_index, config, is_evaluating=False):
- self.file_path = config.TEST_PATH if is_evaluating else (config.TRAIN_PATH + '.train.c2v')
- self.batch_size = config.TEST_BATCH_SIZE if is_evaluating else min(config.BATCH_SIZE, config.NUM_EXAMPLES)
- self.num_epochs = config.NUM_EPOCHS
- self.reading_batch_size = config.READING_BATCH_SIZE if is_evaluating else min(config.READING_BATCH_SIZE, config.NUM_EXAMPLES)
- self.num_batching_threads = config.NUM_BATCHING_THREADS
- self.batch_queue_size = config.BATCH_QUEUE_SIZE
- self.data_num_contexts = config.MAX_CONTEXTS
- self.max_contexts = config.MAX_CONTEXTS
- self.is_evaluating = is_evaluating
-
- self.word_table = PathContextReader.get_word_table(word_to_index)
- self.target_word_table = PathContextReader.get_target_word_table(target_word_to_index)
- self.path_table = PathContextReader.get_path_table(path_to_index)
- self.filtered_output = self.get_filtered_input()
-
- @classmethod
- def get_word_table(cls, word_to_index):
- if cls.class_word_table is None:
- cls.class_word_table = cls.initalize_hash_map(word_to_index, 0)
- return cls.class_word_table
-
- @classmethod
- def get_target_word_table(cls, target_word_to_index):
- if cls.class_target_word_table is None:
- cls.class_target_word_table = cls.initalize_hash_map(target_word_to_index, 0)
- return cls.class_target_word_table
-
- @classmethod
- def get_path_table(cls, path_to_index):
- if cls.class_path_table is None:
- cls.class_path_table = cls.initalize_hash_map(path_to_index, 0)
- return cls.class_path_table
-
- @classmethod
- def initalize_hash_map(cls, word_to_index, default_value):
- return tf.contrib.lookup.HashTable(
- tf.contrib.lookup.KeyValueTensorInitializer(list(word_to_index.keys()), list(word_to_index.values()),
- key_dtype=tf.string,
- value_dtype=tf.int32), default_value)
-
- def get_input_placeholder(self):
- return self.input_placeholder
-
- def start(self, session, data_lines=None):
- self.coord = tf.train.Coordinator()
- self.threads = tf.train.start_queue_runners(sess=session, coord=self.coord)
- return self
-
- def read_file(self):
- row = self.get_row_input()
- record_defaults = [[no_such_composite]] * (self.data_num_contexts + 1)
- row_parts = tf.decode_csv(row, record_defaults=record_defaults, field_delim=' ')
- word = row_parts[0] # (batch, )
- contexts = tf.stack(row_parts[1:(self.max_contexts + 1)], axis=1) # (batch, max_contexts)
-
- flat_contexts = tf.reshape(contexts, [-1]) # (batch * max_contexts, )
- split_contexts = tf.string_split(flat_contexts, delimiter=',')
- dense_split_contexts = tf.reshape(tf.sparse_tensor_to_dense(split_contexts,
- default_value=no_such_word),
- shape=[-1, self.max_contexts, 3]) # (batch, max_contexts, 3)
-
- if self.is_evaluating:
- target_word_label = word # (batch, ) of string
- else:
- target_word_label = self.target_word_table.lookup(word) # (batch, ) of int
-
- path_source_strings = tf.slice(dense_split_contexts, [0, 0, 0], [-1, self.max_contexts, 1])
- path_source_indices = self.word_table.lookup(path_source_strings) # (batch, max_contexts, 1)
- path_strings = tf.slice(dense_split_contexts, [0, 0, 1], [-1, self.max_contexts, 1])
- path_indices = self.path_table.lookup(path_strings) # (batch, max_contexts, 1)
- path_target_strings = tf.slice(dense_split_contexts, [0, 0, 2], [-1, self.max_contexts, 1])
- path_target_indices = self.word_table.lookup(path_target_strings) # (batch, max_contexts, 1)
-
- return target_word_label, path_source_indices, path_target_indices, path_indices, \
- path_source_strings, path_strings, path_target_strings
-
- def get_row_input(self):
- if self.is_evaluating: # test, read from queue (small data)
- row = self.input_placeholder = tf.placeholder(tf.string)
- else: # training, read from file
- filename_queue = tf.train.string_input_producer([self.file_path], num_epochs=self.num_epochs, shuffle=False)
- reader = tf.TextLineReader()
- _, row = reader.read_up_to(filename_queue, num_records=self.reading_batch_size)
- return row
-
- def input_tensors(self):
- return self.initialize_batch_outputs(self.filtered_output[:-3])
-
- def get_filtered_batches(self):
- return self.filtered_output
-
- def initialize_batch_outputs(self, filtered_input):
- return tf.train.shuffle_batch(filtered_input,
- batch_size=self.batch_size,
- enqueue_many=True,
- capacity=self.batch_queue_size,
- min_after_dequeue=int(self.batch_queue_size * 0.85),
- num_threads=self.num_batching_threads,
- allow_smaller_final_batch=True)
-
- def get_filtered_input(self):
- word_label, path_source_indices, path_target_indices, path_indices, \
- source_strings, path_strings, target_strings = self.read_file()
- any_contexts_is_valid = tf.logical_or(
- tf.greater(tf.squeeze(tf.reduce_max(path_source_indices, 1), axis=1), 0),
- tf.logical_or(
- tf.greater(tf.squeeze(tf.reduce_max(path_target_indices, 1), axis=1), 0),
- tf.greater(tf.squeeze(tf.reduce_max(path_indices, 1), axis=1), 0))
- ) # (batch, )
-
- if self.is_evaluating:
- cond = tf.where(any_contexts_is_valid)
- else: # training
- word_is_valid = tf.greater(word_label, 0) # (batch, )
- cond = tf.where(tf.logical_and(word_is_valid, any_contexts_is_valid)) # (batch, 1)
- valid_mask = tf.to_float( # (batch, max_contexts, 1)
- tf.logical_or(tf.logical_or(tf.greater(path_source_indices, 0),
- tf.greater(path_target_indices, 0)),
- tf.greater(path_indices, 0))
- )
-
- filtered = \
- tf.gather(word_label, cond), \
- tf.squeeze(tf.gather(path_source_indices, cond), [1, 3]), \
- tf.squeeze(tf.gather(path_indices, cond), [1, 3]), \
- tf.squeeze(tf.gather(path_target_indices, cond), [1, 3]), \
- tf.squeeze(tf.gather(valid_mask, cond), [1, 3]), \
- tf.squeeze(tf.gather(source_strings, cond)), \
- tf.squeeze(tf.gather(path_strings, cond)), \
- tf.squeeze(tf.gather(target_strings, cond)) # (batch, max_contexts)
-
- return filtered
-
- def __enter__(self):
- return self
-
- def should_stop(self):
- return self.coord.should_stop()
-
- def __exit__(self, type, value, traceback):
- print('Reader stopping')
- self.coord.request_stop()
- self.coord.join(self.threads)
diff --git a/README.md b/README.md
index 16e7ced..51dd179 100644
--- a/README.md
+++ b/README.md
@@ -5,10 +5,12 @@ This is an official implementation of the model described in:
[Uri Alon](http://urialon.cswp.cs.technion.ac.il), [Meital Zilberstein](http://www.cs.technion.ac.il/~mbs/), [Omer Levy](https://levyomer.wordpress.com) and [Eran Yahav](http://www.cs.technion.ac.il/~yahave/),
"code2vec: Learning Distributed Representations of Code", POPL'2019 [[PDF]](https://urialon.cswp.cs.technion.ac.il/wp-content/uploads/sites/83/2018/12/code2vec-popl19.pdf)
-_**October 2018** - the paper was accepted to [POPL'2019](https://popl19.sigplan.org)_!
+_**October 2018** - The paper was accepted to [POPL'2019](https://popl19.sigplan.org)_!
_**April 2019** - The talk video is available [here](https://www.youtube.com/watch?v=EJ8okcxL2Iw)_.
+_**July 2019** - Add `tf.keras` model implementation (see [here](#choosing-implementation-to-use))._
+
An **online demo** is available at [https://code2vec.org/](https://code2vec.org/).
This is a TensorFlow implementation, designed to be easy and useful in research,
@@ -16,6 +18,7 @@ and for experimenting with new ideas in machine learning for code tasks.
By default, it learns Java source code and predicts Java method names, but it can be easily extended to other languages,
since the TensorFlow network is agnostic to the input programming language (see [Extending to other languages](#extending-to-other-languages).
Contributions are welcome.
+This repo actually contains two model implementations. The 1st uses pure TensorFlow and the 2nd uses TensorFlow's Keras.
@@ -33,13 +36,18 @@ Table of Contents
On Ubuntu:
* [Python3](https://www.linuxbabe.com/ubuntu/install-python-3-6-ubuntu-16-04-16-10-17-04). To check if you have it:
> python3 --version
- * TensorFlow - version 1.5 or newer ([install](https://www.tensorflow.org/install/install_linux)). To check TensorFlow version:
+ * TensorFlow - version 2.0.0-beta1 ([install](https://www.tensorflow.org/install/install_linux)).
+ To check TensorFlow version:
> python3 -c 'import tensorflow as tf; print(tf.\_\_version\_\_)'
- * If you are using a GPU, you will need CUDA 9.0 ([download](https://developer.nvidia.com/cuda-90-download-archive))
+ * If you are using a GPU, you will need CUDA 10.0
+ ([download](https://developer.nvidia.com/cuda-10.0-download-archive-base))
as this is the version that is currently supported by TensorFlow. To check CUDA version:
> nvcc --version
- * For GPU: cuDNN (>=7.0) ([download](http://developer.nvidia.com/cudnn))
- * For [creating a new dataset](#creating-and-preprocessing-a-new-java-dataset) or [manually examining a trained model](#step-4-manual-examination-of-a-trained-model) (any operation that requires parsing of a new code example) - [Java JDK](https://openjdk.java.net/install/)
+ * For GPU: cuDNN (>=7.5) ([download](http://developer.nvidia.com/cudnn)) To check cuDNN version:
+> cat /usr/include/cudnn.h | grep CUDNN_MAJOR -A 2
+ * For [creating a new dataset](#creating-and-preprocessing-a-new-java-dataset)
+ or [manually examining a trained model](#step-4-manual-examination-of-a-trained-model)
+ (any operation that requires parsing of a new code example) - [Java JDK](https://openjdk.java.net/install/)
## Quickstart
### Step 0: Cloning this repository
@@ -124,46 +132,74 @@ To manually examine a trained model, run:
```
python3 code2vec.py --load models/java14_model/saved_model_iter8 --predict
```
-After the model loads, follow the instructions and edit the file Input.java and enter a Java
+After the model loads, follow the instructions and edit the file [Input.java](Input.java) and enter a Java
method or code snippet, and examine the model's predictions and attention scores.
## Configuration
-Changing hyper-parameters is possible by editing the file [common.py](common
-.py).
+Changing hyper-parameters is possible by editing the file
+[common.py](common.py).
Here are some of the parameters and their description:
-#### config.NUM_EPOCHS = 20
+#### config.NUM_TRAIN_EPOCHS = 20
The max number of epochs to train the model. Stopping earlier must be done manually (kill).
#### config.SAVE_EVERY_EPOCHS = 1
After how many training iterations a model should be saved.
-#### config.BATCH_SIZE = 1024
+#### config.TRAIN_BATCH_SIZE = 1024
Batch size in training.
-#### config.TEST_BATCH_SIZE = config.BATCH_SIZE
+#### config.TEST_BATCH_SIZE = config.TRAIN_BATCH_SIZE
Batch size in evaluating. Affects only the evaluation speed and memory consumption, does not affect the results.
-#### config.READING_BATCH_SIZE = 1300 * 4
-The batch size of reading text lines to the queue that feeds examples to the network during training.
-#### config.NUM_BATCHING_THREADS = 2
-The number of threads enqueuing examples.
-#### config.BATCH_QUEUE_SIZE = 300000
-Max number of elements in the feeding queue.
-#### config.DATA_NUM_CONTEXTS = 200
-The number of contexts in a single example, as was created in preprocessing.
+#### config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION = 10
+Number of words with highest scores in $ y_hat $ to consider during prediction and evaluation.
+#### config.NUM_BATCHES_TO_LOG_PROGRESS = 100
+Number of batches (during training / evaluating) to complete between two progress-logging records.
+#### config.NUM_TRAIN_BATCHES_TO_EVALUATE = 100
+Number of training batches to complete between model evaluations on the test set.
+#### config.READER_NUM_PARALLEL_BATCHES = 4
+The number of threads enqueuing examples to the reader queue.
+#### config.SHUFFLE_BUFFER_SIZE = 10000
+Size of buffer in reader to shuffle example within during training.
+Bigger buffer allows better randomness, but requires more amount of memory and may harm training throughput.
+#### config.CSV_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB
+The buffer size (in bytes) of the CSV dataset reader.
+
#### config.MAX_CONTEXTS = 200
The number of contexts to use in each example.
-#### config.WORDS_VOCAB_SIZE = 1301136
+#### config.MAX_TOKEN_VOCAB_SIZE = 1301136
The max size of the token vocabulary.
-#### config.TARGET_VOCAB_SIZE = 261245
+#### config.MAX_TARGET_VOCAB_SIZE = 261245
The max size of the target words vocabulary.
-#### config.PATHS_VOCAB_SIZE = 911417
+#### config.MAX_PATH_VOCAB_SIZE = 911417
The max size of the path vocabulary.
-#### config.EMBEDDINGS_SIZE = 128
-Embedding size for tokens and paths.
+#### config.DEFAULT_EMBEDDINGS_SIZE = 128
+Default embedding size to be used for token and path if not specified otherwise.
+#### config.TOKEN_EMBEDDINGS_SIZE = config.EMBEDDINGS_SIZE
+Embedding size for tokens.
+#### config.PATH_EMBEDDINGS_SIZE = config.EMBEDDINGS_SIZE
+Embedding size for paths.
+#### config.CODE_VECTOR_SIZE = config.PATH_EMBEDDINGS_SIZE + 2 * config.TOKEN_EMBEDDINGS_SIZE
+Size of code vectors.
+#### config.TARGET_EMBEDDINGS_SIZE = config.CODE_VECTOR_SIZE
+Embedding size for target words.
#### config.MAX_TO_KEEP = 10
Keep this number of newest trained versions during training.
+#### config.DROPOUT_KEEP_RATE = 0.75
+Dropout rate used during training.
+#### config.SEPARATE_OOV_AND_PAD = False
+Whether to treat `` and `` as two different special tokens whenever possible.
## Features
Code2vec supports the following features:
+### Choosing implementation to use
+This repo comes with two model implementations:
+(i) uses pure TensorFlow (written in [tensorflow_model.py](tensorflow_model.py));
+(ii) uses TensorFlow's Keras (written in [keras_model.py](keras_model.py)).
+The default implementation used by `code2vec.py` is the pure TensorFlow.
+To explicitly choose the desired implementation to use, specify `--framework tensorflow` or `--framework keras`
+as an additional argument when executing the script `code2vec.py`.
+Particularly, this argument can be added to each one of the usage examples (of `code2vec.py`) detailed in this file.
+Note that in order to load a trained model (from file), one should use the same implementation used during its training.
+
### Releasing the model
If you wish to keep a trained model for inference only (without the ability to continue training it) you can
release the model using:
diff --git a/code2vec.py b/code2vec.py
index 554d49e..135a24a 100644
--- a/code2vec.py
+++ b/code2vec.py
@@ -1,56 +1,38 @@
-from common import Config, VocabType
-from argparse import ArgumentParser
+from vocabularies import VocabType
+from config import Config
from interactive_predict import InteractivePredictor
-from model import Model
-import sys
+from model_base import Code2VecModelBase
-if __name__ == '__main__':
- parser = ArgumentParser()
- parser.add_argument("-d", "--data", dest="data_path",
- help="path to preprocessed dataset", required=False)
- parser.add_argument("-te", "--test", dest="test_path",
- help="path to test file", metavar="FILE", required=False)
- is_training = '--train' in sys.argv or '-tr' in sys.argv
- parser.add_argument("-s", "--save", dest="save_path",
- help="path to save file", metavar="FILE", required=False)
- parser.add_argument("-w2v", "--save_word2v", dest="save_w2v",
- help="path to save file", metavar="FILE", required=False)
- parser.add_argument("-t2v", "--save_target2v", dest="save_t2v",
- help="path to save file", metavar="FILE", required=False)
- parser.add_argument("-l", "--load", dest="load_path",
- help="path to save file", metavar="FILE", required=False)
- parser.add_argument('--save_w2v', dest='save_w2v', required=False,
- help="save word (token) vectors in word2vec format")
- parser.add_argument('--save_t2v', dest='save_t2v', required=False,
- help="save target vectors in word2vec format")
- parser.add_argument('--export_code_vectors', action='store_true', required=False,
- help="export code vectors for the given examples")
- parser.add_argument('--release', action='store_true',
- help='if specified and loading a trained model, release the loaded model for a lower model '
- 'size.')
- parser.add_argument('--predict', action='store_true')
- args = parser.parse_args()
+def load_model_dynamically(config: Config) -> Code2VecModelBase:
+ assert config.DL_FRAMEWORK in {'tensorflow', 'keras'}
+ if config.DL_FRAMEWORK == 'tensorflow':
+ from tensorflow_model import Code2VecModel
+ elif config.DL_FRAMEWORK == 'keras':
+ from keras_model import Code2VecModel
+ return Code2VecModel(config)
+
+
+if __name__ == '__main__':
+ config = Config(set_defaults=True, load_from_args=True, verify=True)
- config = Config.get_default_config(args)
+ model = load_model_dynamically(config)
+ config.log('Done creating code2vec model')
- model = Model(config)
- print('Created model')
- if config.TRAIN_PATH:
+ if config.is_training:
model.train()
- if args.save_w2v is not None:
- model.save_word2vec_format(args.save_w2v, source=VocabType.Token)
- print('Origin word vectors saved in word2vec text format in: %s' % args.save_w2v)
- if args.save_t2v is not None:
- model.save_word2vec_format(args.save_t2v, source=VocabType.Target)
- print('Target word vectors saved in word2vec text format in: %s' % args.save_t2v)
- if config.TEST_PATH and not args.data_path:
+ if config.SAVE_W2V is not None:
+ model.save_word2vec_format(config.SAVE_W2V, VocabType.Token)
+ config.log('Origin word vectors saved in word2vec text format in: %s' % config.SAVE_W2V)
+ if config.SAVE_T2V is not None:
+ model.save_word2vec_format(config.SAVE_T2V, VocabType.Target)
+ config.log('Target word vectors saved in word2vec text format in: %s' % config.SAVE_T2V)
+ if config.is_testing and not config.is_training:
eval_results = model.evaluate()
if eval_results is not None:
- results, precision, recall, f1 = eval_results
- print(results)
- print('Precision: ' + str(precision) + ', recall: ' + str(recall) + ', F1: ' + str(f1))
- if args.predict:
+ config.log(
+ str(eval_results).replace('topk', 'top{}'.format(config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION)))
+ if config.PREDICT:
predictor = InteractivePredictor(config, model)
predictor.predict()
model.close_session()
diff --git a/common.py b/common.py
index 1bb2708..4c99658 100644
--- a/common.py
+++ b/common.py
@@ -1,58 +1,13 @@
import re
-import json
-import sys
-from enum import Enum
-
-
-class Config:
- @staticmethod
- def get_default_config(args):
- config = Config()
- config.NUM_EPOCHS = 20
- config.SAVE_EVERY_EPOCHS = 1
- config.BATCH_SIZE = 1024
- config.TEST_BATCH_SIZE = config.BATCH_SIZE
- config.READING_BATCH_SIZE = 1300 * 4
- config.NUM_BATCHING_THREADS = 2
- config.BATCH_QUEUE_SIZE = 300000
- config.MAX_CONTEXTS = 200
- config.WORDS_VOCAB_SIZE = 1301136
- config.TARGET_VOCAB_SIZE = 261245
- config.PATHS_VOCAB_SIZE = 911417
- config.EMBEDDINGS_SIZE = 128
- config.MAX_TO_KEEP = 10
- # Automatically filled, do not edit:
- config.TRAIN_PATH = args.data_path
- config.TEST_PATH = args.test_path
- config.SAVE_PATH = args.save_path
- config.LOAD_PATH = args.load_path
- config.RELEASE = args.release
- config.EXPORT_CODE_VECTORS = args.export_code_vectors
- return config
-
- def __init__(self):
- self.NUM_EPOCHS = 0
- self.SAVE_EVERY_EPOCHS = 0
- self.BATCH_SIZE = 0
- self.TEST_BATCH_SIZE = 0
- self.READING_BATCH_SIZE = 0
- self.NUM_BATCHING_THREADS = 0
- self.BATCH_QUEUE_SIZE = 0
- self.TRAIN_PATH = ''
- self.TEST_PATH = ''
- self.MAX_CONTEXTS = 0
- self.WORDS_VOCAB_SIZE = 0
- self.TARGET_VOCAB_SIZE = 0
- self.PATHS_VOCAB_SIZE = 0
- self.EMBEDDINGS_SIZE = 0
- self.SAVE_PATH = ''
- self.LOAD_PATH = ''
- self.MAX_TO_KEEP = 0
- self.RELEASE = False
- self.EXPORT_CODE_VECTORS = False
+import numpy as np
+import tensorflow as tf
+from itertools import takewhile, repeat
+from typing import List, Optional, Tuple, Iterable
+from datetime import datetime
+from collections import OrderedDict
+
class common:
- noSuchWord = "NoSuchWord"
@staticmethod
def normalize_word(word):
@@ -88,22 +43,6 @@ def _load_vocab_from_histogram(path, min_count=0, start_from=0, return_counts=Fa
result = (*result, word_to_count)
return result
- @staticmethod
- def _load_vocab_from_dict(word_to_count, min_count=0, start_from=0):
- word_to_index = {}
- index_to_word = {}
- next_index = start_from
- for word, count in word_to_count.items():
- if count < min_count:
- continue
- if word in word_to_index:
- continue
- word_to_index[word] = next_index
- index_to_word[next_index] = word
- word_to_count[word] = count
- next_index += 1
- return word_to_index, index_to_word, next_index - start_from
-
@staticmethod
def load_vocab_from_histogram(path, min_count=0, start_from=0, max_size=None, return_counts=False):
if max_size is not None:
@@ -118,15 +57,6 @@ def load_vocab_from_histogram(path, min_count=0, start_from=0, max_size=None, re
min_count = sorted(word_to_count.values(), reverse=True)[max_size] + 1
return common._load_vocab_from_histogram(path, min_count, start_from, return_counts)
- @staticmethod
- def load_vocab_from_dict(word_to_count, max_size=None, start_from=0):
- if max_size is not None:
- if max_size > len(word_to_count):
- min_count = 0
- else:
- min_count = sorted(word_to_count.values(), reverse=True)[max_size] + 1
- return common._load_vocab_from_dict(word_to_count, min_count, start_from)
-
@staticmethod
def load_json(json_file):
data = []
@@ -150,12 +80,15 @@ def load_json_streaming(json_file):
yield (element, scope)
@staticmethod
- def save_word2vec_file(file, vocab_size, dimension, index_to_word, vectors):
- file.write('%d %d\n' % (vocab_size, dimension))
- for i in range(1, vocab_size + 1):
- if i in index_to_word:
- file.write(index_to_word[i] + ' ')
- file.write(' '.join(map(str, vectors[i])) + '\n')
+ def save_word2vec_file(output_file, index_to_word, vocab_embedding_matrix: np.ndarray):
+ assert len(vocab_embedding_matrix.shape) == 2
+ vocab_size, embedding_dimension = vocab_embedding_matrix.shape
+ output_file.write('%d %d\n' % (vocab_size, embedding_dimension))
+ for word_idx in range(0, vocab_size):
+ assert word_idx in index_to_word
+ word_str = index_to_word[word_idx]
+ output_file.write(word_str + ' ')
+ output_file.write(' '.join(map(str, vocab_embedding_matrix[word_idx])) + '\n')
@staticmethod
def calculate_max_contexts(file):
@@ -183,15 +116,16 @@ def load_file_lines(path):
@staticmethod
def split_to_batches(data_lines, batch_size):
- return [data_lines[x:x + batch_size] for x in range(0, len(data_lines), batch_size)]
+ for x in range(0, len(data_lines), batch_size):
+ yield data_lines[x:x + batch_size]
@staticmethod
- def legal_method_names_checker(name):
- return name != common.noSuchWord and re.match('^[a-zA-Z\|]+$', name)
+ def legal_method_names_checker(special_words, name):
+ return name != special_words.OOV and re.match(r'^[a-zA-Z|]+$', name)
@staticmethod
- def filter_impossible_names(top_words):
- result = list(filter(common.legal_method_names_checker, top_words))
+ def filter_impossible_names(special_words, top_words):
+ result = list(filter(lambda word: common.legal_method_names_checker(special_words, word), top_words))
return result
@staticmethod
@@ -199,19 +133,22 @@ def get_subtokens(str):
return str.split('|')
@staticmethod
- def parse_results(result, unhash_dict, topk=5):
+ def parse_prediction_results(raw_prediction_results, unhash_dict, special_words, topk: int = 5) -> List['MethodPredictionResults']:
prediction_results = []
- for single_method in result:
- original_name, top_suggestions, top_scores, attention_per_context = list(single_method)
- current_method_prediction_results = PredictionResults(original_name)
- for i, predicted in enumerate(top_suggestions):
- if predicted == common.noSuchWord:
+ for single_method_prediction in raw_prediction_results:
+ current_method_prediction_results = MethodPredictionResults(single_method_prediction.original_name)
+ for i, predicted in enumerate(single_method_prediction.topk_predicted_words):
+ if predicted == special_words.OOV:
continue
suggestion_subtokens = common.get_subtokens(predicted)
- current_method_prediction_results.append_prediction(suggestion_subtokens, top_scores[i].item())
- for context, attention in [(key, attention_per_context[key]) for key in
- sorted(attention_per_context, key=attention_per_context.get, reverse=True)][
- :topk]:
+ current_method_prediction_results.append_prediction(
+ suggestion_subtokens, single_method_prediction.topk_predicted_words_scores[i].item())
+ topk_attention_per_context = [
+ (key, single_method_prediction.attention_per_context[key])
+ for key in sorted(single_method_prediction.attention_per_context,
+ key=single_method_prediction.attention_per_context.get, reverse=True)
+ ][:topk]
+ for context, attention in topk_attention_per_context:
token1, hashed_path, token2 = context
if hashed_path in unhash_dict:
unhashed_path = unhash_dict[hashed_path]
@@ -220,8 +157,51 @@ def parse_results(result, unhash_dict, topk=5):
prediction_results.append(current_method_prediction_results)
return prediction_results
+ @staticmethod
+ def tf_get_first_true(bool_tensor: tf.Tensor) -> tf.Tensor:
+ bool_tensor_as_int32 = tf.cast(bool_tensor, dtype=tf.int32)
+ cumsum = tf.cumsum(bool_tensor_as_int32, axis=-1, exclusive=False)
+ return tf.logical_and(tf.equal(cumsum, 1), bool_tensor)
+
+ @staticmethod
+ def count_lines_in_file(file_path: str):
+ with open(file_path, 'rb') as f:
+ bufgen = takewhile(lambda x: x, (f.raw.read(1024 * 1024) for _ in repeat(None)))
+ return sum(buf.count(b'\n') for buf in bufgen)
+
+ @staticmethod
+ def squeeze_single_batch_dimension_for_np_arrays(arrays):
+ assert all(array is None or isinstance(array, np.ndarray) or isinstance(array, tf.Tensor) for array in arrays)
+ return tuple(
+ None if array is None else np.squeeze(array, axis=0)
+ for array in arrays
+ )
+
+ @staticmethod
+ def get_first_match_word_from_top_predictions(special_words, original_name, top_predicted_words) -> Optional[Tuple[int, str]]:
+ normalized_original_name = common.normalize_word(original_name)
+ for suggestion_idx, predicted_word in enumerate(common.filter_impossible_names(special_words, top_predicted_words)):
+ normalized_possible_suggestion = common.normalize_word(predicted_word)
+ if normalized_original_name == normalized_possible_suggestion:
+ return suggestion_idx, predicted_word
+ return None
+
+ @staticmethod
+ def now_str():
+ return datetime.now().strftime("%Y%m%d-%H%M%S: ")
+
+ @staticmethod
+ def chunks(l, n):
+ """Yield successive n-sized chunks from l."""
+ for i in range(0, len(l), n):
+ yield l[i:i + n]
-class PredictionResults:
+ @staticmethod
+ def get_unique_list(lst: Iterable) -> list:
+ return list(OrderedDict(((item, 0) for item in lst)).keys())
+
+
+class MethodPredictionResults:
def __init__(self, original_name):
self.original_name = original_name
self.predictions = list()
@@ -235,8 +215,3 @@ def append_attention_path(self, attention_score, token1, path, token2):
'path': path,
'token1': token1,
'token2': token2})
-
-
-class VocabType(Enum):
- Token = 1
- Target = 2
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..204af01
--- /dev/null
+++ b/config.py
@@ -0,0 +1,269 @@
+from math import ceil
+from typing import Optional
+import logging
+from argparse import ArgumentParser
+import sys
+
+
+class Config:
+ @classmethod
+ def arguments_parser(cls) -> ArgumentParser:
+ parser = ArgumentParser()
+ parser.add_argument("-d", "--data", dest="data_path",
+ help="path to preprocessed dataset", required=False)
+ parser.add_argument("-te", "--test", dest="test_path",
+ help="path to test file", metavar="FILE", required=False)
+ parser.add_argument("-s", "--save", dest="save_path",
+ help="path to save the model file", metavar="FILE", required=False)
+ parser.add_argument("-w2v", "--save_word2v", dest="save_w2v",
+ help="path to save the tokens embeddings file", metavar="FILE", required=False)
+ parser.add_argument("-t2v", "--save_target2v", dest="save_t2v",
+ help="path to save the targets embeddings file", metavar="FILE", required=False)
+ parser.add_argument("-l", "--load", dest="load_path",
+ help="path to load the model from", metavar="FILE", required=False)
+ parser.add_argument('--save_w2v', dest='save_w2v', required=False,
+ help="save word (token) vectors in word2vec format")
+ parser.add_argument('--save_t2v', dest='save_t2v', required=False,
+ help="save target vectors in word2vec format")
+ parser.add_argument('--export_code_vectors', action='store_true', required=False,
+ help="export code vectors for the given examples")
+ parser.add_argument('--release', action='store_true',
+ help='if specified and loading a trained model, release the loaded model for a lower model '
+ 'size.')
+ parser.add_argument('--predict', action='store_true',
+ help='execute the interactive prediction shell')
+ parser.add_argument("-fw", "--framework", dest="dl_framework", choices=['keras', 'tensorflow'],
+ default='tensorflow', help="deep learning framework to use.")
+ parser.add_argument("-v", "--verbose", dest="verbose_mode", type=int, required=False, default=1,
+ help="verbose mode (should be in {0,1,2}).")
+ parser.add_argument("-lp", "--logs-path", dest="logs_path", metavar="FILE", required=False,
+ help="path to store logs into. if not given logs are not saved to file.")
+ parser.add_argument('-tb', '--tensorboard', dest='use_tensorboard', action='store_true',
+ help='use tensorboard during training')
+ return parser
+
+ def set_defaults(self):
+ self.NUM_TRAIN_EPOCHS = 20
+ self.SAVE_EVERY_EPOCHS = 1
+ self.TRAIN_BATCH_SIZE = 1024
+ self.TEST_BATCH_SIZE = self.TRAIN_BATCH_SIZE
+ self.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION = 10
+ self.NUM_BATCHES_TO_LOG_PROGRESS = 100
+ self.NUM_TRAIN_BATCHES_TO_EVALUATE = 1800
+ self.READER_NUM_PARALLEL_BATCHES = 6 # cpu cores [for tf.contrib.data.map_and_batch() in the reader]
+ self.SHUFFLE_BUFFER_SIZE = 10000
+ self.CSV_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB
+ self.MAX_TO_KEEP = 10
+
+ # model hyper-params
+ self.MAX_CONTEXTS = 200
+ self.MAX_TOKEN_VOCAB_SIZE = 1301136
+ self.MAX_TARGET_VOCAB_SIZE = 261245
+ self.MAX_PATH_VOCAB_SIZE = 911417
+ self.DEFAULT_EMBEDDINGS_SIZE = 128
+ self.TOKEN_EMBEDDINGS_SIZE = self.DEFAULT_EMBEDDINGS_SIZE
+ self.PATH_EMBEDDINGS_SIZE = self.DEFAULT_EMBEDDINGS_SIZE
+ self.CODE_VECTOR_SIZE = self.context_vector_size
+ self.TARGET_EMBEDDINGS_SIZE = self.CODE_VECTOR_SIZE
+ self.DROPOUT_KEEP_RATE = 0.75
+ self.SEPARATE_OOV_AND_PAD = False
+
+ def load_from_args(self):
+ args = self.arguments_parser().parse_args()
+ # Automatically filled, do not edit:
+ self.PREDICT = args.predict
+ self.MODEL_SAVE_PATH = args.save_path
+ self.MODEL_LOAD_PATH = args.load_path
+ self.TRAIN_DATA_PATH_PREFIX = args.data_path
+ self.TEST_DATA_PATH = args.test_path
+ self.RELEASE = args.release
+ self.EXPORT_CODE_VECTORS = args.export_code_vectors
+ self.SAVE_W2V = args.save_w2v
+ self.SAVE_T2V = args.save_t2v
+ self.VERBOSE_MODE = args.verbose_mode
+ self.LOGS_PATH = args.logs_path
+ self.DL_FRAMEWORK = 'tensorflow' if not args.dl_framework else args.dl_framework
+ self.USE_TENSORBOARD = args.use_tensorboard
+
+ def __init__(self, set_defaults: bool = False, load_from_args: bool = False, verify: bool = False):
+ self.NUM_TRAIN_EPOCHS: int = 0
+ self.SAVE_EVERY_EPOCHS: int = 0
+ self.TRAIN_BATCH_SIZE: int = 0
+ self.TEST_BATCH_SIZE: int = 0
+ self.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION: int = 0
+ self.NUM_BATCHES_TO_LOG_PROGRESS: int = 0
+ self.NUM_TRAIN_BATCHES_TO_EVALUATE: int = 0
+ self.READER_NUM_PARALLEL_BATCHES: int = 0
+ self.SHUFFLE_BUFFER_SIZE: int = 0
+ self.CSV_BUFFER_SIZE: int = 0
+ self.MAX_TO_KEEP: int = 0
+
+ # model hyper-params
+ self.MAX_CONTEXTS: int = 0
+ self.MAX_TOKEN_VOCAB_SIZE: int = 0
+ self.MAX_TARGET_VOCAB_SIZE: int = 0
+ self.MAX_PATH_VOCAB_SIZE: int = 0
+ self.DEFAULT_EMBEDDINGS_SIZE: int = 0
+ self.TOKEN_EMBEDDINGS_SIZE: int = 0
+ self.PATH_EMBEDDINGS_SIZE: int = 0
+ self.CODE_VECTOR_SIZE: int = 0
+ self.TARGET_EMBEDDINGS_SIZE: int = 0
+ self.DROPOUT_KEEP_RATE: float = 0
+ self.SEPARATE_OOV_AND_PAD: bool = False
+
+ # Automatically filled by `args`.
+ self.PREDICT: bool = False # TODO: update README;
+ self.MODEL_SAVE_PATH: Optional[str] = None
+ self.MODEL_LOAD_PATH: Optional[str] = None
+ self.TRAIN_DATA_PATH_PREFIX: Optional[str] = None
+ self.TEST_DATA_PATH: Optional[str] = None
+ self.RELEASE: bool = False
+ self.EXPORT_CODE_VECTORS: bool = False
+ self.SAVE_W2V: Optional[str] = None # TODO: update README;
+ self.SAVE_T2V: Optional[str] = None # TODO: update README;
+ self.VERBOSE_MODE: int = 0
+ self.LOGS_PATH: Optional[str] = None
+ self.DL_FRAMEWORK: str = '' # in {'keras', 'tensorflow'}
+ self.USE_TENSORBOARD: bool = False
+
+ # Automatically filled by `Code2VecModelBase._init_num_of_examples()`.
+ self.NUM_TRAIN_EXAMPLES: int = 0
+ self.NUM_TEST_EXAMPLES: int = 0
+
+ self.__logger: Optional[logging.Logger] = None
+
+ if set_defaults:
+ self.set_defaults()
+ if load_from_args:
+ self.load_from_args()
+ if verify:
+ self.verify()
+
+ @property
+ def context_vector_size(self) -> int:
+ # The context vector is actually a concatenation of the embedded
+ # source & target vectors and the embedded path vector.
+ return self.PATH_EMBEDDINGS_SIZE + 2 * self.TOKEN_EMBEDDINGS_SIZE
+
+ @property
+ def is_training(self) -> bool:
+ return bool(self.TRAIN_DATA_PATH_PREFIX)
+
+ @property
+ def is_loading(self) -> bool:
+ return bool(self.MODEL_LOAD_PATH)
+
+ @property
+ def is_saving(self) -> bool:
+ return bool(self.MODEL_SAVE_PATH)
+
+ @property
+ def is_testing(self) -> bool:
+ return bool(self.TEST_DATA_PATH)
+
+ @property
+ def train_steps_per_epoch(self) -> int:
+ return ceil(self.NUM_TRAIN_EXAMPLES / self.TRAIN_BATCH_SIZE) if self.TRAIN_BATCH_SIZE else 0
+
+ @property
+ def test_steps(self) -> int:
+ return ceil(self.NUM_TEST_EXAMPLES / self.TEST_BATCH_SIZE) if self.TEST_BATCH_SIZE else 0
+
+ def data_path(self, is_evaluating: bool = False):
+ return self.TEST_DATA_PATH if is_evaluating else self.train_data_path
+
+ def batch_size(self, is_evaluating: bool = False):
+ return self.TEST_BATCH_SIZE if is_evaluating else self.TRAIN_BATCH_SIZE # take min with NUM_TRAIN_EXAMPLES?
+
+ @property
+ def train_data_path(self) -> Optional[str]:
+ if not self.is_training:
+ return None
+ return '{}.train.c2v'.format(self.TRAIN_DATA_PATH_PREFIX)
+
+ @property
+ def word_freq_dict_path(self) -> Optional[str]:
+ if not self.is_training:
+ return None
+ return '{}.dict.c2v'.format(self.TRAIN_DATA_PATH_PREFIX)
+
+ @classmethod
+ def get_vocabularies_path_from_model_path(cls, model_file_path: str) -> str:
+ vocabularies_save_file_name = "dictionaries.bin"
+ return '/'.join(model_file_path.split('/')[:-1] + [vocabularies_save_file_name])
+
+ @classmethod
+ def get_entire_model_path(cls, model_path: str) -> str:
+ return model_path + '__entire-model'
+
+ @classmethod
+ def get_model_weights_path(cls, model_path: str) -> str:
+ return model_path + '__only-weights'
+
+ @property
+ def entire_model_load_path(self) -> Optional[str]:
+ if not self.is_loading:
+ return None
+ return self.get_entire_model_path(self.MODEL_LOAD_PATH)
+
+ @property
+ def model_weights_load_path(self) -> Optional[str]:
+ if not self.is_loading:
+ return None
+ return self.get_model_weights_path(self.MODEL_LOAD_PATH)
+
+ @property
+ def entire_model_save_path(self) -> Optional[str]:
+ if not self.is_saving:
+ return None
+ return self.get_entire_model_path(self.MODEL_SAVE_PATH)
+
+ @property
+ def model_weights_save_path(self) -> Optional[str]:
+ if not self.is_saving:
+ return None
+ return self.get_model_weights_path(self.MODEL_SAVE_PATH)
+
+ def verify(self):
+ if not self.is_training and not self.is_loading:
+ raise ValueError("Must train or load a model.")
+ if self.DL_FRAMEWORK not in {'tensorflow', 'keras'}:
+ raise ValueError("config.DL_FRAMEWORK must be in {'tensorflow', 'keras'}.")
+
+ def __iter__(self):
+ for attr_name in dir(self):
+ if attr_name.startswith("__"):
+ continue
+ try:
+ attr_value = getattr(self, attr_name, None)
+ except:
+ attr_value = None
+ if callable(attr_value):
+ continue
+ yield attr_name, attr_value
+
+ def get_logger(self) -> logging.Logger:
+ if self.__logger is None:
+ self.__logger = logging.getLogger('code2vec')
+ self.__logger.setLevel(logging.INFO)
+ self.__logger.handlers = []
+ self.__logger.propagate = 0
+
+ formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
+
+ if self.VERBOSE_MODE >= 1:
+ ch = logging.StreamHandler(sys.stdout)
+ ch.setLevel(logging.INFO)
+ ch.setFormatter(formatter)
+ self.__logger.addHandler(ch)
+
+ if self.LOGS_PATH:
+ fh = logging.FileHandler(self.LOGS_PATH)
+ fh.setLevel(logging.INFO)
+ fh.setFormatter(formatter)
+ self.__logger.addHandler(fh)
+
+ return self.__logger
+
+ def log(self, msg):
+ self.get_logger().info(msg)
diff --git a/interactive_predict.py b/interactive_predict.py
index f5986d0..78aac6c 100644
--- a/interactive_predict.py
+++ b/interactive_predict.py
@@ -40,9 +40,11 @@ def predict(self):
except ValueError as e:
print(e)
continue
- results, code_vectors = self.model.predict(predict_lines)
- prediction_results = common.parse_results(results, hash_to_string_dict, topk=SHOW_TOP_CONTEXTS)
- for i, method_prediction in enumerate(prediction_results):
+ raw_prediction_results = self.model.predict(predict_lines)
+ method_prediction_results = common.parse_prediction_results(
+ raw_prediction_results, hash_to_string_dict,
+ self.model.vocabs.target_vocab.special_words, topk=SHOW_TOP_CONTEXTS)
+ for raw_prediction, method_prediction in zip(raw_prediction_results, method_prediction_results):
print('Original name:\t' + method_prediction.original_name)
for name_prob_pair in method_prediction.predictions:
print('\t(%f) predicted: %s' % (name_prob_pair['probability'], name_prob_pair['name']))
@@ -52,4 +54,4 @@ def predict(self):
attention_obj['score'], attention_obj['token1'], attention_obj['path'], attention_obj['token2']))
if self.config.EXPORT_CODE_VECTORS:
print('Code vector:')
- print(' '.join(map(str, code_vectors[i])))
+ print(' '.join(map(str, raw_prediction.code_vector)))
diff --git a/keras_attention_layer.py b/keras_attention_layer.py
new file mode 100644
index 0000000..fb29bdc
--- /dev/null
+++ b/keras_attention_layer.py
@@ -0,0 +1,66 @@
+import tensorflow as tf
+from tensorflow.python import keras
+from tensorflow.python.keras.layers import Layer
+import tensorflow.python.keras.backend as K
+from typing import Optional
+
+
+class AttentionLayer(Layer):
+ def __init__(self, **kwargs):
+ super(AttentionLayer, self).__init__(**kwargs)
+
+ def build(self, inputs_shape):
+ inputs_shape = inputs_shape if isinstance(inputs_shape, list) else [inputs_shape]
+
+ if len(inputs_shape) < 1 or len(inputs_shape) > 2:
+ raise ValueError("AttentionLayer expect one or two inputs.")
+
+ # The first (and required) input is the actual input to the layer
+ input_shape = inputs_shape[0]
+
+ # Expected input shape consists of a triplet: (batch, input_length, input_dim)
+ if len(input_shape) != 3:
+ raise ValueError("Input shape for AttentionLayer should be of 3 dimension.")
+
+ self.input_length = int(input_shape[1])
+ self.input_dim = int(input_shape[2])
+ attention_param_shape = (self.input_dim, 1)
+
+ self.attention_param = self.add_weight(
+ name='attention_param',
+ shape=attention_param_shape,
+ initializer='uniform',
+ trainable=True,
+ dtype=tf.float32)
+ super(AttentionLayer, self).build(input_shape)
+
+ def call(self, inputs, **kwargs):
+ inputs = inputs if isinstance(inputs, list) else [inputs]
+
+ if len(inputs) < 1 or len(inputs) > 2:
+ raise ValueError("AttentionLayer expect one or two inputs.")
+
+ actual_input = inputs[0]
+ mask = inputs[1] if len(inputs) > 1 else None
+ if mask is not None and not (((len(mask.shape) == 3 and mask.shape[2] == 1) or len(mask.shape) == 2)
+ and mask.shape[1] == self.input_length):
+ raise ValueError("`mask` should be of shape (batch, input_length) or (batch, input_length, 1) "
+ "when calling an AttentionLayer.")
+
+ assert actual_input.shape[-1] == self.attention_param.shape[0]
+
+ # (batch, input_length, input_dim) * (input_dim, 1) ==> (batch, input_length, 1)
+ attention_weights = K.dot(actual_input, self.attention_param)
+
+ if mask is not None:
+ if len(mask.shape) == 2:
+ mask = K.expand_dims(mask, axis=2) # (batch, input_length, 1)
+ mask = K.log(mask)
+ attention_weights += mask
+
+ attention_weights = K.softmax(attention_weights, axis=1) # (batch, input_length, 1)
+ result = K.sum(actual_input * attention_weights, axis=1) # (batch, input_length) [multiplication uses broadcast]
+ return result, attention_weights
+
+ def compute_output_shape(self, input_shape):
+ return input_shape[0], input_shape[2] # (batch, input_dim)
diff --git a/keras_checkpoint_saver_callback.py b/keras_checkpoint_saver_callback.py
new file mode 100644
index 0000000..e23ebac
--- /dev/null
+++ b/keras_checkpoint_saver_callback.py
@@ -0,0 +1,127 @@
+import time
+import datetime
+import logging
+from typing import Optional, Dict
+from collections import defaultdict
+import tensorflow as tf
+from tensorflow.python import keras
+from tensorflow.python.keras.callbacks import Callback
+
+from config import Config
+
+
+class ModelTrainingStatus:
+ def __init__(self):
+ self.nr_epochs_trained: int = 0
+ self.trained_full_last_epoch: bool = False
+
+
+class ModelTrainingStatusTrackerCallback(Callback):
+ def __init__(self, training_status: ModelTrainingStatus):
+ self.training_status: ModelTrainingStatus = training_status
+ super(ModelTrainingStatusTrackerCallback, self).__init__()
+
+ def on_epoch_begin(self, epoch, logs=None):
+ self.training_status.trained_full_last_epoch = False
+
+ def on_epoch_end(self, epoch, logs=None):
+ assert self.training_status.nr_epochs_trained == epoch
+ self.training_status.nr_epochs_trained += 1
+ self.training_status.trained_full_last_epoch = True
+
+
+class ModelCheckpointSaverCallback(Callback):
+ """
+ @model_wrapper should have a `.save()` method.
+ """
+ def __init__(self, model_wrapper, nr_epochs_to_save: int = 1,
+ logger: logging.Logger = None):
+ self.model_wrapper = model_wrapper
+ self.nr_epochs_to_save: int = nr_epochs_to_save
+ self.logger = logger if logger is not None else logging.getLogger()
+
+ self.last_saved_epoch: Optional[int] = None
+ super(ModelCheckpointSaverCallback, self).__init__()
+
+ def on_epoch_begin(self, epoch, logs=None):
+ if self.last_saved_epoch is None:
+ self.last_saved_epoch = (epoch + 1) - 1
+
+ def on_epoch_end(self, epoch, logs=None):
+ nr_epochs_trained = epoch + 1
+ nr_non_saved_epochs = nr_epochs_trained - self.last_saved_epoch
+ if nr_non_saved_epochs >= self.nr_epochs_to_save:
+ self.logger.info('Saving model after {} epochs.'.format(nr_epochs_trained))
+ self.model_wrapper.save()
+ self.logger.info('Done saving model.')
+ self.last_saved_epoch = nr_epochs_trained
+
+
+class MultiBatchCallback(Callback):
+ def __init__(self, multi_batch_size: int, average_logs: bool = False):
+ self.multi_batch_size = multi_batch_size
+ self.average_logs = average_logs
+ self._multi_batch_start_time: int = 0
+ self._multi_batch_logs_sum: Dict[str, float] = defaultdict(float)
+ super(MultiBatchCallback, self).__init__()
+
+ def on_batch_begin(self, batch, logs=None):
+ if self.multi_batch_size == 1 or (batch + 1) % self.multi_batch_size == 1:
+ self._multi_batch_start_time = time.time()
+ if self.average_logs:
+ self._multi_batch_logs_sum = defaultdict(float)
+
+ def on_batch_end(self, batch, logs=None):
+ if self.average_logs:
+ assert isinstance(logs, dict)
+ for log_key, log_value in logs.items():
+ self._multi_batch_logs_sum[log_key] += log_value
+ if self.multi_batch_size == 1 or (batch + 1) % self.multi_batch_size == 0:
+ multi_batch_elapsed = time.time() - self._multi_batch_start_time
+ if self.average_logs:
+ multi_batch_logs = {log_key: log_value / self.multi_batch_size
+ for log_key, log_value in self._multi_batch_logs_sum.items()}
+ else:
+ multi_batch_logs = logs
+ self.on_multi_batch_end(batch, multi_batch_logs, multi_batch_elapsed)
+
+ def on_multi_batch_end(self, batch, logs, multi_batch_elapsed):
+ pass
+
+
+class ModelTrainingProgressLoggerCallback(MultiBatchCallback):
+ def __init__(self, config: Config, training_status: ModelTrainingStatus):
+ self.config = config
+ self.training_status = training_status
+ self.avg_throughput: Optional[float] = None
+ super(ModelTrainingProgressLoggerCallback, self).__init__(
+ self.config.NUM_BATCHES_TO_LOG_PROGRESS, average_logs=True)
+
+ def on_train_begin(self, logs=None):
+ self.config.log('Starting training...')
+
+ def on_epoch_end(self, epoch, logs=None):
+ self.config.log('Completed epoch #{}: {}'.format(epoch + 1, logs))
+
+ def on_multi_batch_end(self, batch, logs, multi_batch_elapsed):
+ nr_samples_in_multi_batch = self.config.TRAIN_BATCH_SIZE * \
+ self.config.NUM_BATCHES_TO_LOG_PROGRESS
+ throughput = nr_samples_in_multi_batch / multi_batch_elapsed
+ if self.avg_throughput is None:
+ self.avg_throughput = throughput
+ else:
+ self.avg_throughput = 0.5 * throughput + 0.5 * self.avg_throughput
+ remained_batches = self.config.train_steps_per_epoch - (batch + 1)
+ remained_samples = remained_batches * self.config.TRAIN_BATCH_SIZE
+ remained_time_sec = remained_samples / self.avg_throughput
+
+ self.config.log(
+ 'Train: during epoch #{epoch} batch {batch}/{tot_batches} ({batch_precision}%) -- '
+ 'throughput (#samples/sec): {throughput} -- epoch ETA: {epoch_ETA} -- loss: {loss:.4f}'.format(
+ epoch=self.training_status.nr_epochs_trained + 1,
+ batch=batch + 1,
+ batch_precision=int(((batch + 1) / self.config.train_steps_per_epoch) * 100),
+ tot_batches=self.config.train_steps_per_epoch,
+ throughput=int(throughput),
+ epoch_ETA=str(datetime.timedelta(seconds=int(remained_time_sec))),
+ loss=logs['loss']))
diff --git a/keras_model.py b/keras_model.py
new file mode 100644
index 0000000..dc42cbe
--- /dev/null
+++ b/keras_model.py
@@ -0,0 +1,415 @@
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras.layers import Input, Embedding, Concatenate, Dropout, TimeDistributed, Dense
+from tensorflow.keras.callbacks import Callback
+import tensorflow.keras.backend as K
+from tensorflow.keras.metrics import sparse_top_k_categorical_accuracy
+
+from path_context_reader import PathContextReader, ModelInputTensorsFormer, ReaderInputTensors, EstimatorAction
+import os
+import numpy as np
+from functools import partial
+from typing import List, Optional, Iterable, Union, Callable, Dict
+from collections import namedtuple
+import time
+import datetime
+from vocabularies import VocabType
+from keras_attention_layer import AttentionLayer
+from keras_topk_word_predictions_layer import TopKWordPredictionsLayer
+from keras_words_subtoken_metrics import WordsSubtokenPrecisionMetric, WordsSubtokenRecallMetric, WordsSubtokenF1Metric
+from config import Config
+from common import common
+from model_base import Code2VecModelBase, ModelEvaluationResults, ModelPredictionResults
+from keras_checkpoint_saver_callback import ModelTrainingStatus, ModelTrainingStatusTrackerCallback,\
+ ModelCheckpointSaverCallback, MultiBatchCallback, ModelTrainingProgressLoggerCallback
+
+
+class Code2VecModel(Code2VecModelBase):
+ def __init__(self, config: Config):
+ self.keras_train_model: Optional[keras.Model] = None
+ self.keras_eval_model: Optional[keras.Model] = None
+ self.keras_model_predict_function: Optional[K.GraphExecutionFunction] = None
+ self.training_status: ModelTrainingStatus = ModelTrainingStatus()
+ self._checkpoint: Optional[tf.train.Checkpoint] = None
+ self._checkpoint_manager: Optional[tf.train.CheckpointManager] = None
+ super(Code2VecModel, self).__init__(config)
+
+ def _create_keras_model(self):
+ # Each input sample consists of a bag of x`MAX_CONTEXTS` tuples (source_terminal, path, target_terminal).
+ # The valid mask indicates for each context whether it actually exists or it is just a padding.
+ path_source_token_input = Input((self.config.MAX_CONTEXTS,), dtype=tf.int32)
+ path_input = Input((self.config.MAX_CONTEXTS,), dtype=tf.int32)
+ path_target_token_input = Input((self.config.MAX_CONTEXTS,), dtype=tf.int32)
+ context_valid_mask = Input((self.config.MAX_CONTEXTS,))
+
+ # Input paths are indexes, we embed these here.
+ paths_embedded = Embedding(
+ self.vocabs.path_vocab.size, self.config.PATH_EMBEDDINGS_SIZE, name='path_embedding')(path_input)
+
+ # Input terminals are indexes, we embed these here.
+ token_embedding_shared_layer = Embedding(
+ self.vocabs.token_vocab.size, self.config.TOKEN_EMBEDDINGS_SIZE, name='token_embedding')
+ path_source_token_embedded = token_embedding_shared_layer(path_source_token_input)
+ path_target_token_embedded = token_embedding_shared_layer(path_target_token_input)
+
+ # `Context` is a concatenation of the 2 terminals & path embedding.
+ # Each context is a vector of size 3 * EMBEDDINGS_SIZE.
+ context_embedded = Concatenate()([path_source_token_embedded, paths_embedded, path_target_token_embedded])
+ context_embedded = Dropout(1 - self.config.DROPOUT_KEEP_RATE)(context_embedded)
+
+ # Lets get dense: Apply a dense layer for each context vector (using same weights for all of the context).
+ context_after_dense = TimeDistributed(
+ Dense(self.config.CODE_VECTOR_SIZE, use_bias=False, activation='tanh'))(context_embedded)
+
+ # The final code vectors are received by applying attention to the "densed" context vectors.
+ code_vectors, attention_weights = AttentionLayer(name='attention')(
+ [context_after_dense, context_valid_mask])
+
+ # "Decode": Now we use another dense layer to get the target word embedding from each code vector.
+ target_index = Dense(
+ self.vocabs.target_vocab.size, use_bias=False, activation='softmax', name='target_index')(code_vectors)
+
+ # Wrap the layers into a Keras model, using our subtoken-metrics and the CE loss.
+ inputs = [path_source_token_input, path_input, path_target_token_input, context_valid_mask]
+ self.keras_train_model = keras.Model(inputs=inputs, outputs=target_index)
+
+ # Actual target word predictions (as strings). Used as a second output layer.
+ # Used for predict() and for the evaluation metrics calculations.
+ topk_predicted_words, topk_predicted_words_scores = TopKWordPredictionsLayer(
+ self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION,
+ self.vocabs.target_vocab.get_index_to_word_lookup_table(),
+ name='target_string')(target_index)
+
+ # We use another dedicated Keras model for evaluation.
+ # The evaluation model outputs the `topk_predicted_words` as a 2nd output.
+ # The separation between train and eval models is for efficiency.
+ self.keras_eval_model = keras.Model(
+ inputs=inputs, outputs=[target_index, topk_predicted_words], name="code2vec-keras-model")
+
+ # We use another dedicated Keras function to produce predictions.
+ # It have additional outputs than the original model.
+ # It is based on the trained layers of the original model and uses their weights.
+ predict_outputs = tuple(KerasPredictionModelOutput(
+ target_index=target_index, code_vectors=code_vectors, attention_weights=attention_weights,
+ topk_predicted_words=topk_predicted_words, topk_predicted_words_scores=topk_predicted_words_scores))
+ self.keras_model_predict_function = K.function(inputs=inputs, outputs=predict_outputs)
+
+ def _create_metrics_for_keras_eval_model(self) -> Dict[str, List[Union[Callable, keras.metrics.Metric]]]:
+ top_k_acc_metrics = []
+ for k in range(1, self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION + 1):
+ top_k_acc_metric = partial(
+ sparse_top_k_categorical_accuracy, k=k)
+ top_k_acc_metric.__name__ = 'top{k}_acc'.format(k=k)
+ top_k_acc_metrics.append(top_k_acc_metric)
+ predicted_words_filters = [
+ lambda word_strings: tf.not_equal(word_strings, self.vocabs.target_vocab.special_words.OOV),
+ lambda word_strings: tf.strings.regex_full_match(word_strings, r'^[a-zA-Z\|]+$')
+ ]
+ words_subtokens_metrics = [
+ WordsSubtokenPrecisionMetric(predicted_words_filters=predicted_words_filters, name='subtoken_precision'),
+ WordsSubtokenRecallMetric(predicted_words_filters=predicted_words_filters, name='subtoken_recall'),
+ WordsSubtokenF1Metric(predicted_words_filters=predicted_words_filters, name='subtoken_f1')
+ ]
+ return {'target_index': top_k_acc_metrics, 'target_string': words_subtokens_metrics}
+
+ @classmethod
+ def _create_optimizer(cls):
+ return tf.optimizers.Adam()
+
+ def _compile_keras_model(self, optimizer=None):
+ if optimizer is None:
+ optimizer = self.keras_train_model.optimizer
+ if optimizer is None:
+ optimizer = self._create_optimizer()
+
+ def zero_loss(true_word, topk_predictions):
+ return tf.constant(0.0, shape=(), dtype=tf.float32)
+
+ self.keras_train_model.compile(
+ loss='sparse_categorical_crossentropy',
+ optimizer=optimizer)
+
+ self.keras_eval_model.compile(
+ loss={'target_index': 'sparse_categorical_crossentropy', 'target_string': zero_loss},
+ optimizer=optimizer,
+ metrics=self._create_metrics_for_keras_eval_model())
+
+ def _create_data_reader(self, estimator_action: EstimatorAction, repeat_endlessly: bool = False):
+ return PathContextReader(
+ vocabs=self.vocabs,
+ config=self.config,
+ model_input_tensors_former=_KerasModelInputTensorsFormer(estimator_action=estimator_action),
+ estimator_action=estimator_action,
+ repeat_endlessly=repeat_endlessly)
+
+ def _create_train_callbacks(self) -> List[Callback]:
+ # TODO: do we want to use early stopping? if so, use the right chechpoint manager and set the correct
+ # `monitor` quantity (example: monitor='val_acc', mode='max')
+
+ keras_callbacks = [
+ ModelTrainingStatusTrackerCallback(self.training_status),
+ ModelTrainingProgressLoggerCallback(self.config, self.training_status),
+ ]
+ if self.config.is_saving:
+ keras_callbacks.append(ModelCheckpointSaverCallback(
+ self, self.config.SAVE_EVERY_EPOCHS, self.logger))
+ if self.config.is_testing:
+ keras_callbacks.append(ModelEvaluationCallback(self))
+ if self.config.USE_TENSORBOARD:
+ log_dir = "logs/scalars/train_" + common.now_str()
+ tensorboard_callback = keras.callbacks.TensorBoard(
+ log_dir=log_dir,
+ update_freq=self.config.NUM_BATCHES_TO_LOG_PROGRESS * self.config.TRAIN_BATCH_SIZE)
+ keras_callbacks.append(tensorboard_callback)
+ return keras_callbacks
+
+ def train(self):
+ # initialize the input pipeline reader
+ train_data_input_reader = self._create_data_reader(estimator_action=EstimatorAction.Train)
+
+ training_history = self.keras_train_model.fit(
+ train_data_input_reader.get_dataset(),
+ steps_per_epoch=self.config.train_steps_per_epoch,
+ epochs=self.config.NUM_TRAIN_EPOCHS,
+ initial_epoch=self.training_status.nr_epochs_trained,
+ verbose=self.config.VERBOSE_MODE,
+ callbacks=self._create_train_callbacks())
+
+ self.log(training_history)
+
+ def evaluate(self) -> Optional[ModelEvaluationResults]:
+ val_data_input_reader = self._create_data_reader(estimator_action=EstimatorAction.Evaluate)
+ eval_res = self.keras_eval_model.evaluate(
+ val_data_input_reader.get_dataset(),
+ steps=self.config.test_steps,
+ verbose=self.config.VERBOSE_MODE)
+ k = self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION
+ return ModelEvaluationResults(
+ topk_acc=eval_res[3:k+3],
+ subtoken_precision=eval_res[k+3],
+ subtoken_recall=eval_res[k+4],
+ subtoken_f1=eval_res[k+5],
+ loss=eval_res[1]
+ )
+
+ def predict(self, predict_data_rows: Iterable[str]) -> List[ModelPredictionResults]:
+ predict_input_reader = self._create_data_reader(estimator_action=EstimatorAction.Predict)
+ input_iterator = predict_input_reader.process_and_iterate_input_from_data_lines(predict_data_rows)
+ all_model_prediction_results = []
+ for input_row in input_iterator:
+ # perform the actual prediction and get raw results.
+ input_for_predict = input_row[0][:4] # we want only the relevant input vectors (w.o. the targets).
+ prediction_results = self.keras_model_predict_function(input_for_predict)
+
+ # make `input_row` and `prediction_results` easy to read (by accessing named fields).
+ prediction_results = KerasPredictionModelOutput(
+ *common.squeeze_single_batch_dimension_for_np_arrays(prediction_results))
+ input_row = _KerasModelInputTensorsFormer(
+ estimator_action=EstimatorAction.Predict).from_model_input_form(input_row)
+ input_row = ReaderInputTensors(*common.squeeze_single_batch_dimension_for_np_arrays(input_row))
+
+ # calculate the attention weight for each context
+ attention_per_context = self._get_attention_weight_per_context(
+ path_source_strings=input_row.path_source_token_strings,
+ path_strings=input_row.path_strings,
+ path_target_strings=input_row.path_target_token_strings,
+ attention_weights=prediction_results.attention_weights
+ )
+
+ # store the calculated prediction results in the wanted format.
+ model_prediction_results = ModelPredictionResults(
+ original_name=common.binary_to_string(input_row.target_string.item()),
+ topk_predicted_words=common.binary_to_string_list(prediction_results.topk_predicted_words),
+ topk_predicted_words_scores=prediction_results.topk_predicted_words_scores,
+ attention_per_context=attention_per_context,
+ code_vector=prediction_results.code_vectors)
+ all_model_prediction_results.append(model_prediction_results)
+
+ return all_model_prediction_results
+
+ def _save_inner_model(self, path):
+ if self.config.RELEASE:
+ self.keras_train_model.save_weights(self.config.get_model_weights_path(path))
+ else:
+ self._get_checkpoint_manager().save(checkpoint_number=self.training_status.nr_epochs_trained)
+
+ def _create_inner_model(self):
+ self._create_keras_model()
+ self._compile_keras_model()
+ self.keras_train_model.summary(print_fn=self.log)
+
+ def _load_inner_model(self):
+ self._create_keras_model()
+ self._compile_keras_model()
+
+ # when loading the model for further training, we must use the full saved model file (not just weights).
+ # we load the entire model if we must to or if there is no model weights file to load.
+ must_use_entire_model = self.config.is_training
+ entire_model_exists = os.path.exists(self.config.entire_model_load_path)
+ model_weights_exist = os.path.exists(self.config.model_weights_load_path)
+ use_full_model = must_use_entire_model or not model_weights_exist
+
+ if must_use_entire_model and not entire_model_exists:
+ raise ValueError(
+ "There is no model at path `{model_file_path}`. When loading the model for further training, "
+ "we must use an entire saved model file (not just weights).".format(
+ model_file_path=self.config.entire_model_load_path))
+ if not entire_model_exists and not model_weights_exist:
+ raise ValueError(
+ "There is no entire model to load at path `{entire_model_path}`, "
+ "and there is no model weights file to load at path `{model_weights_path}`.".format(
+ entire_model_path=self.config.entire_model_load_path,
+ model_weights_path=self.config.model_weights_load_path))
+
+ if use_full_model:
+ self.log('Loading entire model from path `{}`.'.format(self.config.entire_model_load_path))
+ latest_checkpoint = tf.train.latest_checkpoint(self.config.entire_model_load_path)
+ if latest_checkpoint is None:
+ raise ValueError("Failed to load model: Model latest checkpoint is not found.")
+ self.log('Loading latest checkpoint `{}`.'.format(latest_checkpoint))
+ status = self._get_checkpoint().restore(latest_checkpoint)
+ status.initialize_or_restore()
+ # FIXME: are we sure we have to re-compile here? I turned it off to save the optimizer state
+ # self._compile_keras_model() # We have to re-compile because we also recovered the `tf.train.AdamOptimizer`.
+ self.training_status.nr_epochs_trained = int(latest_checkpoint.split('-')[-1])
+ else:
+ # load the "released" model (only the weights).
+ self.log('Loading model weights from path `{}`.'.format(self.config.model_weights_load_path))
+ self.keras_train_model.load_weights(self.config.model_weights_load_path)
+
+ self.keras_train_model.summary(print_fn=self.log)
+
+ def _get_checkpoint(self):
+ assert self.keras_train_model is not None and self.keras_train_model.optimizer is not None
+ if self._checkpoint is None:
+ # TODO: we would like to save (& restore) the `nr_epochs_trained`.
+ self._checkpoint = tf.train.Checkpoint(
+ # nr_epochs_trained=tf.Variable(self.training_status.nr_epochs_trained, name='nr_epochs_trained'),
+ optimizer=self.keras_train_model.optimizer, model=self.keras_train_model)
+ return self._checkpoint
+
+ def _get_checkpoint_manager(self):
+ if self._checkpoint_manager is None:
+ self._checkpoint_manager = tf.train.CheckpointManager(
+ self._get_checkpoint(), self.config.entire_model_save_path,
+ max_to_keep=self.config.MAX_TO_KEEP)
+ return self._checkpoint_manager
+
+ def _get_vocab_embedding_as_np_array(self, vocab_type: VocabType) -> np.ndarray:
+ assert vocab_type in VocabType
+
+ vocab_type_to_embedding_layer_mapping = {
+ VocabType.Target: 'target_index',
+ VocabType.Token: 'token_embedding',
+ VocabType.Path: 'path_embedding'
+ }
+ embedding_layer_name = vocab_type_to_embedding_layer_mapping[vocab_type]
+ weight = np.array(self.keras_train_model.get_layer(embedding_layer_name).get_weights()[0])
+ assert len(weight.shape) == 2
+
+ # token, path have an actual `Embedding` layers, but target have just a `Dense` layer.
+ # hence, transpose the weight when necessary.
+ assert self.vocabs.get(vocab_type).size in weight.shape
+ if self.vocabs.get(vocab_type).size != weight.shape[0]:
+ weight = np.transpose(weight)
+
+ return weight
+
+ def _create_lookup_tables(self):
+ PathContextReader.create_needed_vocabs_lookup_tables(self.vocabs)
+ self.log('Lookup tables created.')
+
+ def _initialize(self):
+ self._create_lookup_tables()
+
+
+class ModelEvaluationCallback(MultiBatchCallback):
+ """
+ This callback is passed to the `model.fit()` call.
+ It is responsible to trigger model evaluation during the training.
+ The reason we use a callback and not just passing validation data to `model.fit()` is because:
+ (i) the training model is different than the evaluation model for efficiency considerations;
+ (ii) we want to control the logging format;
+ (iii) we want the evaluation to occur once per 1K batches (rather than only once per epoch).
+ """
+
+ def __init__(self, code2vec_model: 'Code2VecModel'):
+ self.code2vec_model = code2vec_model
+ self.avg_eval_duration: Optional[int] = None
+ super(ModelEvaluationCallback, self).__init__(self.code2vec_model.config.NUM_TRAIN_BATCHES_TO_EVALUATE)
+
+ def on_epoch_end(self, epoch, logs=None):
+ self.perform_evaluation()
+
+ def on_multi_batch_end(self, batch, logs, multi_batch_elapsed):
+ self.perform_evaluation()
+
+ def perform_evaluation(self):
+ if self.avg_eval_duration is None:
+ self.code2vec_model.log('Evaluating...')
+ else:
+ self.code2vec_model.log('Evaluating... (takes ~{})'.format(
+ str(datetime.timedelta(seconds=int(self.avg_eval_duration)))))
+ eval_start_time = time.time()
+ evaluation_results = self.code2vec_model.evaluate()
+ eval_duration = time.time() - eval_start_time
+ if self.avg_eval_duration is None:
+ self.avg_eval_duration = eval_duration
+ else:
+ self.avg_eval_duration = eval_duration * 0.5 + self.avg_eval_duration * 0.5
+ self.code2vec_model.log('Done evaluating (took {}). Evaluation results:'.format(
+ str(datetime.timedelta(seconds=int(eval_duration)))))
+
+ self.code2vec_model.log(
+ ' loss: {loss:.4f}, f1: {f1:.4f}, recall: {recall:.4f}, precision: {precision:.4f}'.format(
+ loss=evaluation_results.loss, f1=evaluation_results.subtoken_f1,
+ recall=evaluation_results.subtoken_recall, precision=evaluation_results.subtoken_precision))
+ top_k_acc_formated = ['top{}: {:.4f}'.format(i, acc) for i, acc in enumerate(evaluation_results.topk_acc, start=1)]
+ for top_k_acc_chunk in common.chunks(top_k_acc_formated, 5):
+ self.code2vec_model.log(' ' + (', '.join(top_k_acc_chunk)))
+
+
+class _KerasModelInputTensorsFormer(ModelInputTensorsFormer):
+ """
+ An instance of this class is passed to the reader in order to help the reader to construct the input
+ in the form that the model expects to receive it.
+ This class also enables conveniently & clearly access input parts by their field names.
+ eg: 'tensors.path_indices' instead if 'tensors[1]'.
+ This allows the input tensors to be passed as pure tuples along the computation graph, while the
+ python functions that construct the graph can easily (and clearly) access tensors.
+ """
+
+ def __init__(self, estimator_action: EstimatorAction):
+ self.estimator_action = estimator_action
+
+ def to_model_input_form(self, input_tensors: ReaderInputTensors):
+ inputs = (input_tensors.path_source_token_indices, input_tensors.path_indices,
+ input_tensors.path_target_token_indices, input_tensors.context_valid_mask)
+ if self.estimator_action.is_train:
+ targets = input_tensors.target_index
+ else:
+ targets = {'target_index': input_tensors.target_index, 'target_string': input_tensors.target_string}
+ if self.estimator_action.is_predict:
+ inputs += (input_tensors.path_source_token_strings, input_tensors.path_strings,
+ input_tensors.path_target_token_strings)
+ return inputs, targets
+
+ def from_model_input_form(self, input_row) -> ReaderInputTensors:
+ inputs, targets = input_row
+ return ReaderInputTensors(
+ path_source_token_indices=inputs[0],
+ path_indices=inputs[1],
+ path_target_token_indices=inputs[2],
+ context_valid_mask=inputs[3],
+ target_index=targets if self.estimator_action.is_train else targets['target_index'],
+ target_string=targets['target_string'] if not self.estimator_action.is_train else None,
+ path_source_token_strings=inputs[4] if self.estimator_action.is_predict else None,
+ path_strings=inputs[5] if self.estimator_action.is_predict else None,
+ path_target_token_strings=inputs[6] if self.estimator_action.is_predict else None
+ )
+
+
+"""Used for convenient-and-clear access to raw prediction result parts (by their names)."""
+KerasPredictionModelOutput = namedtuple(
+ 'KerasModelOutput', ['target_index', 'code_vectors', 'attention_weights',
+ 'topk_predicted_words', 'topk_predicted_words_scores'])
diff --git a/keras_topk_word_predictions_layer.py b/keras_topk_word_predictions_layer.py
new file mode 100644
index 0000000..89ed9de
--- /dev/null
+++ b/keras_topk_word_predictions_layer.py
@@ -0,0 +1,39 @@
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras.layers import Layer
+import tensorflow.keras.backend as K
+from collections import namedtuple
+
+
+TopKWordPredictionsLayerResult = namedtuple('TopKWordPredictionsLayerResult', ['words', 'scores'])
+
+
+class TopKWordPredictionsLayer(Layer):
+ def __init__(self,
+ top_k: int,
+ index_to_word_table: tf.lookup.StaticHashTable,
+ **kwargs):
+ kwargs['dtype'] = tf.string
+ kwargs['trainable'] = False
+ super(TopKWordPredictionsLayer, self).__init__(**kwargs)
+ self.top_k = top_k
+ self.index_to_word_table = index_to_word_table
+
+ def build(self, input_shape):
+ if len(input_shape) < 2:
+ raise ValueError("Input shape for TopKWordPredictionsLayer should be of >= 2 dimensions.")
+ if input_shape[-1] < self.top_k:
+ raise ValueError("Last dimension of input shape for TopKWordPredictionsLayer should be of >= `top_k`.")
+ super(TopKWordPredictionsLayer, self).build(input_shape)
+ self.trainable = False
+
+ def call(self, y_pred, **kwargs) -> TopKWordPredictionsLayerResult:
+ top_k_pred_scores, top_k_pred_indices = tf.nn.top_k(y_pred, k=self.top_k, sorted=True)
+ top_k_pred_indices = tf.cast(top_k_pred_indices, dtype=self.index_to_word_table.key_dtype)
+ top_k_pred_words = self.index_to_word_table.lookup(top_k_pred_indices)
+
+ return TopKWordPredictionsLayerResult(words=top_k_pred_words, scores=top_k_pred_scores)
+
+ def compute_output_shape(self, input_shape):
+ output_shape = tuple(input_shape[:-1]) + (self.top_k, )
+ return output_shape, output_shape
diff --git a/keras_word_prediction_layer.py b/keras_word_prediction_layer.py
new file mode 100644
index 0000000..69d37ba
--- /dev/null
+++ b/keras_word_prediction_layer.py
@@ -0,0 +1,57 @@
+import tensorflow as tf
+from tensorflow.python import keras
+from tensorflow.python.keras.layers import Layer
+import tensorflow.python.keras.backend as K
+from typing import Optional, List, Callable
+from functools import reduce
+from common import common
+
+
+class WordPredictionLayer(Layer):
+ FilterType = Callable[[tf.Tensor, tf.Tensor], tf.Tensor]
+
+ def __init__(self,
+ top_k: int,
+ index_to_word_table: tf.contrib.lookup.HashTable,
+ predicted_words_filters: Optional[List[FilterType]] = None,
+ **kwargs):
+ kwargs['dtype'] = tf.string
+ kwargs['trainable'] = False
+ super(WordPredictionLayer, self).__init__(**kwargs)
+ self.top_k = top_k
+ self.index_to_word_table = index_to_word_table
+ self.predicted_words_filters = predicted_words_filters
+
+ def build(self, input_shape):
+ if len(input_shape) != 2:
+ raise ValueError("Input shape for WordPredictionLayer should be of 2 dimension.")
+ super(WordPredictionLayer, self).build(input_shape)
+ self.trainable = False
+
+ def call(self, y_pred, **kwargs):
+ y_pred.shape.assert_has_rank(2)
+ top_k_pred_indices = tf.cast(tf.nn.top_k(y_pred, k=self.top_k).indices,
+ dtype=self.index_to_word_table.key_dtype)
+ predicted_target_words_strings = self.index_to_word_table.lookup(top_k_pred_indices)
+
+ # apply given filter
+ masks = []
+ if self.predicted_words_filters is not None:
+ masks = [fltr(top_k_pred_indices, predicted_target_words_strings) for fltr in self.predicted_words_filters]
+ if masks:
+ # assert all(mask.shape.assert_is_compatible_with(top_k_pred_indices) for mask in masks)
+ legal_predicted_target_words_mask = reduce(tf.logical_and, masks)
+ else:
+ legal_predicted_target_words_mask = tf.cast(tf.ones_like(top_k_pred_indices), dtype=tf.bool)
+
+ # the first legal predicted word is our prediction
+ first_legal_predicted_target_word_mask = common.tf_get_first_true(legal_predicted_target_words_mask)
+ first_legal_predicted_target_word_idx = tf.where(first_legal_predicted_target_word_mask)
+ first_legal_predicted_word_string = tf.gather_nd(predicted_target_words_strings,
+ first_legal_predicted_target_word_idx)
+
+ prediction = tf.reshape(first_legal_predicted_word_string, [-1])
+ return prediction
+
+ def compute_output_shape(self, input_shape):
+ return input_shape[0], # (batch,)
diff --git a/keras_words_subtoken_metrics.py b/keras_words_subtoken_metrics.py
new file mode 100644
index 0000000..3f01a9c
--- /dev/null
+++ b/keras_words_subtoken_metrics.py
@@ -0,0 +1,128 @@
+import tensorflow as tf
+import tensorflow.keras.backend as K
+
+import abc
+from typing import Optional, Callable, List
+from functools import reduce
+
+from common import common
+
+
+class WordsSubtokenMetricBase(tf.metrics.Metric):
+ FilterType = Callable[[tf.Tensor, tf.Tensor], tf.Tensor]
+
+ def __init__(self,
+ index_to_word_table: Optional[tf.lookup.StaticHashTable] = None,
+ topk_predicted_words=None,
+ predicted_words_filters: Optional[List[FilterType]] = None,
+ subtokens_delimiter: str = '|', name=None, dtype=None):
+ super(WordsSubtokenMetricBase, self).__init__(name=name, dtype=dtype)
+ self.tp = self.add_weight('true_positives', shape=(), initializer=tf.zeros_initializer)
+ self.fp = self.add_weight('false_positives', shape=(), initializer=tf.zeros_initializer)
+ self.fn = self.add_weight('false_negatives', shape=(), initializer=tf.zeros_initializer)
+ self.index_to_word_table = index_to_word_table
+ self.topk_predicted_words = topk_predicted_words
+ self.predicted_words_filters = predicted_words_filters
+ self.subtokens_delimiter = subtokens_delimiter
+
+ def _get_true_target_word_string(self, true_target_word):
+ if self.index_to_word_table is None:
+ return true_target_word
+ true_target_word_index = tf.cast(true_target_word, dtype=self.index_to_word_table.key_dtype)
+ return self.index_to_word_table.lookup(true_target_word_index)
+
+ def update_state(self, true_target_word, predictions, sample_weight=None):
+ """Accumulates true positive, false positive and false negative statistics."""
+ if sample_weight is not None:
+ raise NotImplemented("WordsSubtokenMetricBase with non-None `sample_weight` is not implemented.")
+
+ # For each example in the batch we have:
+ # (i) one ground true target word;
+ # (ii) one predicted word (argmax y_hat)
+
+ topk_predicted_words = predictions if self.topk_predicted_words is None else self.topk_predicted_words
+ assert topk_predicted_words is not None
+ predicted_word = self._get_prediction_from_topk(topk_predicted_words)
+
+ true_target_word_string = self._get_true_target_word_string(true_target_word)
+ true_target_word_string = tf.reshape(true_target_word_string, [-1])
+
+ # We split each word into subtokens
+ true_target_subwords = tf.compat.v1.string_split(true_target_word_string, sep=self.subtokens_delimiter)
+ prediction_subwords = tf.compat.v1.string_split(predicted_word, sep=self.subtokens_delimiter)
+ true_target_subwords = tf.sparse.to_dense(true_target_subwords, default_value='')
+ prediction_subwords = tf.sparse.to_dense(prediction_subwords, default_value='')
+ true_target_subwords_mask = tf.not_equal(true_target_subwords, '')
+ prediction_subwords_mask = tf.not_equal(prediction_subwords, '')
+ # Now shapes of true_target_subwords & true_target_subwords are (batch, subtokens)
+
+ # We use broadcast to calculate 2 lists difference with duplicates preserving.
+ true_target_subwords = tf.expand_dims(true_target_subwords, -1)
+ prediction_subwords = tf.expand_dims(prediction_subwords, -1)
+ # Now shapes of true_target_subwords & true_target_subwords are (batch, subtokens, 1)
+ true_target_subwords__in__prediction_subwords = \
+ tf.reduce_any(tf.equal(true_target_subwords, tf.transpose(prediction_subwords, perm=[0, 2, 1])), axis=2)
+ prediction_subwords__in__true_target_subwords = \
+ tf.reduce_any(tf.equal(prediction_subwords, tf.transpose(true_target_subwords, perm=[0, 2, 1])), axis=2)
+
+ # Count ground true label subwords that exist in the predicted word.
+ batch_true_positive = tf.reduce_sum(tf.cast(
+ tf.logical_and(prediction_subwords__in__true_target_subwords, prediction_subwords_mask), tf.float32))
+ # Count ground true label subwords that don't exist in the predicted word.
+ batch_false_positive = tf.reduce_sum(tf.cast(
+ tf.logical_and(~prediction_subwords__in__true_target_subwords, prediction_subwords_mask), tf.float32))
+ # Count predicted word subwords that don't exist in the ground true label.
+ batch_false_negative = tf.reduce_sum(tf.cast(
+ tf.logical_and(~true_target_subwords__in__prediction_subwords, true_target_subwords_mask), tf.float32))
+
+ self.tp.assign_add(batch_true_positive)
+ self.fp.assign_add(batch_false_positive)
+ self.fn.assign_add(batch_false_negative)
+
+ def _get_prediction_from_topk(self, topk_predicted_words):
+ # apply given filter
+ masks = []
+ if self.predicted_words_filters is not None:
+ masks = [fltr(topk_predicted_words) for fltr in self.predicted_words_filters]
+ if masks:
+ # assert all(mask.shape.assert_is_compatible_with(top_k_pred_indices) for mask in masks)
+ legal_predicted_target_words_mask = reduce(tf.logical_and, masks)
+ else:
+ legal_predicted_target_words_mask = tf.cast(tf.ones_like(topk_predicted_words), dtype=tf.bool)
+
+ # the first legal predicted word is our prediction
+ first_legal_predicted_target_word_mask = common.tf_get_first_true(legal_predicted_target_words_mask)
+ first_legal_predicted_target_word_idx = tf.where(first_legal_predicted_target_word_mask)
+ first_legal_predicted_word_string = tf.gather_nd(topk_predicted_words,
+ first_legal_predicted_target_word_idx)
+
+ prediction = tf.reshape(first_legal_predicted_word_string, [-1])
+ return prediction
+
+ @abc.abstractmethod
+ def result(self):
+ ...
+
+ def reset_states(self):
+ for v in self.variables:
+ K.set_value(v, 0)
+
+
+class WordsSubtokenPrecisionMetric(WordsSubtokenMetricBase):
+ def result(self):
+ precision = tf.math.divide_no_nan(self.tp, self.tp + self.fp)
+ return precision
+
+
+class WordsSubtokenRecallMetric(WordsSubtokenMetricBase):
+ def result(self):
+ recall = tf.math.divide_no_nan(self.tp, self.tp + self.fn)
+ return recall
+
+
+class WordsSubtokenF1Metric(WordsSubtokenMetricBase):
+ def result(self):
+ recall = tf.math.divide_no_nan(self.tp, self.tp + self.fn)
+ precision = tf.math.divide_no_nan(self.tp, self.tp + self.fp)
+ f1 = tf.math.divide_no_nan(2 * precision * recall, precision + recall + K.epsilon())
+ return f1
diff --git a/model.py b/model.py
deleted file mode 100644
index 8c9a95f..0000000
--- a/model.py
+++ /dev/null
@@ -1,474 +0,0 @@
-import tensorflow as tf
-
-import PathContextReader
-import numpy as np
-import time
-import pickle
-from common import common, VocabType
-
-
-class Model:
- topk = 10
- num_batches_to_log = 100
-
- def __init__(self, config):
- self.config = config
- self.sess = tf.Session()
-
- self.eval_data_lines = None
- self.eval_queue = None
- self.predict_queue = None
-
- self.eval_placeholder = None
- self.predict_placeholder = None
- self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op, self.eval_code_vectors = None, None, None, None
- self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op = None, None, None
-
- if config.TRAIN_PATH:
- with open('{}.dict.c2v'.format(config.TRAIN_PATH), 'rb') as file:
- word_to_count = pickle.load(file)
- path_to_count = pickle.load(file)
- target_to_count = pickle.load(file)
- num_training_examples = pickle.load(file)
- self.config.NUM_EXAMPLES = num_training_examples
- print('Dictionaries loaded.')
-
- if config.LOAD_PATH:
- self.load_model(sess=None)
- else:
- self.word_to_index, self.index_to_word, self.word_vocab_size = \
- common.load_vocab_from_dict(word_to_count, config.WORDS_VOCAB_SIZE, start_from=1)
- print('Loaded word vocab. size: %d' % self.word_vocab_size)
-
- self.target_word_to_index, self.index_to_target_word, self.target_word_vocab_size = \
- common.load_vocab_from_dict(target_to_count, config.TARGET_VOCAB_SIZE,
- start_from=1)
- print('Loaded target word vocab. size: %d' % self.target_word_vocab_size)
-
- self.path_to_index, self.index_to_path, self.path_vocab_size = \
- common.load_vocab_from_dict(path_to_count, config.PATHS_VOCAB_SIZE,
- start_from=1)
- print('Loaded paths vocab. size: %d' % self.path_vocab_size)
-
- self.create_index_to_target_word_map()
-
- def create_index_to_target_word_map(self):
- self.index_to_target_word_table = tf.contrib.lookup.HashTable(
- tf.contrib.lookup.KeyValueTensorInitializer(list(self.index_to_target_word.keys()),
- list(self.index_to_target_word.values()),
- key_dtype=tf.int64, value_dtype=tf.string),
- default_value=tf.constant(common.noSuchWord, dtype=tf.string))
-
- def close_session(self):
- self.sess.close()
-
- def train(self):
- print('Starting training')
- start_time = time.time()
-
- batch_num = 0
- sum_loss = 0
- multi_batch_start_time = time.time()
- num_batches_to_evaluate = max(int(
- self.config.NUM_EXAMPLES / self.config.BATCH_SIZE * self.config.SAVE_EVERY_EPOCHS), 1)
-
- self.queue_thread = PathContextReader.PathContextReader(word_to_index=self.word_to_index,
- path_to_index=self.path_to_index,
- target_word_to_index=self.target_word_to_index,
- config=self.config)
- optimizer, train_loss = self.build_training_graph(self.queue_thread.input_tensors())
- self.saver = tf.train.Saver(max_to_keep=self.config.MAX_TO_KEEP)
-
- self.initialize_session_variables(self.sess)
- print('Initalized variables')
- if self.config.LOAD_PATH:
- self.load_model(self.sess)
- with self.queue_thread.start(self.sess):
- time.sleep(1)
- print('Started reader...')
- try:
- while True:
- batch_num += 1
- _, batch_loss = self.sess.run([optimizer, train_loss])
- sum_loss += batch_loss
- if batch_num % self.num_batches_to_log == 0:
- self.trace(sum_loss, batch_num, multi_batch_start_time)
- print('Number of waiting examples in queue: %d' % self.sess.run(
- "shuffle_batch/random_shuffle_queue_Size:0"))
- sum_loss = 0
- multi_batch_start_time = time.time()
- if batch_num % num_batches_to_evaluate == 0:
- epoch_num = int((batch_num / num_batches_to_evaluate) * self.config.SAVE_EVERY_EPOCHS)
- save_target = self.config.SAVE_PATH + '_iter' + str(epoch_num)
- self.save_model(self.sess, save_target)
- print('Saved after %d epochs in: %s' % (epoch_num, save_target))
- results, precision, recall, f1 = self.evaluate()
- print('Accuracy after %d epochs: %s' % (epoch_num, results[:5]))
- print('After ' + str(epoch_num) + ' epochs: Precision: ' + str(precision) + ', recall: ' + str(
- recall) + ', F1: ' + str(f1))
- except tf.errors.OutOfRangeError:
- print('Done training')
-
- if self.config.SAVE_PATH:
- self.save_model(self.sess, self.config.SAVE_PATH)
- print('Model saved in file: %s' % self.config.SAVE_PATH)
-
- elapsed = int(time.time() - start_time)
- print("Training time: %sH:%sM:%sS\n" % ((elapsed // 60 // 60), (elapsed // 60) % 60, elapsed % 60))
-
- def trace(self, sum_loss, batch_num, multi_batch_start_time):
- multi_batch_elapsed = time.time() - multi_batch_start_time
- avg_loss = sum_loss / (self.num_batches_to_log * self.config.BATCH_SIZE)
- print('Average loss at batch %d: %f, \tthroughput: %d samples/sec' % (batch_num, avg_loss,
- self.config.BATCH_SIZE * self.num_batches_to_log / (
- multi_batch_elapsed if multi_batch_elapsed > 0 else 1)))
-
- def evaluate(self):
- eval_start_time = time.time()
- if self.eval_queue is None:
- self.eval_queue = PathContextReader.PathContextReader(word_to_index=self.word_to_index,
- path_to_index=self.path_to_index,
- target_word_to_index=self.target_word_to_index,
- config=self.config, is_evaluating=True)
- self.eval_placeholder = self.eval_queue.get_input_placeholder()
- self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op, _, _, _, _, self.eval_code_vectors = \
- self.build_test_graph(self.eval_queue.get_filtered_batches())
- self.saver = tf.train.Saver()
-
- if self.config.LOAD_PATH and not self.config.TRAIN_PATH:
- self.initialize_session_variables(self.sess)
- self.load_model(self.sess)
- if self.config.RELEASE:
- release_name = self.config.LOAD_PATH + '.release'
- print('Releasing model, output model: %s' % release_name )
- self.saver.save(self.sess, release_name )
- return None
-
- if self.eval_data_lines is None:
- print('Loading test data from: ' + self.config.TEST_PATH)
- self.eval_data_lines = common.load_file_lines(self.config.TEST_PATH)
- print('Done loading test data')
-
- with open('log.txt', 'w') as output_file:
- if self.config.EXPORT_CODE_VECTORS:
- code_vectors_file = open(self.config.TEST_PATH + '.vectors', 'w')
- num_correct_predictions = np.zeros(self.topk)
- total_predictions = 0
- total_prediction_batches = 0
- true_positive, false_positive, false_negative = 0, 0, 0
- start_time = time.time()
-
- for batch in common.split_to_batches(self.eval_data_lines, self.config.TEST_BATCH_SIZE):
- top_words, top_scores, original_names, code_vectors = self.sess.run(
- [self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op, self.eval_code_vectors],
- feed_dict={self.eval_placeholder: batch})
- top_words, original_names = common.binary_to_string_matrix(top_words), common.binary_to_string_matrix(
- original_names)
- # Flatten original names from [[]] to []
- original_names = [w for l in original_names for w in l]
-
- num_correct_predictions = self.update_correct_predictions(num_correct_predictions, output_file,
- zip(original_names, top_words))
- true_positive, false_positive, false_negative = self.update_per_subtoken_statistics(
- zip(original_names, top_words),
- true_positive, false_positive, false_negative)
-
- total_predictions += len(original_names)
- total_prediction_batches += 1
- if self.config.EXPORT_CODE_VECTORS:
- self.write_code_vectors(code_vectors_file, code_vectors)
- if total_prediction_batches % self.num_batches_to_log == 0:
- elapsed = time.time() - start_time
- # start_time = time.time()
- self.trace_evaluation(output_file, num_correct_predictions, total_predictions, elapsed, len(self.eval_data_lines))
-
- print('Done testing, epoch reached')
- output_file.write(str(num_correct_predictions / total_predictions) + '\n')
- if self.config.EXPORT_CODE_VECTORS:
- code_vectors_file.close()
-
- elapsed = int(time.time() - eval_start_time)
- precision, recall, f1 = self.calculate_results(true_positive, false_positive, false_negative)
- print("Evaluation time: %sH:%sM:%sS" % ((elapsed // 60 // 60), (elapsed // 60) % 60, elapsed % 60))
- del self.eval_data_lines
- self.eval_data_lines = None
- return num_correct_predictions / total_predictions, precision, recall, f1
-
- def write_code_vectors(self, file, code_vectors):
- for vec in code_vectors:
- file.write(' '.join(map(str, vec)) + '\n')
-
- def update_per_subtoken_statistics(self, results, true_positive, false_positive, false_negative):
- for original_name, top_words in results:
- prediction = common.filter_impossible_names(top_words)[0]
- original_subtokens = common.get_subtokens(original_name)
- predicted_subtokens = common.get_subtokens(prediction)
- for subtok in predicted_subtokens:
- if subtok in original_subtokens:
- true_positive += 1
- else:
- false_positive += 1
- for subtok in original_subtokens:
- if not subtok in predicted_subtokens:
- false_negative += 1
- return true_positive, false_positive, false_negative
-
- @staticmethod
- def calculate_results(true_positive, false_positive, false_negative):
- precision = true_positive / (true_positive + false_positive)
- recall = true_positive / (true_positive + false_negative)
- f1 = 2 * precision * recall / (precision + recall)
- return precision, recall, f1
-
- @staticmethod
- def trace_evaluation(output_file, correct_predictions, total_predictions, elapsed, total_examples):
- state_message = 'Evaluated %d/%d examples...' % (total_predictions, total_examples)
- throughput_message = "Prediction throughput: %d samples/sec" % int(total_predictions / (elapsed if elapsed > 0 else 1))
- print(state_message)
- print(throughput_message)
-
- def update_correct_predictions(self, num_correct_predictions, output_file, results):
- for original_name, top_words in results:
- normalized_original_name = common.normalize_word(original_name)
- predicted_something = False
- for i, predicted_word in enumerate(common.filter_impossible_names(top_words)):
- if i == 0:
- output_file.write('Original: ' + original_name + ', predicted 1st: ' + predicted_word + '\n')
- predicted_something = True
- normalized_suggestion = common.normalize_word(predicted_word)
- if normalized_original_name == normalized_suggestion:
- output_file.write('\t\t predicted correctly at rank: ' + str(i + 1) + '\n')
- for j in range(i, self.topk):
- num_correct_predictions[j] += 1
- break
- if not predicted_something:
- output_file.write('No results for predicting: ' + original_name)
- return num_correct_predictions
-
- def build_training_graph(self, input_tensors):
- words_input, source_input, path_input, target_input, valid_mask = input_tensors # (batch, 1), (batch, max_contexts)
-
- with tf.variable_scope('model'):
- words_vocab = tf.get_variable('WORDS_VOCAB', shape=(self.word_vocab_size + 1, self.config.EMBEDDINGS_SIZE),
- dtype=tf.float32,
- initializer=tf.contrib.layers.variance_scaling_initializer(factor=1.0,
- mode='FAN_OUT',
- uniform=True))
- target_words_vocab = tf.get_variable('TARGET_WORDS_VOCAB',
- shape=(
- self.target_word_vocab_size + 1, self.config.EMBEDDINGS_SIZE * 3),
- dtype=tf.float32,
- initializer=tf.contrib.layers.variance_scaling_initializer(factor=1.0,
- mode='FAN_OUT',
- uniform=True))
- attention_param = tf.get_variable('ATTENTION',
- shape=(self.config.EMBEDDINGS_SIZE * 3, 1), dtype=tf.float32)
- paths_vocab = tf.get_variable('PATHS_VOCAB', shape=(self.path_vocab_size + 1, self.config.EMBEDDINGS_SIZE),
- dtype=tf.float32,
- initializer=tf.contrib.layers.variance_scaling_initializer(factor=1.0,
- mode='FAN_OUT',
- uniform=True))
-
- code_vectors, _ = self.calculate_weighted_contexts(words_vocab, paths_vocab, attention_param,
- source_input, path_input, target_input,
- valid_mask)
-
- logits = tf.matmul(code_vectors, target_words_vocab, transpose_b=True)
- batch_size = tf.to_float(tf.shape(words_input)[0])
- loss = tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(
- labels=tf.reshape(words_input, [-1]),
- logits=logits)) / batch_size
-
- optimizer = tf.train.AdamOptimizer().minimize(loss)
-
- return optimizer, loss
-
- def calculate_weighted_contexts(self, words_vocab, paths_vocab, attention_param, source_input, path_input,
- target_input, valid_mask, is_evaluating=False):
- keep_prob1 = 0.75
- max_contexts = self.config.MAX_CONTEXTS
-
- source_word_embed = tf.nn.embedding_lookup(params=words_vocab, ids=source_input) # (batch, max_contexts, dim)
- path_embed = tf.nn.embedding_lookup(params=paths_vocab, ids=path_input) # (batch, max_contexts, dim)
- target_word_embed = tf.nn.embedding_lookup(params=words_vocab, ids=target_input) # (batch, max_contexts, dim)
-
- context_embed = tf.concat([source_word_embed, path_embed, target_word_embed],
- axis=-1) # (batch, max_contexts, dim * 3)
- if not is_evaluating:
- context_embed = tf.nn.dropout(context_embed, keep_prob1)
-
- flat_embed = tf.reshape(context_embed, [-1, self.config.EMBEDDINGS_SIZE * 3]) # (batch * max_contexts, dim * 3)
- transform_param = tf.get_variable('TRANSFORM',
- shape=(self.config.EMBEDDINGS_SIZE * 3, self.config.EMBEDDINGS_SIZE * 3),
- dtype=tf.float32)
-
- flat_embed = tf.tanh(tf.matmul(flat_embed, transform_param)) # (batch * max_contexts, dim * 3)
-
- contexts_weights = tf.matmul(flat_embed, attention_param) # (batch * max_contexts, 1)
- batched_contexts_weights = tf.reshape(contexts_weights,
- [-1, max_contexts, 1]) # (batch, max_contexts, 1)
- mask = tf.log(valid_mask) # (batch, max_contexts)
- mask = tf.expand_dims(mask, axis=2) # (batch, max_contexts, 1)
- batched_contexts_weights += mask # (batch, max_contexts, 1)
- attention_weights = tf.nn.softmax(batched_contexts_weights, axis=1) # (batch, max_contexts, 1)
-
- batched_embed = tf.reshape(flat_embed, shape=[-1, max_contexts, self.config.EMBEDDINGS_SIZE * 3])
- code_vectors = tf.reduce_sum(tf.multiply(batched_embed, attention_weights),
- axis=1) # (batch, dim * 3)
-
- return code_vectors, attention_weights
-
- def build_test_graph(self, input_tensors, normalize_scores=False):
- with tf.variable_scope('model', reuse=self.get_should_reuse_variables()):
- words_vocab = tf.get_variable('WORDS_VOCAB', shape=(self.word_vocab_size + 1, self.config.EMBEDDINGS_SIZE),
- dtype=tf.float32, trainable=False)
- target_words_vocab = tf.get_variable('TARGET_WORDS_VOCAB',
- shape=(
- self.target_word_vocab_size + 1, self.config.EMBEDDINGS_SIZE * 3),
- dtype=tf.float32, trainable=False)
- attention_param = tf.get_variable('ATTENTION',
- shape=(self.config.EMBEDDINGS_SIZE * 3, 1),
- dtype=tf.float32, trainable=False)
- paths_vocab = tf.get_variable('PATHS_VOCAB',
- shape=(self.path_vocab_size + 1, self.config.EMBEDDINGS_SIZE),
- dtype=tf.float32, trainable=False)
-
- target_words_vocab = tf.transpose(target_words_vocab) # (dim * 3, target_word_vocab+1)
-
- words_input, source_input, path_input, target_input, valid_mask, source_string, path_string, path_target_string = input_tensors # (batch, 1), (batch, max_contexts)
-
- code_vectors, attention_weights = self.calculate_weighted_contexts(words_vocab, paths_vocab,
- attention_param,
- source_input, path_input,
- target_input,
- valid_mask, True)
-
- scores = tf.matmul(code_vectors, target_words_vocab) # (batch, target_word_vocab+1)
-
- topk_candidates = tf.nn.top_k(scores, k=tf.minimum(self.topk, self.target_word_vocab_size))
- top_indices = tf.to_int64(topk_candidates.indices)
- top_words = self.index_to_target_word_table.lookup(top_indices)
- original_words = words_input
- top_scores = topk_candidates.values
- if normalize_scores:
- top_scores = tf.nn.softmax(top_scores)
-
- return top_words, top_scores, original_words, attention_weights, source_string, path_string, path_target_string, code_vectors
-
- def predict(self, predict_data_lines):
- if self.predict_queue is None:
- self.predict_queue = PathContextReader.PathContextReader(word_to_index=self.word_to_index,
- path_to_index=self.path_to_index,
- target_word_to_index=self.target_word_to_index,
- config=self.config, is_evaluating=True)
- self.predict_placeholder = self.predict_queue.get_input_placeholder()
- self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op, \
- self.attention_weights_op, self.predict_source_string, self.predict_path_string, self.predict_path_target_string, self.predict_code_vectors = \
- self.build_test_graph(self.predict_queue.get_filtered_batches(), normalize_scores=True)
-
- self.initialize_session_variables(self.sess)
- self.saver = tf.train.Saver()
- self.load_model(self.sess)
-
- code_vectors = []
- results = []
- for batch in common.split_to_batches(predict_data_lines, 1):
- top_words, top_scores, original_names, attention_weights, source_strings, path_strings, target_strings, batch_code_vectors = self.sess.run(
- [self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op,
- self.attention_weights_op, self.predict_source_string, self.predict_path_string,
- self.predict_path_target_string, self.predict_code_vectors],
- feed_dict={self.predict_placeholder: batch})
- top_words, original_names = common.binary_to_string_matrix(top_words), common.binary_to_string_matrix(
- original_names)
- # Flatten original names from [[]] to []
- attention_per_path = self.get_attention_per_path(source_strings, path_strings, target_strings,
- attention_weights)
- original_names = [w for l in original_names for w in l]
- results.append((original_names[0], top_words[0], top_scores[0], attention_per_path))
- if self.config.EXPORT_CODE_VECTORS:
- code_vectors.append(batch_code_vectors)
- if len(code_vectors) > 0:
- code_vectors = np.vstack(code_vectors)
- return results, code_vectors
-
- def get_attention_per_path(self, source_strings, path_strings, target_strings, attention_weights):
- attention_weights = np.squeeze(attention_weights) # (max_contexts, )
- attention_per_context = {}
- for source, path, target, weight in zip(source_strings, path_strings, target_strings, attention_weights):
- string_triplet = (
- common.binary_to_string(source), common.binary_to_string(path), common.binary_to_string(target))
- attention_per_context[string_triplet] = weight
- return attention_per_context
-
- @staticmethod
- def get_dictionaries_path(model_file_path):
- dictionaries_save_file_name = "dictionaries.bin"
- return '/'.join(model_file_path.split('/')[:-1] + [dictionaries_save_file_name])
-
- def save_model(self, sess, path):
- self.saver.save(sess, path)
- with open(self.get_dictionaries_path(path), 'wb') as file:
- pickle.dump(self.word_to_index, file)
- pickle.dump(self.index_to_word, file)
- pickle.dump(self.word_vocab_size, file)
-
- pickle.dump(self.target_word_to_index, file)
- pickle.dump(self.index_to_target_word, file)
- pickle.dump(self.target_word_vocab_size, file)
-
- pickle.dump(self.path_to_index, file)
- pickle.dump(self.index_to_path, file)
- pickle.dump(self.path_vocab_size, file)
-
- def load_model(self, sess):
- if not sess is None:
- print('Loading model weights from: ' + self.config.LOAD_PATH)
- self.saver.restore(sess, self.config.LOAD_PATH)
- print('Done')
- dictionaries_path = self.get_dictionaries_path(self.config.LOAD_PATH)
- with open(dictionaries_path , 'rb') as file:
- print('Loading model dictionaries from: %s' % dictionaries_path)
- self.word_to_index = pickle.load(file)
- self.index_to_word = pickle.load(file)
- self.word_vocab_size = pickle.load(file)
-
- self.target_word_to_index = pickle.load(file)
- self.index_to_target_word = pickle.load(file)
- self.target_word_vocab_size = pickle.load(file)
-
- self.path_to_index = pickle.load(file)
- self.index_to_path = pickle.load(file)
- self.path_vocab_size = pickle.load(file)
- print('Done')
-
- def save_word2vec_format(self, dest, source):
- with tf.variable_scope('model', reuse=None):
- if source is VocabType.Token:
- vocab_size = self.word_vocab_size
- embedding_size = self.config.EMBEDDINGS_SIZE
- index = self.index_to_word
- var_name = 'WORDS_VOCAB'
- elif source is VocabType.Target:
- vocab_size = self.target_word_vocab_size
- embedding_size = self.config.EMBEDDINGS_SIZE * 3
- index = self.index_to_target_word
- var_name = 'TARGET_WORDS_VOCAB'
- else:
- raise ValueError('vocab type should be VocabType.Token or VocabType.Target.')
- embeddings = tf.get_variable(var_name, shape=(vocab_size + 1, embedding_size), dtype=tf.float32,
- trainable=False)
- self.saver = tf.train.Saver()
- self.load_model(self.sess)
- np_embeddings = self.sess.run(embeddings)
- with open(dest, 'w') as words_file:
- common.save_word2vec_file(words_file, vocab_size, embedding_size, index, np_embeddings)
-
- @staticmethod
- def initialize_session_variables(sess):
- sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()))
-
- def get_should_reuse_variables(self):
- if self.config.TRAIN_PATH:
- return True
- else:
- return None
diff --git a/model_base.py b/model_base.py
new file mode 100644
index 0000000..763f1a9
--- /dev/null
+++ b/model_base.py
@@ -0,0 +1,181 @@
+import numpy as np
+import abc
+import os
+from typing import NamedTuple, Optional, List, Dict, Tuple, Iterable
+
+from common import common
+from vocabularies import Code2VecVocabs, VocabType
+from config import Config
+
+
+class ModelEvaluationResults(NamedTuple):
+ topk_acc: float
+ subtoken_precision: float
+ subtoken_recall: float
+ subtoken_f1: float
+ loss: Optional[float] = None
+
+ def __str__(self):
+ res_str = 'topk_acc: {topk_acc}, precision: {precision}, recall: {recall}, F1: {f1}'.format(
+ topk_acc=self.topk_acc,
+ precision=self.subtoken_precision,
+ recall=self.subtoken_recall,
+ f1=self.subtoken_f1)
+ if self.loss is not None:
+ res_str = ('loss: {}, '.format(self.loss)) + res_str
+ return res_str
+
+
+class ModelPredictionResults(NamedTuple):
+ original_name: str
+ topk_predicted_words: np.ndarray
+ topk_predicted_words_scores: np.ndarray
+ attention_per_context: Dict[Tuple[str, str, str], float]
+ code_vector: Optional[np.ndarray] = None
+
+
+class Code2VecModelBase(abc.ABC):
+ def __init__(self, config: Config):
+ self.config = config
+ self.config.verify()
+
+ self._log_creating_model()
+
+ self._init_num_of_examples()
+ self._log_model_configuration()
+ self.vocabs = Code2VecVocabs(config)
+ self.vocabs.target_vocab.get_index_to_word_lookup_table() # just to initialize it (if not already initialized)
+ self._load_or_create_inner_model()
+ self._initialize()
+
+ def _log_creating_model(self):
+ self.log('')
+ self.log('')
+ self.log('---------------------------------------------------------------------')
+ self.log('---------------------------------------------------------------------')
+ self.log('---------------------- Creating word2vec model ----------------------')
+ self.log('---------------------------------------------------------------------')
+ self.log('---------------------------------------------------------------------')
+
+ def _log_model_configuration(self):
+ self.log('---------------------------------------------------------------------')
+ self.log('----------------- Configuration - Hyper Parameters ------------------')
+ longest_param_name_len = max(len(param_name) for param_name, _ in self.config)
+ for param_name, param_val in self.config:
+ self.log('{name: <{name_len}}{val}'.format(
+ name=param_name, val=param_val, name_len=longest_param_name_len+2))
+ self.log('---------------------------------------------------------------------')
+
+ @property
+ def logger(self):
+ return self.config.get_logger()
+
+ def log(self, msg):
+ self.logger.info(msg)
+
+ def _init_num_of_examples(self):
+ self.log('Checking number of examples ...')
+ if self.config.is_training:
+ self.config.NUM_TRAIN_EXAMPLES = self._get_num_of_examples_for_dataset(self.config.train_data_path)
+ self.log(' Number of train examples: {}'.format(self.config.NUM_TRAIN_EXAMPLES))
+ if self.config.is_testing:
+ self.config.NUM_TEST_EXAMPLES = self._get_num_of_examples_for_dataset(self.config.TEST_DATA_PATH)
+ self.log(' Number of test examples: {}'.format(self.config.NUM_TEST_EXAMPLES))
+
+ @staticmethod
+ def _get_num_of_examples_for_dataset(dataset_path: str) -> int:
+ dataset_num_examples_file_path = dataset_path + '.num_examples'
+ if os.path.isfile(dataset_num_examples_file_path):
+ with open(dataset_num_examples_file_path, 'r') as file:
+ num_examples_in_dataset = int(file.readline())
+ else:
+ num_examples_in_dataset = common.count_lines_in_file(dataset_path)
+ with open(dataset_num_examples_file_path, 'w') as file:
+ file.write(str(num_examples_in_dataset))
+ return num_examples_in_dataset
+
+ def load_or_build(self):
+ self.vocabs = Code2VecVocabs(self.config)
+ self._load_or_create_inner_model()
+
+ def save(self, model_save_path=None):
+ if model_save_path is None:
+ model_save_path = self.config.MODEL_SAVE_PATH
+ model_save_dir = '/'.join(model_save_path.split('/')[:-1])
+ if not os.path.isdir(model_save_dir):
+ os.mkdir(model_save_dir)
+ self.vocabs.save(self.config.get_vocabularies_path_from_model_path(model_save_path))
+ self._save_inner_model(model_save_path)
+
+ def _write_code_vectors(self, file, code_vectors):
+ for vec in code_vectors:
+ file.write(' '.join(map(str, vec)) + '\n')
+
+ def _get_attention_weight_per_context(
+ self, path_source_strings: Iterable[str], path_strings: Iterable[str], path_target_strings: Iterable[str],
+ attention_weights: Iterable[float]) -> Dict[Tuple[str, str, str], float]:
+ attention_weights = np.squeeze(attention_weights, axis=-1) # (max_contexts, )
+ attention_per_context: Dict[Tuple[str, str, str], float] = {}
+ # shape of path_source_strings, path_strings, path_target_strings, attention_weights is (max_contexts, )
+
+ # iterate over contexts
+ for path_source, path, path_target, weight in \
+ zip(path_source_strings, path_strings, path_target_strings, attention_weights):
+ string_context_triplet = (common.binary_to_string(path_source),
+ common.binary_to_string(path),
+ common.binary_to_string(path_target))
+ attention_per_context[string_context_triplet] = weight
+ return attention_per_context
+
+ def close_session(self):
+ # can be overridden by the implementation model class.
+ # default implementation just does nothing.
+ pass
+
+ @abc.abstractmethod
+ def train(self):
+ ...
+
+ @abc.abstractmethod
+ def evaluate(self) -> Optional[ModelEvaluationResults]:
+ ...
+
+ @abc.abstractmethod
+ def predict(self, predict_data_lines: Iterable[str]) -> List[ModelPredictionResults]:
+ ...
+
+ @abc.abstractmethod
+ def _save_inner_model(self, path):
+ ...
+
+ def _load_or_create_inner_model(self):
+ if self.config.is_loading:
+ self._load_inner_model()
+ else:
+ self._create_inner_model()
+
+ @abc.abstractmethod
+ def _load_inner_model(self):
+ ...
+
+ def _create_inner_model(self):
+ # can be overridden by the implementation model class.
+ # default implementation just does nothing.
+ pass
+
+ def _initialize(self):
+ # can be overridden by the implementation model class.
+ # default implementation just does nothing.
+ pass
+
+ @abc.abstractmethod
+ def _get_vocab_embedding_as_np_array(self, vocab_type: VocabType) -> np.ndarray:
+ ...
+
+ def save_word2vec_format(self, dest_save_path: str, vocab_type: VocabType):
+ if vocab_type not in VocabType:
+ raise ValueError('`vocab_type` should be `VocabType.Token`, `VocabType.Target` or `VocabType.Path`.')
+ vocab_embedding_matrix = self._get_vocab_embedding_as_np_array(vocab_type)
+ index_to_word = self.vocabs.get(vocab_type).index_to_word
+ with open(dest_save_path, 'w') as words_file:
+ common.save_word2vec_file(words_file, index_to_word, vocab_embedding_matrix)
diff --git a/path_context_reader.py b/path_context_reader.py
new file mode 100644
index 0000000..c9424eb
--- /dev/null
+++ b/path_context_reader.py
@@ -0,0 +1,229 @@
+import tensorflow as tf
+from typing import Dict, Tuple, NamedTuple, Union, Optional, Iterable
+from config import Config
+from vocabularies import Code2VecVocabs
+import abc
+from functools import reduce
+from enum import Enum
+
+
+class EstimatorAction(Enum):
+ Train = 'train'
+ Evaluate = 'evaluate'
+ Predict = 'predict'
+
+ @property
+ def is_train(self):
+ return self is EstimatorAction.Train
+
+ @property
+ def is_evaluate(self):
+ return self is EstimatorAction.Evaluate
+
+ @property
+ def is_predict(self):
+ return self is EstimatorAction.Predict
+
+ @property
+ def is_evaluate_or_predict(self):
+ return self.is_evaluate or self.is_predict
+
+
+class ReaderInputTensors(NamedTuple):
+ """
+ Used mostly for convenient-and-clear access to input parts (by their names).
+ """
+ path_source_token_indices: tf.Tensor
+ path_indices: tf.Tensor
+ path_target_token_indices: tf.Tensor
+ context_valid_mask: tf.Tensor
+ target_index: Optional[tf.Tensor] = None
+ target_string: Optional[tf.Tensor] = None
+ path_source_token_strings: Optional[tf.Tensor] = None
+ path_strings: Optional[tf.Tensor] = None
+ path_target_token_strings: Optional[tf.Tensor] = None
+
+
+class ModelInputTensorsFormer(abc.ABC):
+ """
+ Should be inherited by the model implementation.
+ An instance of the inherited class is passed by the model to the reader in order to help the reader
+ to construct the input in the form that the model expects to receive it.
+ This class also enables conveniently & clearly access input parts by their field names.
+ eg: 'tensors.path_indices' instead if 'tensors[1]'.
+ This allows the input tensors to be passed as pure tuples along the computation graph, while the
+ python functions that construct the graph can easily (and clearly) access tensors.
+ """
+
+ @abc.abstractmethod
+ def to_model_input_form(self, input_tensors: ReaderInputTensors):
+ ...
+
+ @abc.abstractmethod
+ def from_model_input_form(self, input_row) -> ReaderInputTensors:
+ ...
+
+
+class PathContextReader:
+ def __init__(self,
+ vocabs: Code2VecVocabs,
+ config: Config,
+ model_input_tensors_former: ModelInputTensorsFormer,
+ estimator_action: EstimatorAction,
+ repeat_endlessly: bool = False):
+ self.vocabs = vocabs
+ self.config = config
+ self.model_input_tensors_former = model_input_tensors_former
+ self.estimator_action = estimator_action
+ self.repeat_endlessly = repeat_endlessly
+ self.CONTEXT_PADDING = ','.join([self.vocabs.token_vocab.special_words.PAD,
+ self.vocabs.path_vocab.special_words.PAD,
+ self.vocabs.token_vocab.special_words.PAD])
+ self.csv_record_defaults = [[self.vocabs.target_vocab.special_words.OOV]] + \
+ ([[self.CONTEXT_PADDING]] * self.config.MAX_CONTEXTS)
+
+ # initialize the needed lookup tables (if not already initialized).
+ self.create_needed_vocabs_lookup_tables(self.vocabs)
+
+ self._dataset: Optional[tf.data.Dataset] = None
+
+ @classmethod
+ def create_needed_vocabs_lookup_tables(cls, vocabs: Code2VecVocabs):
+ vocabs.token_vocab.get_word_to_index_lookup_table()
+ vocabs.path_vocab.get_word_to_index_lookup_table()
+ vocabs.target_vocab.get_word_to_index_lookup_table()
+
+ @tf.function
+ def process_input_row(self, row_placeholder):
+ parts = tf.io.decode_csv(
+ row_placeholder, record_defaults=self.csv_record_defaults, field_delim=' ', use_quote_delim=False)
+ # Note: we DON'T apply the filter `_filter_input_rows()` here.
+ tensors = self._map_raw_dataset_row_to_input_tensors(*parts)
+
+ # make it batched (first batch axis is going to have dimension 1)
+ tensors_expanded = ReaderInputTensors(
+ **{name: None if tensor is None else tf.expand_dims(tensor, axis=0)
+ for name, tensor in tensors._asdict().items()})
+ return self.model_input_tensors_former.to_model_input_form(tensors_expanded)
+
+ def process_and_iterate_input_from_data_lines(self, input_data_lines: Iterable) -> Iterable:
+ for data_row in input_data_lines:
+ processed_row = self.process_input_row(data_row)
+ yield processed_row
+
+ def get_dataset(self, input_data_rows: Optional = None) -> tf.data.Dataset:
+ if self._dataset is None:
+ self._dataset = self._create_dataset_pipeline(input_data_rows)
+ return self._dataset
+
+ def _create_dataset_pipeline(self, input_data_rows: Optional = None) -> tf.data.Dataset:
+ if input_data_rows is None:
+ assert not self.estimator_action.is_predict
+ dataset = tf.data.experimental.CsvDataset(
+ self.config.data_path(is_evaluating=self.estimator_action.is_evaluate),
+ record_defaults=self.csv_record_defaults, field_delim=' ', use_quote_delim=False,
+ buffer_size=self.config.CSV_BUFFER_SIZE)
+ else:
+ dataset = tf.data.Dataset.from_tensor_slices(input_data_rows)
+ dataset = dataset.map(
+ lambda input_line: tf.io.decode_csv(
+ tf.reshape(tf.cast(input_line, tf.string), ()),
+ record_defaults=self.csv_record_defaults,
+ field_delim=' ', use_quote_delim=False))
+
+ if self.repeat_endlessly:
+ dataset = dataset.repeat()
+ if self.estimator_action.is_train:
+ if not self.repeat_endlessly and self.config.NUM_TRAIN_EPOCHS > 1:
+ dataset = dataset.repeat(self.config.NUM_TRAIN_EPOCHS)
+ dataset = dataset.shuffle(self.config.SHUFFLE_BUFFER_SIZE, reshuffle_each_iteration=True)
+
+ dataset = dataset.map(self._map_raw_dataset_row_to_expected_model_input_form,
+ num_parallel_calls=self.config.READER_NUM_PARALLEL_BATCHES)
+ batch_size = self.config.batch_size(is_evaluating=self.estimator_action.is_evaluate)
+ if self.estimator_action.is_predict:
+ dataset = dataset.batch(1)
+ else:
+ dataset = dataset.filter(self._filter_input_rows)
+ dataset = dataset.batch(batch_size)
+
+ dataset = dataset.prefetch(buffer_size=40) # original: tf.contrib.data.AUTOTUNE) -- got OOM err; 10 seems promising.
+ return dataset
+
+ def _filter_input_rows(self, *row_parts) -> tf.bool:
+ row_parts = self.model_input_tensors_former.from_model_input_form(row_parts)
+
+ assert all(tensor.shape == (self.config.MAX_CONTEXTS,) for tensor in
+ {row_parts.path_source_token_indices, row_parts.path_indices,
+ row_parts.path_target_token_indices, row_parts.context_valid_mask})
+
+ # FIXME: Does "valid" here mean just "no padding" or "neither padding nor OOV"? I assumed just "no padding".
+ any_word_valid_mask_per_context_part = [
+ tf.not_equal(tf.reduce_max(row_parts.path_source_token_indices, axis=0),
+ self.vocabs.token_vocab.word_to_index[self.vocabs.token_vocab.special_words.PAD]),
+ tf.not_equal(tf.reduce_max(row_parts.path_target_token_indices, axis=0),
+ self.vocabs.token_vocab.word_to_index[self.vocabs.token_vocab.special_words.PAD]),
+ tf.not_equal(tf.reduce_max(row_parts.path_indices, axis=0),
+ self.vocabs.path_vocab.word_to_index[self.vocabs.path_vocab.special_words.PAD])]
+ any_contexts_is_valid = reduce(tf.logical_or, any_word_valid_mask_per_context_part) # scalar
+
+ if self.estimator_action.is_evaluate:
+ cond = any_contexts_is_valid # scalar
+ else: # training
+ word_is_valid = tf.greater(
+ row_parts.target_index, self.vocabs.target_vocab.word_to_index[self.vocabs.target_vocab.special_words.OOV]) # scalar
+ cond = tf.logical_and(word_is_valid, any_contexts_is_valid) # scalar
+
+ return cond # scalar
+
+ def _map_raw_dataset_row_to_expected_model_input_form(self, *row_parts) -> \
+ Tuple[Union[tf.Tensor, Tuple[tf.Tensor, ...], Dict[str, tf.Tensor]], ...]:
+ tensors = self._map_raw_dataset_row_to_input_tensors(*row_parts)
+ return self.model_input_tensors_former.to_model_input_form(tensors)
+
+ def _map_raw_dataset_row_to_input_tensors(self, *row_parts) -> ReaderInputTensors:
+ row_parts = list(row_parts)
+ target_str = row_parts[0]
+ target_index = self.vocabs.target_vocab.lookup_index(target_str)
+
+ contexts_str = tf.stack(row_parts[1:(self.config.MAX_CONTEXTS + 1)], axis=0)
+ split_contexts = tf.compat.v1.string_split(contexts_str, sep=',', skip_empty=False)
+ # dense_split_contexts = tf.sparse_tensor_to_dense(split_contexts, default_value=self.vocabs.token_vocab.special_words.PAD)
+ sparse_split_contexts = tf.sparse.SparseTensor(
+ indices=split_contexts.indices, values=split_contexts.values, dense_shape=[self.config.MAX_CONTEXTS, 3])
+ dense_split_contexts = tf.reshape(
+ tf.sparse.to_dense(sp_input=sparse_split_contexts, default_value=self.vocabs.token_vocab.special_words.PAD),
+ shape=[self.config.MAX_CONTEXTS, 3]) # (max_contexts, 3)
+
+ path_source_token_strings = tf.squeeze(
+ tf.slice(dense_split_contexts, begin=[0, 0], size=[self.config.MAX_CONTEXTS, 1]), axis=1) # (max_contexts,)
+ path_strings = tf.squeeze(
+ tf.slice(dense_split_contexts, begin=[0, 1], size=[self.config.MAX_CONTEXTS, 1]), axis=1) # (max_contexts,)
+ path_target_token_strings = tf.squeeze(
+ tf.slice(dense_split_contexts, begin=[0, 2], size=[self.config.MAX_CONTEXTS, 1]), axis=1) # (max_contexts,)
+
+ path_source_token_indices = self.vocabs.token_vocab.lookup_index(path_source_token_strings) # (max_contexts, )
+ path_indices = self.vocabs.path_vocab.lookup_index(path_strings) # (max_contexts, )
+ path_target_token_indices = self.vocabs.token_vocab.lookup_index(path_target_token_strings) # (max_contexts, )
+
+ # FIXME: Does "valid" here mean just "no padding" or "neither padding nor OOV"? I assumed just "no padding".
+ valid_word_mask_per_context_part = [
+ tf.not_equal(path_source_token_indices, self.vocabs.token_vocab.word_to_index[self.vocabs.token_vocab.special_words.PAD]),
+ tf.not_equal(path_target_token_indices, self.vocabs.token_vocab.word_to_index[self.vocabs.token_vocab.special_words.PAD]),
+ tf.not_equal(path_indices, self.vocabs.path_vocab.word_to_index[self.vocabs.path_vocab.special_words.PAD])] # [(max_contexts, )]
+ context_valid_mask = tf.cast(reduce(tf.logical_or, valid_word_mask_per_context_part), dtype=tf.float32) # (max_contexts, )
+
+ assert all(tensor.shape == (self.config.MAX_CONTEXTS,) for tensor in
+ {path_source_token_indices, path_indices, path_target_token_indices, context_valid_mask})
+
+ return ReaderInputTensors(
+ path_source_token_indices=path_source_token_indices,
+ path_indices=path_indices,
+ path_target_token_indices=path_target_token_indices,
+ context_valid_mask=context_valid_mask,
+ target_index=target_index,
+ target_string=target_str,
+ path_source_token_strings=path_source_token_strings,
+ path_strings=path_strings,
+ path_target_token_strings=path_target_token_strings
+ )
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..7350dc6
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,2 @@
+tensorflow==2.0.0-beta1
+numpy
diff --git a/tensorflow_model.py b/tensorflow_model.py
new file mode 100644
index 0000000..295ed0e
--- /dev/null
+++ b/tensorflow_model.py
@@ -0,0 +1,530 @@
+import tensorflow as tf
+import numpy as np
+import time
+from typing import Dict, Optional, List, Iterable
+from collections import Counter
+from functools import partial
+
+from path_context_reader import PathContextReader, ModelInputTensorsFormer, ReaderInputTensors, EstimatorAction
+from common import common
+from vocabularies import VocabType
+from config import Config
+from model_base import Code2VecModelBase, ModelEvaluationResults, ModelPredictionResults
+
+
+tf.compat.v1.disable_eager_execution()
+
+
+class Code2VecModel(Code2VecModelBase):
+ def __init__(self, config: Config):
+ self.sess = tf.compat.v1.Session()
+ self.saver = None
+
+ self.eval_reader = None
+ self.eval_input_iterator_reset_op = None
+ self.predict_reader = None
+
+ # self.eval_placeholder = None
+ self.predict_placeholder = None
+ self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op, self.eval_code_vectors = None, None, None, None
+ self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op = None, None, None
+
+ self.vocab_type_to_tf_variable_name_mapping: Dict[VocabType, str] = {
+ VocabType.Token: 'WORDS_VOCAB',
+ VocabType.Target: 'TARGET_WORDS_VOCAB',
+ VocabType.Path: 'PATHS_VOCAB'
+ }
+
+ super(Code2VecModel, self).__init__(config)
+
+ def train(self):
+ self.log('Starting training')
+ start_time = time.time()
+
+ batch_num = 0
+ sum_loss = 0
+ multi_batch_start_time = time.time()
+ num_batches_to_save_and_eval = max(int(self.config.train_steps_per_epoch * self.config.SAVE_EVERY_EPOCHS), 1)
+
+ train_reader = PathContextReader(vocabs=self.vocabs,
+ model_input_tensors_former=_TFTrainModelInputTensorsFormer(),
+ config=self.config, estimator_action=EstimatorAction.Train)
+ input_iterator = tf.compat.v1.data.make_initializable_iterator(train_reader.get_dataset())
+ input_iterator_reset_op = input_iterator.initializer
+ input_tensors = input_iterator.get_next()
+
+ optimizer, train_loss = self._build_tf_training_graph(input_tensors)
+ self.saver = tf.compat.v1.train.Saver(max_to_keep=self.config.MAX_TO_KEEP)
+
+ self.log('Number of trainable params: {}'.format(
+ np.sum([np.prod(v.get_shape().as_list()) for v in tf.compat.v1.trainable_variables()])))
+ for variable in tf.compat.v1.trainable_variables():
+ self.log("variable name: {} -- shape: {} -- #params: {}".format(
+ variable.name, variable.get_shape(), np.prod(variable.get_shape().as_list())))
+
+ self._initialize_session_variables()
+
+ if self.config.MODEL_LOAD_PATH:
+ self._load_inner_model(self.sess)
+
+ self.sess.run(input_iterator_reset_op)
+ time.sleep(1)
+ self.log('Started reader...')
+ # run evaluation in a loop until iterator is exhausted.
+ try:
+ while True:
+ # Each iteration = batch. We iterate as long as the tf iterator (reader) yields batches.
+ batch_num += 1
+
+ # Actual training for the current batch.
+ _, batch_loss = self.sess.run([optimizer, train_loss])
+
+ sum_loss += batch_loss
+ if batch_num % self.config.NUM_BATCHES_TO_LOG_PROGRESS == 0:
+ self._trace_training(sum_loss, batch_num, multi_batch_start_time)
+ self.log('Number of waiting examples in queue: %d' % self.sess.run(
+ "shuffle_batch/random_shuffle_queue_Size:0"))
+ sum_loss = 0
+ multi_batch_start_time = time.time()
+ if batch_num % num_batches_to_save_and_eval == 0:
+ epoch_num = int((batch_num / num_batches_to_save_and_eval) * self.config.SAVE_EVERY_EPOCHS)
+ save_path = self.config.MODEL_SAVE_PATH + '_iter' + str(epoch_num)
+ self._save_inner_model(save_path)
+ self.log('Saved after %d epochs in: %s' % (epoch_num, save_path))
+ evaluation_results = self.evaluate()
+ evaluation_results_str = (str(evaluation_results).replace('topk', 'top{}'.format(
+ self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION)))
+ self.log('After {nr_epochs} epochs -- {evaluation_results}'.format(
+ nr_epochs=epoch_num,
+ evaluation_results=evaluation_results_str
+ ))
+ except tf.errors.OutOfRangeError:
+ pass # The reader iterator is exhausted and have no more batches to produce.
+
+ self.log('Done training')
+
+ if self.config.MODEL_SAVE_PATH:
+ self._save_inner_model(self.config.MODEL_SAVE_PATH)
+ self.log('Model saved in file: %s' % self.config.MODEL_SAVE_PATH)
+
+ elapsed = int(time.time() - start_time)
+ self.log("Training time: %sH:%sM:%sS\n" % ((elapsed // 60 // 60), (elapsed // 60) % 60, elapsed % 60))
+
+ def evaluate(self) -> Optional[ModelEvaluationResults]:
+ eval_start_time = time.time()
+ if self.eval_reader is None:
+ self.eval_reader = PathContextReader(vocabs=self.vocabs,
+ model_input_tensors_former=_TFEvaluateModelInputTensorsFormer(),
+ config=self.config, estimator_action=EstimatorAction.Evaluate)
+ input_iterator = tf.compat.v1.data.make_initializable_iterator(self.eval_reader.get_dataset())
+ self.eval_input_iterator_reset_op = input_iterator.initializer
+ input_tensors = input_iterator.get_next()
+
+ self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op, _, _, _, _, \
+ self.eval_code_vectors = self._build_tf_test_graph(input_tensors)
+ self.saver = tf.compat.v1.train.Saver()
+
+ if self.config.MODEL_LOAD_PATH and not self.config.TRAIN_DATA_PATH_PREFIX:
+ self._initialize_session_variables()
+ self._load_inner_model(self.sess)
+ if self.config.RELEASE:
+ release_name = self.config.MODEL_LOAD_PATH + '.release'
+ self.log('Releasing model, output model: %s' % release_name)
+ self.saver.save(self.sess, release_name)
+ return None # FIXME: why do we return none here?
+
+ with open('log.txt', 'w') as log_output_file:
+ if self.config.EXPORT_CODE_VECTORS:
+ code_vectors_file = open(self.config.TEST_DATA_PATH + '.vectors', 'w')
+ total_predictions = 0
+ total_prediction_batches = 0
+ subtokens_evaluation_metric = SubtokensEvaluationMetric(
+ partial(common.filter_impossible_names, self.vocabs.target_vocab.special_words))
+ topk_accuracy_evaluation_metric = TopKAccuracyEvaluationMetric(
+ self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION,
+ partial(common.get_first_match_word_from_top_predictions, self.vocabs.target_vocab.special_words))
+ start_time = time.time()
+
+ self.sess.run(self.eval_input_iterator_reset_op)
+
+ self.log('Starting evaluation')
+
+ # Run evaluation in a loop until iterator is exhausted.
+ # Each iteration = batch. We iterate as long as the tf iterator (reader) yields batches.
+ try:
+ while True:
+ top_words, top_scores, original_names, code_vectors = self.sess.run(
+ [self.eval_top_words_op, self.eval_top_values_op,
+ self.eval_original_names_op, self.eval_code_vectors],
+ )
+
+ # shapes:
+ # top_words: (batch, top_k); top_scores: (batch, top_k)
+ # original_names: (batch, ); code_vectors: (batch, code_vector_size)
+
+ top_words = common.binary_to_string_matrix(top_words) # (batch, top_k)
+ original_names = common.binary_to_string_list(original_names) # (batch,)
+
+ self._log_predictions_during_evaluation(zip(original_names, top_words), log_output_file)
+ topk_accuracy_evaluation_metric.update_batch(zip(original_names, top_words))
+ subtokens_evaluation_metric.update_batch(zip(original_names, top_words))
+
+ total_predictions += len(original_names)
+ total_prediction_batches += 1
+ if self.config.EXPORT_CODE_VECTORS:
+ self._write_code_vectors(code_vectors_file, code_vectors)
+ if total_prediction_batches % self.config.NUM_BATCHES_TO_LOG_PROGRESS == 0:
+ elapsed = time.time() - start_time
+ # start_time = time.time()
+ self._trace_evaluation(total_predictions, elapsed)
+ except tf.errors.OutOfRangeError:
+ pass # reader iterator is exhausted and have no more batches to produce.
+ self.log('Done evaluating, epoch reached')
+ log_output_file.write(str(topk_accuracy_evaluation_metric.topk_correct_predictions) + '\n')
+ if self.config.EXPORT_CODE_VECTORS:
+ code_vectors_file.close()
+
+ elapsed = int(time.time() - eval_start_time)
+ self.log("Evaluation time: %sH:%sM:%sS" % ((elapsed // 60 // 60), (elapsed // 60) % 60, elapsed % 60))
+ return ModelEvaluationResults(
+ topk_acc=topk_accuracy_evaluation_metric.topk_correct_predictions,
+ subtoken_precision=subtokens_evaluation_metric.precision,
+ subtoken_recall=subtokens_evaluation_metric.recall,
+ subtoken_f1=subtokens_evaluation_metric.f1)
+
+ def _build_tf_training_graph(self, input_tensors):
+ # Use `_TFTrainModelInputTensorsFormer` to access input tensors by name.
+ input_tensors = _TFTrainModelInputTensorsFormer().from_model_input_form(input_tensors)
+ # shape of (batch, 1) for input_tensors.target_index
+ # shape of (batch, max_contexts) for others:
+ # input_tensors.path_source_token_indices, input_tensors.path_indices,
+ # input_tensors.path_target_token_indices, input_tensors.context_valid_mask
+
+ with tf.compat.v1.variable_scope('model'):
+ tokens_vocab = tf.compat.v1.get_variable(
+ self.vocab_type_to_tf_variable_name_mapping[VocabType.Token],
+ shape=(self.vocabs.token_vocab.size, self.config.TOKEN_EMBEDDINGS_SIZE), dtype=tf.float32,
+ initializer=tf.compat.v1.initializers.variance_scaling(scale=1.0, mode='fan_out', distribution="uniform"))
+ targets_vocab = tf.compat.v1.get_variable(
+ self.vocab_type_to_tf_variable_name_mapping[VocabType.Target],
+ shape=(self.vocabs.target_vocab.size, self.config.TARGET_EMBEDDINGS_SIZE), dtype=tf.float32,
+ initializer=tf.compat.v1.initializers.variance_scaling(scale=1.0, mode='fan_out', distribution="uniform"))
+ attention_param = tf.compat.v1.get_variable(
+ 'ATTENTION',
+ shape=(self.config.CODE_VECTOR_SIZE, 1), dtype=tf.float32)
+ paths_vocab = tf.compat.v1.get_variable(
+ self.vocab_type_to_tf_variable_name_mapping[VocabType.Path],
+ shape=(self.vocabs.path_vocab.size, self.config.PATH_EMBEDDINGS_SIZE), dtype=tf.float32,
+ initializer=tf.compat.v1.initializers.variance_scaling(scale=1.0, mode='fan_out', distribution="uniform"))
+
+ code_vectors, _ = self._calculate_weighted_contexts(
+ tokens_vocab, paths_vocab, attention_param, input_tensors.path_source_token_indices,
+ input_tensors.path_indices, input_tensors.path_target_token_indices, input_tensors.context_valid_mask)
+
+ logits = tf.matmul(code_vectors, targets_vocab, transpose_b=True)
+ batch_size = tf.cast(tf.shape(input_tensors.target_index)[0], dtype=tf.float32)
+ loss = tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=tf.reshape(input_tensors.target_index, [-1]),
+ logits=logits)) / batch_size
+
+ optimizer = tf.compat.v1.train.AdamOptimizer().minimize(loss)
+
+ return optimizer, loss
+
+ def _calculate_weighted_contexts(self, tokens_vocab, paths_vocab, attention_param, source_input, path_input,
+ target_input, valid_mask, is_evaluating=False):
+ source_word_embed = tf.nn.embedding_lookup(params=tokens_vocab, ids=source_input) # (batch, max_contexts, dim)
+ path_embed = tf.nn.embedding_lookup(params=paths_vocab, ids=path_input) # (batch, max_contexts, dim)
+ target_word_embed = tf.nn.embedding_lookup(params=tokens_vocab, ids=target_input) # (batch, max_contexts, dim)
+
+ context_embed = tf.concat([source_word_embed, path_embed, target_word_embed],
+ axis=-1) # (batch, max_contexts, dim * 3)
+
+ if not is_evaluating:
+ context_embed = tf.nn.dropout(context_embed, rate=1-self.config.DROPOUT_KEEP_RATE)
+
+ flat_embed = tf.reshape(context_embed, [-1, self.config.context_vector_size]) # (batch * max_contexts, dim * 3)
+ transform_param = tf.compat.v1.get_variable(
+ 'TRANSFORM', shape=(self.config.context_vector_size, self.config.CODE_VECTOR_SIZE), dtype=tf.float32)
+
+ flat_embed = tf.tanh(tf.matmul(flat_embed, transform_param)) # (batch * max_contexts, dim * 3)
+
+ contexts_weights = tf.matmul(flat_embed, attention_param) # (batch * max_contexts, 1)
+ batched_contexts_weights = tf.reshape(
+ contexts_weights, [-1, self.config.MAX_CONTEXTS, 1]) # (batch, max_contexts, 1)
+ mask = tf.math.log(valid_mask) # (batch, max_contexts)
+ mask = tf.expand_dims(mask, axis=2) # (batch, max_contexts, 1)
+ batched_contexts_weights += mask # (batch, max_contexts, 1)
+ attention_weights = tf.nn.softmax(batched_contexts_weights, axis=1) # (batch, max_contexts, 1)
+
+ batched_embed = tf.reshape(flat_embed, shape=[-1, self.config.MAX_CONTEXTS, self.config.CODE_VECTOR_SIZE])
+ code_vectors = tf.reduce_sum(tf.multiply(batched_embed, attention_weights), axis=1) # (batch, dim * 3)
+
+ return code_vectors, attention_weights
+
+ def _build_tf_test_graph(self, input_tensors, normalize_scores=False):
+ with tf.compat.v1.variable_scope('model', reuse=self.get_should_reuse_variables()):
+ tokens_vocab = tf.compat.v1.get_variable(
+ self.vocab_type_to_tf_variable_name_mapping[VocabType.Token],
+ shape=(self.vocabs.token_vocab.size, self.config.TOKEN_EMBEDDINGS_SIZE),
+ dtype=tf.float32, trainable=False)
+ targets_vocab = tf.compat.v1.get_variable(
+ self.vocab_type_to_tf_variable_name_mapping[VocabType.Target],
+ shape=(self.vocabs.target_vocab.size, self.config.TARGET_EMBEDDINGS_SIZE),
+ dtype=tf.float32, trainable=False)
+ attention_param = tf.compat.v1.get_variable(
+ 'ATTENTION', shape=(self.config.context_vector_size, 1),
+ dtype=tf.float32, trainable=False)
+ paths_vocab = tf.compat.v1.get_variable(
+ self.vocab_type_to_tf_variable_name_mapping[VocabType.Path],
+ shape=(self.vocabs.path_vocab.size, self.config.PATH_EMBEDDINGS_SIZE),
+ dtype=tf.float32, trainable=False)
+
+ targets_vocab = tf.transpose(targets_vocab) # (dim * 3, target_word_vocab)
+
+ # Use `_TFEvaluateModelInputTensorsFormer` to access input tensors by name.
+ input_tensors = _TFEvaluateModelInputTensorsFormer().from_model_input_form(input_tensors)
+ # shape of (batch, 1) for input_tensors.target_string
+ # shape of (batch, max_contexts) for the other tensors
+
+ code_vectors, attention_weights = self._calculate_weighted_contexts(
+ tokens_vocab, paths_vocab, attention_param, input_tensors.path_source_token_indices,
+ input_tensors.path_indices, input_tensors.path_target_token_indices,
+ input_tensors.context_valid_mask, is_evaluating=True)
+
+ scores = tf.matmul(code_vectors, targets_vocab) # (batch, target_word_vocab)
+
+ topk_candidates = tf.nn.top_k(scores, k=tf.minimum(
+ self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION, self.vocabs.target_vocab.size))
+ top_indices = topk_candidates.indices
+ top_words = self.vocabs.target_vocab.lookup_word(top_indices)
+ original_words = input_tensors.target_string
+ top_scores = topk_candidates.values
+ if normalize_scores:
+ top_scores = tf.nn.softmax(top_scores)
+
+ return top_words, top_scores, original_words, attention_weights, input_tensors.path_source_token_strings, \
+ input_tensors.path_strings, input_tensors.path_target_token_strings, code_vectors
+
+ def predict(self, predict_data_lines: Iterable[str]) -> List[ModelPredictionResults]:
+ if self.predict_reader is None:
+ self.predict_reader = PathContextReader(vocabs=self.vocabs,
+ model_input_tensors_former=_TFEvaluateModelInputTensorsFormer(),
+ config=self.config, estimator_action=EstimatorAction.Predict)
+ self.predict_placeholder = tf.compat.v1.placeholder(tf.string)
+ reader_output = self.predict_reader.process_input_row(self.predict_placeholder)
+
+ self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op, \
+ self.attention_weights_op, self.predict_source_string, self.predict_path_string, \
+ self.predict_path_target_string, self.predict_code_vectors = \
+ self._build_tf_test_graph(reader_output, normalize_scores=True)
+
+ self._initialize_session_variables()
+ self.saver = tf.compat.v1.train.Saver()
+ self._load_inner_model(sess=self.sess)
+
+ prediction_results: List[ModelPredictionResults] = []
+ for line in predict_data_lines:
+ batch_top_words, batch_top_scores, batch_original_name, batch_attention_weights, batch_path_source_strings,\
+ batch_path_strings, batch_path_target_strings, batch_code_vectors = self.sess.run(
+ [self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op,
+ self.attention_weights_op, self.predict_source_string, self.predict_path_string,
+ self.predict_path_target_string, self.predict_code_vectors],
+ feed_dict={self.predict_placeholder: line})
+ # shapes:
+ # batch_top_words, top_scores: (batch, top_k)
+ # batch_original_name: (batch, )
+ # batch_attention_weights: (batch, max_context, 1)
+ # batch_path_source_strings, batch_path_strings, batch_path_target_strings: (batch, max_context)
+ # batch_code_vectors: (batch, code_vector_size)
+
+ # remove first axis: (batch=1, ...)
+ assert all(tensor.shape[0] == 1 for tensor in (batch_top_words, batch_top_scores, batch_original_name,
+ batch_attention_weights, batch_path_source_strings,
+ batch_path_strings, batch_path_target_strings,
+ batch_code_vectors))
+ top_words = np.squeeze(batch_top_words, axis=0)
+ top_scores = np.squeeze(batch_top_scores, axis=0)
+ original_name = batch_original_name[0]
+ attention_weights = np.squeeze(batch_attention_weights, axis=0)
+ path_source_strings = np.squeeze(batch_path_source_strings, axis=0)
+ path_strings = np.squeeze(batch_path_strings, axis=0)
+ path_target_strings = np.squeeze(batch_path_target_strings, axis=0)
+ code_vectors = np.squeeze(batch_code_vectors, axis=0)
+
+ top_words = common.binary_to_string_list(top_words)
+ original_name = common.binary_to_string(original_name)
+ attention_per_context = self._get_attention_weight_per_context(
+ path_source_strings, path_strings, path_target_strings, attention_weights)
+ prediction_results.append(ModelPredictionResults(
+ original_name=original_name,
+ topk_predicted_words=top_words,
+ topk_predicted_words_scores=top_scores,
+ attention_per_context=attention_per_context,
+ code_vector=(code_vectors if self.config.EXPORT_CODE_VECTORS else None)
+ ))
+ return prediction_results
+
+ def _save_inner_model(self, path: str):
+ self.saver.save(self.sess, path)
+
+ def _load_inner_model(self, sess=None):
+ if sess is not None:
+ self.log('Loading model weights from: ' + self.config.MODEL_LOAD_PATH)
+ self.saver.restore(sess, self.config.MODEL_LOAD_PATH)
+ self.log('Done loading model weights')
+
+ def _get_vocab_embedding_as_np_array(self, vocab_type: VocabType) -> np.ndarray:
+ assert vocab_type in VocabType
+ vocab_tf_variable_name = self.vocab_type_to_tf_variable_name_mapping[vocab_type]
+ with tf.compat.v1.variable_scope('model', reuse=None):
+ embeddings = tf.compat.v1.get_variable(vocab_tf_variable_name)
+ self.saver = tf.compat.v1.train.Saver()
+ self._load_inner_model(self.sess)
+ vocab_embedding_matrix = self.sess.run(embeddings)
+ return vocab_embedding_matrix
+
+ def get_should_reuse_variables(self):
+ if self.config.TRAIN_DATA_PATH_PREFIX:
+ return True
+ else:
+ return None
+
+ def _log_predictions_during_evaluation(self, results, output_file):
+ for original_name, top_predicted_words in results:
+ found_match = common.get_first_match_word_from_top_predictions(
+ self.vocabs.target_vocab.special_words, original_name, top_predicted_words)
+ if found_match is not None:
+ prediction_idx, predicted_word = found_match
+ if prediction_idx == 0:
+ output_file.write('Original: ' + original_name + ', predicted 1st: ' + predicted_word + '\n')
+ else:
+ output_file.write('\t\t predicted correctly at rank: ' + str(prediction_idx + 1) + '\n')
+ else:
+ output_file.write('No results for predicting: ' + original_name)
+
+ def _trace_training(self, sum_loss, batch_num, multi_batch_start_time):
+ multi_batch_elapsed = time.time() - multi_batch_start_time
+ avg_loss = sum_loss / (self.config.NUM_BATCHES_TO_LOG_PROGRESS * self.config.TRAIN_BATCH_SIZE)
+ throughput = self.config.TRAIN_BATCH_SIZE * self.config.NUM_BATCHES_TO_LOG_PROGRESS / \
+ (multi_batch_elapsed if multi_batch_elapsed > 0 else 1)
+ self.log('Average loss at batch %d: %f, \tthroughput: %d samples/sec' % (
+ batch_num, avg_loss, throughput))
+
+ def _trace_evaluation(self, total_predictions, elapsed):
+ state_message = 'Evaluated %d examples...' % total_predictions
+ throughput_message = "Prediction throughput: %d samples/sec" % int(
+ total_predictions / (elapsed if elapsed > 0 else 1))
+ self.log(state_message)
+ self.log(throughput_message)
+
+ def close_session(self):
+ self.sess.close()
+
+ def _initialize_session_variables(self):
+ self.sess.run(tf.group(
+ tf.compat.v1.global_variables_initializer(),
+ tf.compat.v1.local_variables_initializer(),
+ tf.compat.v1.tables_initializer()))
+ self.log('Initalized variables')
+
+
+class SubtokensEvaluationMetric:
+ def __init__(self, filter_impossible_names_fn):
+ self.nr_true_positives: int = 0
+ self.nr_false_positives: int = 0
+ self.nr_false_negatives: int = 0
+ self.nr_predictions: int = 0
+ self.filter_impossible_names_fn = filter_impossible_names_fn
+
+ def update_batch(self, results):
+ for original_name, top_words in results:
+ prediction = self.filter_impossible_names_fn(top_words)[0]
+ original_subtokens = Counter(common.get_subtokens(original_name))
+ predicted_subtokens = Counter(common.get_subtokens(prediction))
+ self.nr_true_positives += sum(count for element, count in predicted_subtokens.items()
+ if element in original_subtokens)
+ self.nr_false_positives += sum(count for element, count in predicted_subtokens.items()
+ if element not in original_subtokens)
+ self.nr_false_negatives += sum(count for element, count in original_subtokens.items()
+ if element not in predicted_subtokens)
+ self.nr_predictions += 1
+
+ @property
+ def true_positive(self):
+ return self.nr_true_positives / self.nr_predictions
+
+ @property
+ def false_positive(self):
+ return self.nr_false_positives / self.nr_predictions
+
+ @property
+ def false_negative(self):
+ return self.nr_false_negatives / self.nr_predictions
+
+ @property
+ def precision(self):
+ return self.nr_true_positives / (self.nr_true_positives + self.nr_false_positives)
+
+ @property
+ def recall(self):
+ return self.nr_true_positives / (self.nr_true_positives + self.nr_false_negatives)
+
+ @property
+ def f1(self):
+ return 2 * self.precision * self.recall / (self.precision + self.recall)
+
+
+class TopKAccuracyEvaluationMetric:
+ def __init__(self, top_k: int, get_first_match_word_from_top_predictions_fn):
+ self.top_k = top_k
+ self.nr_correct_predictions = np.zeros(self.top_k)
+ self.nr_predictions: int = 0
+ self.get_first_match_word_from_top_predictions_fn = get_first_match_word_from_top_predictions_fn
+
+ def update_batch(self, results):
+ for original_name, top_predicted_words in results:
+ self.nr_predictions += 1
+ found_match = self.get_first_match_word_from_top_predictions_fn(original_name, top_predicted_words)
+ if found_match is not None:
+ suggestion_idx, _ = found_match
+ self.nr_correct_predictions[suggestion_idx:self.top_k] += 1
+
+ @property
+ def topk_correct_predictions(self):
+ return self.nr_correct_predictions / self.nr_predictions
+
+
+class _TFTrainModelInputTensorsFormer(ModelInputTensorsFormer):
+ def to_model_input_form(self, input_tensors: ReaderInputTensors):
+ return input_tensors.target_index, input_tensors.path_source_token_indices, input_tensors.path_indices, \
+ input_tensors.path_target_token_indices, input_tensors.context_valid_mask
+
+ def from_model_input_form(self, input_row) -> ReaderInputTensors:
+ return ReaderInputTensors(
+ target_index=input_row[0],
+ path_source_token_indices=input_row[1],
+ path_indices=input_row[2],
+ path_target_token_indices=input_row[3],
+ context_valid_mask=input_row[4]
+ )
+
+
+class _TFEvaluateModelInputTensorsFormer(ModelInputTensorsFormer):
+ def to_model_input_form(self, input_tensors: ReaderInputTensors):
+ return input_tensors.target_string, input_tensors.path_source_token_indices, input_tensors.path_indices, \
+ input_tensors.path_target_token_indices, input_tensors.context_valid_mask, \
+ input_tensors.path_source_token_strings, input_tensors.path_strings, \
+ input_tensors.path_target_token_strings
+
+ def from_model_input_form(self, input_row) -> ReaderInputTensors:
+ return ReaderInputTensors(
+ target_string=input_row[0],
+ path_source_token_indices=input_row[1],
+ path_indices=input_row[2],
+ path_target_token_indices=input_row[3],
+ context_valid_mask=input_row[4],
+ path_source_token_strings=input_row[5],
+ path_strings=input_row[6],
+ path_target_token_strings=input_row[7]
+ )
diff --git a/train.sh b/train.sh
index 3cc567b..00205d1 100644
--- a/train.sh
+++ b/train.sh
@@ -1,3 +1,4 @@
+#!/usr/bin/env bash
###########################################################
# Change the following values to train a new model.
# type: the name of the new model, only affects the saved file name.
@@ -14,4 +15,4 @@ model_dir=models/${type}
mkdir -p models/${model_dir}
set -e
-python3 -u code2vec.py --data ${data} --test ${test_data} --save ${model_dir}/saved_model
\ No newline at end of file
+python3 -u code2vec.py --data ${data} --test ${test_data} --save ${model_dir}/saved_model
diff --git a/vocabularies.py b/vocabularies.py
new file mode 100644
index 0000000..0a86365
--- /dev/null
+++ b/vocabularies.py
@@ -0,0 +1,238 @@
+from itertools import chain
+from typing import Optional, Dict, Iterable, Set, NamedTuple
+import pickle
+import os
+from enum import Enum
+from config import Config
+import tensorflow as tf
+from argparse import Namespace
+
+from common import common
+
+
+class VocabType(Enum):
+ Token = 1
+ Target = 2
+ Path = 3
+
+
+SpecialVocabWordsType = Namespace
+
+
+_SpecialVocabWords_OnlyOov = Namespace(
+ OOV=''
+)
+
+_SpecialVocabWords_SeparateOovPad = Namespace(
+ PAD='',
+ OOV=''
+)
+
+_SpecialVocabWords_JoinedOovPad = Namespace(
+ PAD_OR_OOV='',
+ PAD='',
+ OOV=''
+)
+
+
+class Vocab:
+ def __init__(self, vocab_type: VocabType, words: Iterable[str],
+ special_words: Optional[SpecialVocabWordsType] = None):
+ if special_words is None:
+ special_words = Namespace()
+
+ self.vocab_type = vocab_type
+ self.word_to_index: Dict[str, int] = {}
+ self.index_to_word: Dict[int, str] = {}
+ self._word_to_index_lookup_table = None
+ self._index_to_word_lookup_table = None
+ self.special_words: SpecialVocabWordsType = special_words
+
+ for index, word in enumerate(chain(common.get_unique_list(special_words.__dict__.values()), words)):
+ self.word_to_index[word] = index
+ self.index_to_word[index] = word
+
+ self.size = len(self.word_to_index)
+
+ def save_to_file(self, file):
+ # Notice: From historical reasons, a saved vocab doesn't include special words.
+ special_words_as_unique_list = common.get_unique_list(self.special_words.__dict__.values())
+ nr_special_words = len(special_words_as_unique_list)
+ word_to_index_wo_specials = {word: idx for word, idx in self.word_to_index.items() if idx >= nr_special_words}
+ index_to_word_wo_specials = {idx: word for idx, word in self.index_to_word.items() if idx >= nr_special_words}
+ size_wo_specials = self.size - nr_special_words
+ pickle.dump(word_to_index_wo_specials, file)
+ pickle.dump(index_to_word_wo_specials, file)
+ pickle.dump(size_wo_specials, file)
+
+ @classmethod
+ def load_from_file(cls, vocab_type: VocabType, file, special_words: SpecialVocabWordsType) -> 'Vocab':
+ special_words_as_unique_list = common.get_unique_list(special_words.__dict__.values())
+
+ # Notice: From historical reasons, a saved vocab doesn't include special words,
+ # so they should be added upon loading.
+
+ word_to_index_wo_specials = pickle.load(file)
+ index_to_word_wo_specials = pickle.load(file)
+ size_wo_specials = pickle.load(file)
+ assert len(index_to_word_wo_specials) == len(word_to_index_wo_specials) == size_wo_specials
+ min_word_idx_wo_specials = min(index_to_word_wo_specials.keys())
+
+ if min_word_idx_wo_specials != len(special_words_as_unique_list):
+ raise ValueError(
+ "Error while attempting to load vocabulary `{vocab_type}` from file `{file_path}`. "
+ "The stored vocabulary has minimum word index {min_word_idx}, "
+ "while expecting minimum word index to be {nr_special_words} "
+ "because having to use {nr_special_words} special words, which are: {special_words}. "
+ "Please check the parameter `config.SEPARATE_OOV_AND_PAD`.".format(
+ vocab_type=vocab_type, file_path=file.name, min_word_idx=min_word_idx_wo_specials,
+ nr_special_words=len(special_words_as_unique_list), special_words=special_words))
+
+ vocab = cls(vocab_type, [], special_words)
+ vocab.word_to_index = {**word_to_index_wo_specials,
+ **{word: idx for idx, word in enumerate(special_words_as_unique_list)}}
+ vocab.index_to_word = {**index_to_word_wo_specials,
+ **{idx: word for idx, word in enumerate(special_words_as_unique_list)}}
+ vocab.size = size_wo_specials + len(special_words_as_unique_list)
+ return vocab
+
+ @classmethod
+ def create_from_freq_dict(cls, vocab_type: VocabType, word_to_count: Dict[str, int], max_size: int,
+ special_words: Optional[SpecialVocabWordsType] = None):
+ if special_words is None:
+ special_words = Namespace()
+ words_sorted_by_counts = sorted(word_to_count, key=word_to_count.get, reverse=True)
+ words_sorted_by_counts_and_limited = words_sorted_by_counts[:max_size]
+ return cls(vocab_type, words_sorted_by_counts_and_limited, special_words)
+
+ @staticmethod
+ def _create_word_to_index_lookup_table(word_to_index: Dict[str, int], default_value: int):
+ return tf.lookup.StaticHashTable(
+ tf.lookup.KeyValueTensorInitializer(
+ list(word_to_index.keys()), list(word_to_index.values()), key_dtype=tf.string, value_dtype=tf.int32),
+ default_value=tf.constant(default_value, dtype=tf.int32))
+
+ @staticmethod
+ def _create_index_to_word_lookup_table(index_to_word: Dict[int, str], default_value: str) \
+ -> tf.lookup.StaticHashTable:
+ return tf.lookup.StaticHashTable(
+ tf.lookup.KeyValueTensorInitializer(
+ list(index_to_word.keys()), list(index_to_word.values()), key_dtype=tf.int32, value_dtype=tf.string),
+ default_value=tf.constant(default_value, dtype=tf.string))
+
+ def get_word_to_index_lookup_table(self) -> tf.lookup.StaticHashTable:
+ if self._word_to_index_lookup_table is None:
+ self._word_to_index_lookup_table = self._create_word_to_index_lookup_table(
+ self.word_to_index, default_value=self.word_to_index[self.special_words.OOV])
+ return self._word_to_index_lookup_table
+
+ def get_index_to_word_lookup_table(self) -> tf.lookup.StaticHashTable:
+ if self._index_to_word_lookup_table is None:
+ self._index_to_word_lookup_table = self._create_index_to_word_lookup_table(
+ self.index_to_word, default_value=self.special_words.OOV)
+ return self._index_to_word_lookup_table
+
+ def lookup_index(self, word: tf.Tensor) -> tf.Tensor:
+ return self.get_word_to_index_lookup_table().lookup(word)
+
+ def lookup_word(self, index: tf.Tensor) -> tf.Tensor:
+ return self.get_index_to_word_lookup_table().lookup(index)
+
+
+WordFreqDictType = Dict[str, int]
+
+
+class Code2VecWordFreqDicts(NamedTuple):
+ token_to_count: WordFreqDictType
+ path_to_count: WordFreqDictType
+ target_to_count: WordFreqDictType
+
+
+class Code2VecVocabs:
+ def __init__(self, config: Config):
+ self.config = config
+ self.token_vocab: Optional[Vocab] = None
+ self.path_vocab: Optional[Vocab] = None
+ self.target_vocab: Optional[Vocab] = None
+
+ # Used to avoid re-saving a non-modified vocabulary to a path it is already saved in.
+ self._already_saved_in_paths: Set[str] = set()
+
+ self._load_or_create()
+
+ def _load_or_create(self):
+ vocabularies_load_path = None
+ if not self.config.is_training or self.config.is_loading:
+ vocabularies_load_path = self.config.get_vocabularies_path_from_model_path(self.config.MODEL_LOAD_PATH)
+ if not os.path.isfile(vocabularies_load_path):
+ vocabularies_load_path = None
+ if vocabularies_load_path is None:
+ self._create_from_word_freq_dict()
+ else:
+ self._load_from_path(vocabularies_load_path)
+
+ def _load_from_path(self, vocabularies_load_path: str):
+ assert os.path.exists(vocabularies_load_path)
+ self.config.log('Loading model vocabularies from: `%s` ... ' % vocabularies_load_path)
+ with open(vocabularies_load_path, 'rb') as file:
+ self.token_vocab = Vocab.load_from_file(
+ VocabType.Token, file, self._get_special_words_by_vocab_type(VocabType.Token))
+ self.target_vocab = Vocab.load_from_file(
+ VocabType.Target, file, self._get_special_words_by_vocab_type(VocabType.Target))
+ self.path_vocab = Vocab.load_from_file(
+ VocabType.Path, file, self._get_special_words_by_vocab_type(VocabType.Path))
+ self.config.log('Done loading model vocabularies.')
+ self._already_saved_in_paths.add(vocabularies_load_path)
+
+ def _create_from_word_freq_dict(self):
+ word_freq_dict = self._load_word_freq_dict()
+ self.config.log('Word frequencies dictionaries loaded. Now creating vocabularies.')
+ self.token_vocab = Vocab.create_from_freq_dict(
+ VocabType.Token, word_freq_dict.token_to_count, self.config.MAX_TOKEN_VOCAB_SIZE,
+ special_words=self._get_special_words_by_vocab_type(VocabType.Token))
+ self.config.log('Created token vocab. size: %d' % self.token_vocab.size)
+ self.path_vocab = Vocab.create_from_freq_dict(
+ VocabType.Path, word_freq_dict.path_to_count, self.config.MAX_PATH_VOCAB_SIZE,
+ special_words=self._get_special_words_by_vocab_type(VocabType.Path))
+ self.config.log('Created path vocab. size: %d' % self.path_vocab.size)
+ self.target_vocab = Vocab.create_from_freq_dict(
+ VocabType.Target, word_freq_dict.target_to_count, self.config.MAX_TARGET_VOCAB_SIZE,
+ special_words=self._get_special_words_by_vocab_type(VocabType.Target))
+ self.config.log('Created target vocab. size: %d' % self.target_vocab.size)
+
+ def _get_special_words_by_vocab_type(self, vocab_type: VocabType) -> SpecialVocabWordsType:
+ if not self.config.SEPARATE_OOV_AND_PAD:
+ return _SpecialVocabWords_JoinedOovPad
+ if vocab_type == VocabType.Target:
+ return _SpecialVocabWords_OnlyOov
+ return _SpecialVocabWords_SeparateOovPad
+
+ def save(self, vocabularies_save_path: str):
+ if vocabularies_save_path in self._already_saved_in_paths:
+ return
+ with open(vocabularies_save_path, 'wb') as file:
+ self.token_vocab.save_to_file(file)
+ self.target_vocab.save_to_file(file)
+ self.path_vocab.save_to_file(file)
+ self._already_saved_in_paths.add(vocabularies_save_path)
+
+ def _load_word_freq_dict(self) -> Code2VecWordFreqDicts:
+ self.config.log('Loading word frequencies dictionaries from: %s ... ' % self.config.word_freq_dict_path)
+ with open(self.config.word_freq_dict_path, 'rb') as file:
+ token_to_count = pickle.load(file)
+ path_to_count = pickle.load(file)
+ target_to_count = pickle.load(file)
+ self.config.log('Done loading word frequencies dictionaries.')
+ # assert all(isinstance(item, WordFreqDictType) for item in {token_to_count, path_to_count, target_to_count})
+ return Code2VecWordFreqDicts(
+ token_to_count=token_to_count, path_to_count=path_to_count, target_to_count=target_to_count)
+
+ def get(self, vocab_type: VocabType) -> Vocab:
+ if not isinstance(vocab_type, VocabType):
+ raise ValueError('`vocab_type` should be `VocabType.Token`, `VocabType.Target` or `VocabType.Path`.')
+ if vocab_type == VocabType.Token:
+ return self.token_vocab
+ if vocab_type == VocabType.Target:
+ return self.target_vocab
+ if vocab_type == VocabType.Path:
+ return self.path_vocab