Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for extensions like citext #250

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions spec/pg/decoder_spec.cr
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
require "uuid"
# Extensions must be loaded before spec_helper connects to the DB
require "../../src/pg/extensions/citext"
require "../spec_helper"

describe PG::Decoders do
Expand Down Expand Up @@ -92,4 +94,15 @@ describe PG::Decoders do
test_decode "path ", "'(1,2,3,4)'::path ", PG::Geo::Path.new([PG::Geo::Point.new(1.0, 2.0), PG::Geo::Point.new(3.0, 4.0)], closed: true)
test_decode "path ", "'[1,2,3,4,5,6]'::path", PG::Geo::Path.new([PG::Geo::Point.new(1.0, 2.0), PG::Geo::Point.new(3.0, 4.0), PG::Geo::Point.new(5.0, 6.0)], closed: false)
test_decode "polygon", "'1,2,3,4,5,6'::polygon", PG::Geo::Polygon.new([PG::Geo::Point.new(1.0, 2.0), PG::Geo::Point.new(3.0, 4.0), PG::Geo::Point.new(5.0, 6.0)])

it "decodes extension types on a per-connection basis" do
PG_DB.exec "CREATE EXTENSION IF NOT EXISTS citext"
PG_DB.query "select 'OMG lol'::citext" do |rs|
rs.each do
text = rs.read
text.should be_a PG::CIText
text.should eq "omg LOL"
end
end
end
end
29 changes: 29 additions & 0 deletions src/pg/connection.cr
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
require "./decoder"
require "../pq/*"

module PG
class Connection < ::DB::Connection
protected getter connection

private EXTENSIONS = [] of Extension

def self.register_extension(extension : Extension)
EXTENSIONS << extension
end

def initialize(context)
super
@connection = uninitialized PQ::Connection
@decoders = Hash(Int32, Decoders::Decoder).new do |_, oid|
Decoders.from_oid(oid)
end

begin
conn_info = PQ::ConnInfo.new(context.uri)
Expand All @@ -15,6 +25,25 @@ module PG
rescue ex
raise DB::ConnectionRefused.new(cause: ex)
end

auto_release = @auto_release
@auto_release = false
begin
EXTENSIONS.each(&.load(self))
ensure
@auto_release = auto_release
end
end

def register_decoder(decoder : Decoders::Decoder)
decoder.oids.each do |oid|
@decoders[oid] = decoder
end
end

@checked_oids = Set(Int32).new
def decoder(oid)
@decoders[oid]
end

def build_prepared_statement(query) : Statement
Expand Down
5 changes: 5 additions & 0 deletions src/pg/extensions.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module PG
module Extension
abstract def load(connection : Connection)
end
end
66 changes: 66 additions & 0 deletions src/pg/extensions/citext.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
require "../../pg"

module PG
module Extension
# This extension adds support for decoding the Postgres `citext` type.
class CIText
include Extension

def load(connection)
oid = connection.query_one "SELECT oid FROM pg_type WHERE typname = 'citext'", as: UInt32
connection.register_decoder Decoder.new([oid.to_i])
end

struct Decoder
include Decoders::Decoder

getter oids : Array(Int32)

def initialize(@oids : Array(Int32))
end

def decode(io, bytesize, oid)
PG::CIText.new(Decoders::StringDecoder.new.decode(io, bytesize, oid))
end

def type
PG::CIText
end
end
end

Connection.register_extension CIText.new
end

struct CIText
def initialize(text : String)
@text = text
end

def hash(hasher)
@text.hash(hasher)
end

def ==(other : self)
self == other.@text
end

def ==(other : String)
@text.compare(other, case_insensitive: true)
end

def to_s
@text
end

def to_s(io)
@text.to_s io
end
end
end

class String
def ==(other : PG::CIText)
other == self
end
end
2 changes: 1 addition & 1 deletion src/pg/result_set.cr
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class PG::ResultSet < ::DB::ResultSet
end

private def decoder(index = @column_index)
Decoders.from_oid(oid(index))
statement.connection.decoder(oid(index))
end

private def oid(index = @column_index)
Expand Down