Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

return length of the Java embedTokens() method #2

Open
kkalouli opened this issue Jun 24, 2019 · 4 comments
Open

return length of the Java embedTokens() method #2

kkalouli opened this issue Jun 24, 2019 · 4 comments

Comments

@kkalouli
Copy link

Hi,

first of all, well done for this great work and thanks for making it publicly available!

I have the following problem: I am using the Java version and I want to match each token to its embedding. I am loading the English uncased model and getting the embeddings of 2 strings (str1, str2) with
float[][][] embeddings = bert.embedTokens(str1, str2);.
After that, I can get the embedding corresponding to each sequence/string by

float[][] firstSent = embeddings[0];
float[][] secondSent = embeddings[1];

However, firstSent and secondSent have always a standard length of 127 and not the length of my strings str1 and str2. If I then do firstSent[0], this will have a length of 768 which is the expected size of the embeddings but I don't understand why I am getting 127 as the length of firstSent and secondSent. And since I get this length, I guess that firstSent[0] does NOT correspond to the first token of my first sentence, which is what I would like to get.

Any help is much appreciated! Thanks a lot!

@kkalouli
Copy link
Author

Ok, so I figured out the answer myself, so posting it here in case it helps someone else:

the array float[][] that one gets back by running e.g. embeddings[0] will always have the size 128 because this configuration (max_seq_len) is loaded when loading a pretrained model. (in my case: bert-uncased-L-12-H-768-A-12). The first position of this array is always occupied by [CLS] and [SEP] is also included in the right place. The rest of the positions that do not "correspond" to the given sentence are filled with paddings. See here for how this looks like: https://github.com/hanxiao/bert-as-service#getting-elmo-like-contextual-word-embedding
(check Getting ELMo-like contextual word embedding Section) and here: https://github.com/google-research/bert (check Tokenization Section). This means that you cannot really do a one-to-one translation of the 128-size array to your original sentence.

There are 2 options as I see it:

  1. Change the max_seq_len from the original configuration when loading the model to just fit the exact size of your sentence ("Set it to NONE for dynamically using the longest sequence in a (mini)batch." from https://github.com/hanxiao/bert-as-service#getting-elmo-like-contextual-word-embedding ) This is more straightforward in the original python implementation. In the java implemtentation, the maxSequenceLength has to be an integer or a String, so by currently being an integer one would need to overwrite the current value. But I think this is easy. In the getInputs() method of Bert class we could add something like:
maxSequenceLength = tokens.length + 2 ;  // +2 is needed for the reserved CLS and SEP at the beginning and end of the sequence 

@robrua would you consider adding this? According to https://github.com/hanxiao/bert-as-service, it is more efficient (faster) to have smaller size maxSequenceLength. (see question "How about the speed? Is it fast enough for production?")

  1. Write a piece of code to map the BERT positions of the array to the original tokens of the sentence. This is also proposed here: https://github.com/google-research/bert. I converted it to java and I am pasting it here, in case somebody else wants to use it. It gives you a mapping from the position of your original tokens to the positions of the tokens within the bert array, similarly to the link above.
public HashMap<String,Integer> matchOriginalTokens2BERTTokens(String[] originalTokens ){
		ArrayList<String> bertTokens = new ArrayList<String>();
		HashMap<String,Integer> orig2TokenMap = new HashMap<String,Integer>();
		// create a wordpiece tokenizer
		FullTokenizer tokenizer = new FullTokenizer(new File("/path/to/file/vocab.txt"), true);
		// bert tokens start with CLS
		bertTokens.add("CLS");
		// go through the original tokens
		for (String origToken :originalTokens ){
			orig2TokenMap.put(origToken,bertTokens.size());
			// tokenize the current original token with the wordpiece tokenizer
			String[] tokToken = tokenizer.tokenize(origToken);
			// add each of those new tokens to the bertTokens, so that the latter increases its size
			for (String tok : tokToken){
				bertTokens.add(tok);	
			}
		}
		// bert tokens end with SEP
		bertTokens.add("SEP");
		return orig2TokenMap;
	}

@robrua
Copy link
Owner

robrua commented Jun 28, 2019

Hey, thanks for the research and the detailed issue.

For (1) that's an excellent idea to add here and I'll look into allowing dynamic max sequence length on both the Python and Java ends next time I sit down and do some work on this project.

For (2) there's an extra complication involved here in matching the output token vectors back to the original source: BERT uses a wordpiece vocabulary which may split a single word from your sequence into multiple subtokens before inputting it to the model. Because of this, the output size doesn't necessarily match the number of words in the input sequence (even after considering the [CLS] and [SEP] tokens); you'd need to inject some "tracking" logic into the tokenizer to keep track of any words that are getting subdivided during tokenization. There's no reason this wouldn't work, and I think providing a way to match the output vectors to each word in the input sequence would be useful, so I'll also take a look at this in the future.

@kkalouli
Copy link
Author

Hi Rob!

Thanks for considering adding (1) to the code.

About (2): you are right that this is not straightforward because of this special wordpiece tokenizer that bert is using but the code included in the link I posted (https://github.com/google-research/bert) and which I converted in java takes this into account. This means that it uses the same special wordpiece tokenizer to tokenize the words and keeps track of how each word is tokenized: e.g. the verb "faked" is tokenized as "fake" + "d", each of these two tokens matching to its own vector. In this case, the above code keeps track of the "real" word, in this case "fake" and gives you back the vector of "fake" and not "d" as the representation of "fake". In other words, the code above always tracks the first token of each word, which is also the base form of the word.

Thanks!

@robrua
Copy link
Owner

robrua commented Jul 11, 2019

Reopening this to remind myself to add this in the future.

On (2): Right then, I hadn't read it closely enough and failed to notice you were tracking the start indices for each token. I'll probably end up including something very similar to this, just integrated into the tokenizer itself to avoid needing to run it twice on each sequence.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants