Skip to content

Commit

Permalink
Embedding similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
sivareddyg committed Dec 20, 2016
1 parent 5026252 commit 9eb979c
Show file tree
Hide file tree
Showing 13 changed files with 296 additions and 146 deletions.
Expand Up @@ -88,11 +88,11 @@ public CcgParseToUngroundedGraphs(String dataFolder, String languageCode,
GroundedLexicon groundedLexicon = new GroundedLexicon(null);
graphCreator = new GroundedGraphs(schema, kb, groundedLexicon,
normalCcgAutoLexicon, questionCcgAutoLexicon,
relationLexicalIdentifiers, relationTypingIdentifiers, null, 1, false,
relationLexicalIdentifiers, relationTypingIdentifiers, null, null, 1,
false, false, false, false, false, false, false, false, false, false,
false, false, false, false, false, false, false, false, false, false,
false, false, false, false, false, false, false, false, false, false,
false, false, false, 10.0, 1.0, 0.0, 0.0, 0.0);
false, false, false, false, false, 10.0, 1.0, 0.0, 0.0, 0.0);
}

public List<List<LexicalGraph>> processText(String line)
Expand Down
61 changes: 36 additions & 25 deletions src/in/sivareddy/graphparser/cli/RunGraphToQueryTrainingMain.java
Expand Up @@ -51,6 +51,7 @@ public class RunGraphToQueryTrainingMain extends AbstractCli {
// Log File
private OptionSpec<String> logFile;
private OptionSpec<String> loadModelFromFile;
private OptionSpec<String> embeddingFile;
private OptionSpec<String> lexicon;
private OptionSpec<String> cachedKB;
private OptionSpec<String> testFile;
Expand Down Expand Up @@ -119,6 +120,7 @@ public class RunGraphToQueryTrainingMain extends AbstractCli {
private OptionSpec<Boolean> argumentStemMatchingFlag;
private OptionSpec<Boolean> argumentStemGrelPartMatchingFlag;
private OptionSpec<Boolean> ngramStemMatchingFlag;
private OptionSpec<Boolean> useEmbeddingSimilarityFlag;

// Graph features
private OptionSpec<Boolean> graphIsConnectedFlag;
Expand Down Expand Up @@ -278,6 +280,12 @@ public void initializeOptions(OptionParser parser) {
.accepts("loadModelFromFile",
"Load model from serialized model file").withRequiredArg()
.ofType(String.class).defaultsTo("");

embeddingFile =
parser
.accepts("embeddingFile",
"Load word embeddings from file").withRequiredArg()
.ofType(String.class).defaultsTo("");

lexicon =
parser.accepts("lexicon", "lexicon containing nl to grounded mappings")
Expand Down Expand Up @@ -476,6 +484,10 @@ public void initializeOptions(OptionParser parser) {
.accepts("ngramStemMatchingFlag",
"use stem overlaps between words in the sentence and grounded edges")
.withRequiredArg().ofType(Boolean.class).defaultsTo(false);
useEmbeddingSimilarityFlag = parser
.accepts("useEmbeddingSimilarityFlag",
"use embedding similarity between grounded edges and words or ungronded edges")
.withRequiredArg().ofType(Boolean.class).defaultsTo(false);

// Graph features
graphIsConnectedFlag =
Expand Down Expand Up @@ -671,21 +683,18 @@ public void run(OptionSet options) {
KnowledgeBase kb = null;

if (!options.valueOf(cachedKB).equals("")) {
kb =
new KnowledgeBaseCached(options.valueOf(cachedKB),
relationTypesFileName);
kb = new KnowledgeBaseCached(options.valueOf(cachedKB),
relationTypesFileName);
} else {
KnowledgeBaseOnline.TYPE_KEY = options.valueOf(typeKey);
kb =
new KnowledgeBaseOnline(options.valueOf(endpoint), String.format(
"http://%s:8890/sparql", options.valueOf(endpoint)), "dba",
"dba", 50000, schemaObj);
kb = new KnowledgeBaseOnline(options.valueOf(endpoint),
String.format("http://%s:8890/sparql", options.valueOf(endpoint)),
"dba", "dba", 50000, schemaObj);
}

RdfGraphTools rdfGraphTools =
new RdfGraphTools(options.valueOf(endpoint), String.format(
"http://%s:8890/sparql", options.valueOf(endpoint)), "dba",
"dba", options.valueOf(timeout));
RdfGraphTools rdfGraphTools = new RdfGraphTools(options.valueOf(endpoint),
String.format("http://%s:8890/sparql", options.valueOf(endpoint)),
"dba", "dba", options.valueOf(timeout));
GraphToSparqlConverter.TYPE_KEY = options.valueOf(typeKey);
GroundedGraphs.CONTENT_WORD_POS =
Sets.newHashSet(Splitter.on(";").trimResults().omitEmptyStrings()
Expand All @@ -694,15 +703,13 @@ public void run(OptionSet options) {
List<String> kbGraphUri =
Lists.newArrayList(Splitter.on(";").split(options.valueOf(domain)));

CcgAutoLexicon normalCcgAutoLexicon =
new CcgAutoLexicon(options.valueOf(ccgIndexedMapping),
options.valueOf(unaryRules), options.valueOf(binaryRules),
options.valueOf(ccgLexicon));
CcgAutoLexicon normalCcgAutoLexicon = new CcgAutoLexicon(
options.valueOf(ccgIndexedMapping), options.valueOf(unaryRules),
options.valueOf(binaryRules), options.valueOf(ccgLexicon));

CcgAutoLexicon questionCcgAutoLexicon =
new CcgAutoLexicon(options.valueOf(ccgIndexedMapping),
options.valueOf(unaryRules), options.valueOf(binaryRules),
options.valueOf(ccgLexiconQuestions));
CcgAutoLexicon questionCcgAutoLexicon = new CcgAutoLexicon(
options.valueOf(ccgIndexedMapping), options.valueOf(unaryRules),
options.valueOf(binaryRules), options.valueOf(ccgLexiconQuestions));

GroundedLexicon groundedLexicon =
new GroundedLexicon(options.valueOf(lexicon));
Expand All @@ -718,6 +725,7 @@ public void run(OptionSet options) {

String logfile = options.valueOf(logFile);
String loadModelFromFileVal = options.valueOf(loadModelFromFile);
String embeddingFileVal = options.valueOf(embeddingFile);
boolean debugEnabled = options.valueOf(debugEnabledFlag);

int threadCount = options.valueOf(nthreads);
Expand Down Expand Up @@ -771,6 +779,8 @@ public void run(OptionSet options) {
boolean argumentStemGrelPartMatchingFlagVal =
options.valueOf(argumentStemGrelPartMatchingFlag);
boolean ngramStemMatchingFlagVal = options.valueOf(ngramStemMatchingFlag);
boolean useEmbeddingSimilarityFlagVal =
options.valueOf(useEmbeddingSimilarityFlag);

// Graph features
boolean graphIsConnectedFlagVal = options.valueOf(graphIsConnectedFlag);
Expand Down Expand Up @@ -831,16 +841,16 @@ public void run(OptionSet options) {
options.valueOf(groundTrainingCorpusInTheEnd);

// Set pointWiseF1Threshold for learning. IMPORTANT.
GraphToQueryTraining.setPointWiseF1Threshold(options
.valueOf(pointWiseF1Threshold));
GraphToQueryTraining
.setPointWiseF1Threshold(options.valueOf(pointWiseF1Threshold));

GraphToQueryTrainingMain graphToQueryModel = new GraphToQueryTrainingMain(
schemaObj, kb, groundedLexicon, normalCcgAutoLexicon,
questionCcgAutoLexicon, rdfGraphTools, kbGraphUri, testfile, devfile,
supervisedTrainingFile, corupusTrainingFile, groundInputCorporaFiles,
semanticParseKeyString, goldParsesFileVal, mostFrequentTypesFileVal,
debugEnabled, groundTrainingCorpusInTheEndVal,
trainingSampleSizeCount, logfile, loadModelFromFileVal,
trainingSampleSizeCount, logfile, loadModelFromFileVal, embeddingFileVal,
nBestTrainSyntacticParsesVal, nBestTestSyntacticParsesVal,
nbestEdgesVal, nbestGraphsVal, forestSizeVal, ngramLengthVal,
useSchemaVal, useKBVal, groundFreeVariablesVal,
Expand All @@ -852,9 +862,10 @@ public void run(OptionSet options) {
questionTypeGrelPartFlagVal, stemMatchingFlagVal,
mediatorStemGrelPartMatchingFlagVal, argumentStemMatchingFlagVal,
argumentStemGrelPartMatchingFlagVal, ngramStemMatchingFlagVal,
graphIsConnectedFlagVal, graphHasEdgeFlagVal, countNodesFlagVal,
edgeNodeCountFlagVal, duplicateEdgesFlagVal, grelGrelFlagVal,
useLexiconWeightsRelVal, useLexiconWeightsTypeVal, validQueryFlagVal,
useEmbeddingSimilarityFlagVal, graphIsConnectedFlagVal,
graphHasEdgeFlagVal, countNodesFlagVal, edgeNodeCountFlagVal,
duplicateEdgesFlagVal, grelGrelFlagVal, useLexiconWeightsRelVal,
useLexiconWeightsTypeVal, validQueryFlagVal,
useAnswerTypeQuestionWordFlagVal, useNbestGraphsVal,
addBagOfWordsGraphVal, addOnlyBagOfWordsGraphVal,
handleNumbersFlagVal, entityScoreFlagVal, entityWordOverlapFlagVal,
Expand Down
15 changes: 7 additions & 8 deletions src/in/sivareddy/graphparser/cli/RunPrintDomainLexicon.java
Expand Up @@ -169,14 +169,13 @@ public void run(OptionSet options) {
String[] relationTypingIdentifiers = {};

GroundedLexicon groundedLexicon = new GroundedLexicon(null);
GroundedGraphs graphCreator =
new GroundedGraphs(schemaObj, kb, groundedLexicon,
normalCcgAutoLexicon, questionCcgAutoLexicon,
relationLexicalIdentifiers, relationTypingIdentifiers, null, 1,
false, false, false, false, false, false, false, false, false,
false, false, false, false, false, false, false, false, false, false,
false, false, false, false, false, false, false, false, false,
false, false, false, false, false, false, 0.0, 0.0, 0.0, 0.0, 0.0);
GroundedGraphs graphCreator = new GroundedGraphs(schemaObj, kb,
groundedLexicon, normalCcgAutoLexicon, questionCcgAutoLexicon,
relationLexicalIdentifiers, relationTypingIdentifiers, null, null, 1,
false, false, false, false, false, false, false, false, false, false,
false, false, false, false, false, false, false, false, false, false,
false, false, false, false, false, false, false, false, false, false,
false, false, false, false, false, 0.0, 0.0, 0.0, 0.0, 0.0);

CreateGroundedLexicon engine =
new CreateGroundedLexicon(graphCreator, kb, semanticParseKeyString,
Expand Down
61 changes: 33 additions & 28 deletions src/in/sivareddy/graphparser/learning/GraphToQueryTraining.java
Expand Up @@ -7,6 +7,7 @@
import in.sivareddy.graphparser.parsing.LexicalGraph;
import in.sivareddy.graphparser.parsing.LexicalGraph.AnswerTypeQuestionWordFeature;
import in.sivareddy.graphparser.parsing.LexicalGraph.ValidQueryFeature;
import in.sivareddy.graphparser.util.CrossLingualEmbeddingSimilarity;
import in.sivareddy.graphparser.util.GroundedLexicon;
import in.sivareddy.graphparser.util.RdfGraphTools;
import in.sivareddy.graphparser.util.Schema;
Expand Down Expand Up @@ -70,6 +71,7 @@ public class GraphToQueryTraining {
private static double POINTWISE_F1_THRESHOLD = 0.90;

private StructuredPercepton learningModel;
private CrossLingualEmbeddingSimilarity embeddings;
private Schema schema;
private KnowledgeBase kb;
private GroundedLexicon groundedLexicon;
Expand Down Expand Up @@ -118,7 +120,8 @@ public GraphToQueryTraining(Schema schema, KnowledgeBase kb,
boolean useSchema, boolean useKB, boolean groundFreeVariables,
boolean groundEntityVariableEdges, boolean groundEntityEntityEdges,
boolean useEmtpyTypes, boolean ignoreTypes,
StructuredPercepton learningModel, boolean urelGrelFlag,
StructuredPercepton learningModel,
CrossLingualEmbeddingSimilarity embeddings, boolean urelGrelFlag,
boolean urelPartGrelPartFlag, boolean utypeGtypeFlag,
boolean gtypeGrelFlag, boolean grelGrelFlag, boolean ngramGrelPartFlag,
boolean wordGrelPartFlag, boolean wordGrelFlag, boolean argGrelPartFlag,
Expand All @@ -127,20 +130,20 @@ public GraphToQueryTraining(Schema schema, KnowledgeBase kb,
boolean mediatorStemGrelPartMatchingFlag,
boolean argumentStemMatchingFlag,
boolean argumentStemGrelPartMatchingFlag, boolean ngramStemMatchingFlag,
boolean graphIsConnectedFlag, boolean graphHasEdgeFlag,
boolean countNodesFlag, boolean edgeNodeCountFlag,
boolean useLexiconWeightsRel, boolean useLexiconWeightsType,
boolean duplicateEdgesFlag, boolean validQueryFlag,
boolean useAnswerTypeQuestionWordFlag, boolean useNbestSurrogateGraphs,
boolean addBagOfWordsGraph, boolean addOnlyBagOfWordsGraph,
boolean handleNumbers, boolean entityScoreFlag,
boolean entityWordOverlapFlag, boolean paraphraseScoreFlag,
boolean paraphraseClassifierScoreFlag, boolean allowMerging,
boolean useGoldRelations, boolean evaluateOnlyTheFirstBest,
boolean handleEventEventEdges, boolean useExpand, boolean useHyperExpand,
double initialEdgeWeight, double initialTypeWeight,
double initialWordWeight, double mergeEdgeWeight,
double stemFeaturesWeight,
boolean useEmbeddingSimilarityFlag, boolean graphIsConnectedFlag,
boolean graphHasEdgeFlag, boolean countNodesFlag,
boolean edgeNodeCountFlag, boolean useLexiconWeightsRel,
boolean useLexiconWeightsType, boolean duplicateEdgesFlag,
boolean validQueryFlag, boolean useAnswerTypeQuestionWordFlag,
boolean useNbestSurrogateGraphs, boolean addBagOfWordsGraph,
boolean addOnlyBagOfWordsGraph, boolean handleNumbers,
boolean entityScoreFlag, boolean entityWordOverlapFlag,
boolean paraphraseScoreFlag, boolean paraphraseClassifierScoreFlag,
boolean allowMerging, boolean useGoldRelations,
boolean evaluateOnlyTheFirstBest, boolean handleEventEventEdges,
boolean useExpand, boolean useHyperExpand, double initialEdgeWeight,
double initialTypeWeight, double initialWordWeight,
double mergeEdgeWeight, double stemFeaturesWeight,
RdfGraphTools rdfGraphTools, List<String> kbGraphUri) throws IOException {
String[] relationLexicalIdentifiers = {"lemma"};
String[] relationTypingIdentifiers = {};
Expand All @@ -165,6 +168,7 @@ public GraphToQueryTraining(Schema schema, KnowledgeBase kb,
this.useAnswerTypeQuestionWordFlag = useAnswerTypeQuestionWordFlag;

this.learningModel = learningModel;
this.embeddings = embeddings;
this.schema = schema;
this.kb = kb;
this.groundedLexicon = groundedLexicon;
Expand All @@ -188,19 +192,20 @@ public GraphToQueryTraining(Schema schema, KnowledgeBase kb,
this.graphCreator = new GroundedGraphs(this.schema, this.kb,
this.groundedLexicon, normalCcgAutoLexicon, questionCcgAutoLexicon,
relationLexicalIdentifiers, relationTypingIdentifiers,
this.learningModel, ngramLength, urelGrelFlag, urelPartGrelPartFlag,
utypeGtypeFlag, gtypeGrelFlag, grelGrelFlag, ngramGrelPartFlag,
wordGrelPartFlag, wordGrelFlag, argGrelPartFlag, argGrelFlag,
questionTypeGrelPartFlag, eventTypeGrelPartFlag, stemMatchingFlag,
mediatorStemGrelPartMatchingFlag, argumentStemMatchingFlag,
argumentStemGrelPartMatchingFlag, ngramStemMatchingFlag,
graphIsConnectedFlag, graphHasEdgeFlag, countNodesFlag,
edgeNodeCountFlag, useLexiconWeightsRel, useLexiconWeightsType,
duplicateEdgesFlag, ignorePronouns, handleNumbers, entityScoreFlag,
entityWordOverlapFlag, paraphraseScoreFlag,
paraphraseClassifierScoreFlag, allowMerging, handleEventEventEdges,
useExpand, useHyperExpand, initialEdgeWeight, initialTypeWeight,
initialWordWeight, mergeEdgeWeight, stemFeaturesWeight);
this.learningModel, this.embeddings, ngramLength, urelGrelFlag,
urelPartGrelPartFlag, utypeGtypeFlag, gtypeGrelFlag, grelGrelFlag,
ngramGrelPartFlag, wordGrelPartFlag, wordGrelFlag, argGrelPartFlag,
argGrelFlag, questionTypeGrelPartFlag, eventTypeGrelPartFlag,
stemMatchingFlag, mediatorStemGrelPartMatchingFlag,
argumentStemMatchingFlag, argumentStemGrelPartMatchingFlag,
ngramStemMatchingFlag, useEmbeddingSimilarityFlag, graphIsConnectedFlag,
graphHasEdgeFlag, countNodesFlag, edgeNodeCountFlag,
useLexiconWeightsRel, useLexiconWeightsType, duplicateEdgesFlag,
ignorePronouns, handleNumbers, entityScoreFlag, entityWordOverlapFlag,
paraphraseScoreFlag, paraphraseClassifierScoreFlag, allowMerging,
handleEventEventEdges, useExpand, useHyperExpand, initialEdgeWeight,
initialTypeWeight, initialWordWeight, mergeEdgeWeight,
stemFeaturesWeight);
}


Expand Down

0 comments on commit 9eb979c

Please sign in to comment.