Skip to content

Commit

Permalink
Added Sequel plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Jun 4, 2023
1 parent fcd2ad1 commit f1944d2
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 8 deletions.
16 changes: 11 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,28 @@ DB.create_table :items do
end
```

Add the plugin to your model

```ruby
class Item < Sequel::Model
plugin :pgvector, :embedding
end
```

Insert a vector

```ruby
DB[:items].insert(embedding: Pgvector.encode([1, 1, 1]))
# or
Item.create(embedding: Pgvector.encode([1, 1, 1]))
```

Get the nearest neighbors to a vector

```ruby
DB[:items].order(Sequel.lit("embedding <-> ?", Pgvector.encode([1, 1, 1]))).limit(5)
# or
Item.order(Sequel.lit("embedding <-> ?", Pgvector.encode([1, 1, 1]))).limit(5)
Item.nearest_neighbors(:embedding, [1, 1, 1], distance: "euclidean").limit(5)
```

Also supports `inner_product` and `cosine` distance

## History

View the [changelog](https://github.com/pgvector/pgvector-ruby/blob/master/CHANGELOG.md)
Expand Down
3 changes: 2 additions & 1 deletion examples/disco_item_recs.rb
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
end

class Movie < Sequel::Model
plugin :pgvector, :factors
end

data = Disco.load_movielens
Expand All @@ -27,4 +28,4 @@ class Movie < Sequel::Model
Movie.multi_insert(movies)

movie = Movie.first(name: "Star Wars (1977)")
pp Movie.exclude(id: movie.id).order(Sequel.lit("factors <=> ?", movie.factors)).limit(5).map(&:name)
pp movie.nearest_neighbors(:factors, distance: "cosine").limit(5).map(&:name)
4 changes: 3 additions & 1 deletion examples/disco_user_recs.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
end

class Movie < Sequel::Model
plugin :pgvector, :factors
end

class User < Sequel::Model
plugin :pgvector, :factors
end

data = Disco.load_movielens
Expand All @@ -42,7 +44,7 @@ class User < Sequel::Model
User.multi_insert(users)

user = User[123]
pp Movie.order(Sequel.lit("factors <#> ?", user.factors)).limit(5).map(&:name)
pp Movie.nearest_neighbors(:factors, user.factors, distance: "inner_product").limit(5).map(&:name)

# excludes rated, so will be different for some users
# pp recommender.user_recs(user.id).map { |v| v[:item_id] }
46 changes: 46 additions & 0 deletions lib/sequel/plugins/pgvector.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module Sequel
module Plugins
module Pgvector
def self.configure(model, *columns)
model.vector_columns = columns.to_h { |c| [c.to_sym, {}] }
end

module ClassMethods
attr_accessor :vector_columns

def nearest_neighbors(column, value, distance:)
value = ::Pgvector.encode(value) unless value.is_a?(String)

operator =
case distance
when "inner_product"
"<#>"
when "cosine"
"<=>"
when "euclidean"
"<->"
end

raise ArgumentError, "Invalid distance: #{distance}" unless operator

quoted_column = dataset.quote_identifier(column)
exclude(column => nil).order(Sequel.lit("#{quoted_column} #{operator} ?", value))
end

Plugins.inherited_instance_variables(self, :@vector_columns => :dup)
end

module InstanceMethods
def nearest_neighbors(column, **options)
column = column.to_sym
# important! check if neighbor attribute before calling send
raise ArgumentError, "Invalid column" unless self.class.vector_columns[column]

self.class
.nearest_neighbors(column, self[column], **options)
.exclude(primary_key => self[primary_key])
end
end
end
end
end
3 changes: 2 additions & 1 deletion test/sequel_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
end

class Item < Sequel::Model(DB[:sequel_items])
plugin :pgvector, :embedding
end

class TestSequel < Minitest::Test
Expand All @@ -29,7 +30,7 @@ def test_model
Item.create(embedding: Pgvector.encode([1, 1, 1]))
Item.create(embedding: Pgvector.encode([2, 2, 2]))
Item.create(embedding: Pgvector.encode([1, 1, 2]))
results = Item.order(Sequel.lit("embedding <-> ?", Pgvector.encode([1, 1, 1]))).limit(5)
results = Item.nearest_neighbors(:embedding, [1, 1, 1], distance: "euclidean").limit(5)
assert_equal ["[1,1,1]", "[1,1,2]", "[2,2,2]"], results.map(&:embedding)
end

Expand Down

0 comments on commit f1944d2

Please sign in to comment.