Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@ Listed models for word embedding, and checked it already implemented.
- [x] Word2Vec
- Distributed Representations of Words and Phrases
and their Compositionality [[pdf]](https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf)
- [ ] GloVe
- [x] GloVe
- GloVe: Global Vectors for Word Representation [[pdf]](http://nlp.stanford.edu/pubs/glove.pdf)
- [ ] SPPMI-SVD
- Neural Word Embedding as Implicit Matrix Factorization [[pdf]](https://papers.nips.cc/paper/5477-neural-word-embedding-as-implicit-matrix-factorization.pdf)
- and more...

## Installation

Expand All @@ -46,6 +45,7 @@ Usage:

Available Commands:
distance Estimate the distance between words
glove Embed words using glove
help Help about any command
word2vec Embed words using word2vec

Expand All @@ -57,6 +57,7 @@ For more information about each sub-command, see below:
- [distance](./distance/README.md)
- [word2vec](./model/README.md)
- In code-based, refer to the [example](./example/example.go).
- [glove](./model/README.md)

## Demo

Expand All @@ -66,14 +67,12 @@ Downloading [text8](http://mattmahoney.net/dc/textdata) corpus, and training by
$ sh demo.sh
```

## File I/O
- Input
- Given a text is composed of one-sentence per one-line, ideally.
- Output
- Output a file is subject to the following format:
```
<word> <value1> <value2> ...
```
## Output
Output a file is subject to the following format:

```
<word> <value1> <value2> ...
```

## References
- Just see it for more deep comprehension:
Expand Down
170 changes: 170 additions & 0 deletions builder/glove.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Copyright © 2017 Makoto Ito
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package builder

import (
"github.com/pkg/errors"
"github.com/spf13/viper"

"github.com/ynqa/word-embedding/config"
"github.com/ynqa/word-embedding/model"
"github.com/ynqa/word-embedding/model/glove"
)

// GloveBuilder manages the members to build the Model interface.
type GloveBuilder struct {
dimension int
window int
initLearningRate float64
thread int
toLower bool
verbose bool

solver string
iteration int
alpha float64
xmax int
minCount int
batchSize int
}

// NewGloveBuilder creates *GloveBuilder
func NewGloveBuilder() *GloveBuilder {
return &GloveBuilder{
dimension: config.DefaultDimension,
window: config.DefaultWindow,
initLearningRate: config.DefaultInitLearningRate,
thread: config.DefaultThread,
toLower: config.DefaultToLower,
verbose: config.DefaultVerbose,

solver: config.DefaultSolver,
iteration: config.DefaultIteration,
alpha: config.DefaultAlpha,
xmax: config.DefaultXmax,
minCount: config.DefaultMinCount,
batchSize: config.DefaultBatchSize,
}
}

// NewGloveBuilderViper creates *GloveBuilder using viper.
func NewGloveBuilderViper() *GloveBuilder {
return &GloveBuilder{
dimension: viper.GetInt(config.Dimension.String()),
window: viper.GetInt(config.Window.String()),
initLearningRate: viper.GetFloat64(config.InitLearningRate.String()),
thread: viper.GetInt(config.Thread.String()),
toLower: viper.GetBool(config.ToLower.String()),
verbose: viper.GetBool(config.Verbose.String()),

solver: viper.GetString(config.Solver.String()),
iteration: viper.GetInt(config.Iteration.String()),
alpha: viper.GetFloat64(config.Alpha.String()),
xmax: viper.GetInt(config.Xmax.String()),
minCount: viper.GetInt(config.MinCount.String()),
batchSize: viper.GetInt(config.BatchSize.String()),
}
}

// SetDimension sets the dimension of word vector.
func (gb *GloveBuilder) SetDimension(dimension int) *GloveBuilder {
gb.dimension = dimension
return gb
}

// SetWindow sets the context window size.
func (gb *GloveBuilder) SetWindow(window int) *GloveBuilder {
gb.window = window
return gb
}

// SetInitLearningRate sets the initial learning rate.
func (gb *GloveBuilder) SetInitLearningRate(initlr float64) *GloveBuilder {
gb.initLearningRate = initlr
return gb
}

// SetThread sets number of goroutine.
func (gb *GloveBuilder) SetThread(thread int) *GloveBuilder {
gb.thread = thread
return gb
}

// SetToLower converts the words in corpus to lowercase.
func (gb *GloveBuilder) SetToLower() *GloveBuilder {
gb.toLower = true
return gb
}

// SetVerbose sets verbose mode.
func (gb *GloveBuilder) SetVerbose() *GloveBuilder {
gb.verbose = true
return gb
}

// SetSolver sets the solver.
func (gb *GloveBuilder) SetSolver(solver string) *GloveBuilder {
gb.solver = solver
return gb
}

// SetIteration sets the number of iteration.
func (gb *GloveBuilder) SetIteration(iter int) *GloveBuilder {
gb.iteration = iter
return gb
}

// SetAlpha sets alpha.
func (gb *GloveBuilder) SetAlpha(alpha float64) *GloveBuilder {
gb.alpha = alpha
return gb
}

// SetXmax sets x-max.
func (gb *GloveBuilder) SetXmax(xmax int) *GloveBuilder {
gb.xmax = xmax
return gb
}

// SetMinCount sets min count.
func (gb *GloveBuilder) SetMinCount(minCount int) *GloveBuilder {
gb.minCount = minCount
return gb
}

// SetBatchSize sets batchSize
func (gb *GloveBuilder) SetBatchSize(batchSize int) *GloveBuilder {
gb.batchSize = batchSize
return gb
}

// Build creates model.Model interface.
func (gb *GloveBuilder) Build() (model.Model, error) {
cnf := model.NewConfig(gb.dimension, gb.window, gb.initLearningRate,
gb.thread, gb.toLower, gb.verbose)

var solver glove.Solver
switch gb.solver {
case "sgd":
solver = glove.NewSGD(cnf)
case "adagrad":
solver = glove.NewAdaGrad(cnf)
default:
return nil, errors.Errorf("Invalid solver: %s not in sgd|adagrad", gb.solver)
}

return glove.NewGlove(cnf, solver,
gb.iteration, gb.xmax, gb.alpha, gb.minCount, gb.batchSize), nil
}
127 changes: 127 additions & 0 deletions builder/glove_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright © 2017 Makoto Ito
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package builder

import (
"testing"
)

func TestGloveSetDimension(t *testing.T) {
b := &GloveBuilder{}
b.SetDimension(100)

if b.dimension != 100 {
t.Errorf("Expected builder.alpha=0.1: %v", b.alpha)
}
}

func TestGloveSetWindow(t *testing.T) {
b := &GloveBuilder{}
b.SetWindow(10)

if b.window != 10 {
t.Errorf("Expected builder.window=10: %v", b.window)
}
}

func TestGloveSetInitLearningRate(t *testing.T) {
b := &GloveBuilder{}
b.SetInitLearningRate(0.001)

if b.initLearningRate != 0.001 {
t.Errorf("Expected builder.initLearningRate=0.001: %v", b.initLearningRate)
}
}

func TestGloveSetToLower(t *testing.T) {
b := &GloveBuilder{}
b.SetToLower()

if !b.toLower {
t.Errorf("Expected builder.lower=true: %v", b.toLower)
}
}

func TestGloveSetVerbose(t *testing.T) {
b := &GloveBuilder{}
b.SetVerbose()

if !b.verbose {
t.Errorf("Expected builder.verbose=true: %v", b.verbose)
}
}

func TestGloveSetSolver(t *testing.T) {
b := &GloveBuilder{}
b.SetSolver("adagrad")

if b.solver != "adagrad" {
t.Errorf("Expected builder.solver=adagrad: %v", b.solver)
}
}

func TestGloveSetIteration(t *testing.T) {
b := &GloveBuilder{}
b.SetIteration(50)

if b.iteration != 50 {
t.Errorf("Expected builder.iteration=50: %v", b.iteration)
}
}

func TestGloveSetAlpha(t *testing.T) {
b := &GloveBuilder{}
b.SetAlpha(0.1)

if b.alpha != 0.1 {
t.Errorf("Expected builder.alpha=0.1: %v", b.alpha)
}
}

func TestGloveSetXmax(t *testing.T) {
b := &GloveBuilder{}
b.SetXmax(10)

if b.xmax != 10 {
t.Errorf("Expected builder.alpha=10: %v", b.xmax)
}
}

func TestGloveSetMinCount(t *testing.T) {
b := &GloveBuilder{}
b.SetMinCount(10)

if b.minCount != 10 {
t.Errorf("Expected builder.minCount=10: %v", b.minCount)
}
}

func TestGloveSetBatchSize(t *testing.T) {
b := &GloveBuilder{}
b.SetBatchSize(2048)

if b.batchSize != 2048 {
t.Errorf("Expected builder.batchSize=2048: %v", b.batchSize)
}
}

func TestGloveInvalidSolverBuild(t *testing.T) {
b := &GloveBuilder{}
b.SetSolver("fake_solver")

if _, err := b.Build(); err == nil {
t.Errorf("Expected to fail building with invalid solver except for sgd|adagrad: %v", b.solver)
}
}
Loading