From 2c47f4172fb93df42654af1ba1c60b71d97a9193 Mon Sep 17 00:00:00 2001 From: Jenn Wheeler Date: Sun, 2 Sep 2018 16:04:07 -0400 Subject: [PATCH] slowly getting closer to properly working --- src/odbc/connection.cr | 12 ++-- src/odbc/decode.cr | 0 src/odbc/handle.cr | 4 +- src/odbc/libodbc.cr | 30 +++++++++- src/odbc/result.cr | 121 +++++++++++++++++++++++++++++++++++++++++ src/odbc/statement.cr | 31 +++++++++-- src/odbc/types.cr | 106 ++++++++++++++++++++++++++---------- 7 files changed, 263 insertions(+), 41 deletions(-) create mode 100644 src/odbc/decode.cr diff --git a/src/odbc/connection.cr b/src/odbc/connection.cr index fb3a982..5d85bd8 100644 --- a/src/odbc/connection.cr +++ b/src/odbc/connection.cr @@ -1,5 +1,7 @@ module ODBC class Connection < DB::Connection + getter raw_conn + def initialize(context : DB::ConnectionContext) # set up all the basic connection info super(context) @@ -9,19 +11,19 @@ module ODBC conns_size = conn_string.size.to_i16 @env = ODBC.alloc_env - @connection = ODBC.alloc_conn(@env) + @raw_conn = ODBC.alloc_conn(@env) - result = LibODBC.driver_connect(@connection, + result = LibODBC.driver_connect(@raw_conn, nil, conn_string, conns_size, nil, 0, nil, - LibODBC::SqlDriverConnect::SqlDriverComplete) + LibODBC::DriverConnect::SqlDriverComplete) if result == LibODBC::SqlReturn::SqlSuccessWithInfo - puts ODBC.get_detail("SQLDriverConnect", @connection, 1) + puts ODBC.get_detail("SQLDriverConnect", @raw_conn, 1) elsif result != LibODBC::SqlReturn::SqlSuccess raise Errno.new("Error establishing connection to server") end @@ -39,7 +41,7 @@ module ODBC def do_close LibODBC.disconnect(nil) LibODBC.free_handle(ODBC::HandleType::SqlHandleEnv.value, @env) - LibODBC.free_handle(ODBC::HandleType::SqlHandleDbc, @connection) + LibODBC.free_handle(ODBC::HandleType::SqlHandleDbc, @raw_conn) end # :nodoc: diff --git a/src/odbc/decode.cr b/src/odbc/decode.cr new file mode 100644 index 0000000..e69de29 diff --git a/src/odbc/handle.cr b/src/odbc/handle.cr index 2605486..e76ebca 100644 --- a/src/odbc/handle.cr +++ b/src/odbc/handle.cr @@ -15,8 +15,8 @@ module ODBC end version = Pointer(Void).new(LibODBC::OdbcVer::SqlOvOdbc3.value) - result2 = LibODBC.set_env_attr(output_handle_ptr, LibODBC::EnvAttr::SqlAttrOdbcVersion.value, version, 0) - if result2.value != 0 && result2.value != 1 + env_result = LibODBC.set_env_attr(output_handle_ptr, LibODBC::EnvAttr::SqlAttrOdbcVersion.value, version, 0) + if env_result.value != 0 && env_result.value != 1 error = ODBC.get_detail("SQLSetEnvAttr", output_handle_ptr, 1) raise Errno.new(error) end diff --git a/src/odbc/libodbc.cr b/src/odbc/libodbc.cr index a75dd72..c90edea 100644 --- a/src/odbc/libodbc.cr +++ b/src/odbc/libodbc.cr @@ -29,7 +29,7 @@ lib LibODBC SqlError = -1 end - enum SqlDriverConnect + enum DriverConnect SqlDriverNoPrompt = 0 SqlDriverComplete = 1 SqlDriverPrompt = 2 @@ -44,6 +44,17 @@ lib LibODBC SqlAttrOutputNts = 10001 end + enum FetchOrientation + SqlFetchNext = 1 + SqlFetchFirst = 2 + SqlFetchLast = 3 + SqlFetchPrior = 4 + SqlFetchAbsolute = 5 + SqlFetchRelative = 6 + SqlFetchFirstuser = 31 + SqlFetchFirstSystem = 32 + end + enum OdbcVer SqlOvOdbc2 = 2 SqlOvOdbc3 = 3 @@ -51,6 +62,12 @@ lib LibODBC SqlOvOdbc4 = 400 end + enum Nullable + SqlNullableUnknown = 2 + SqlNullable = 1 + SqlNoNulls = 0 + end + fun alloc_handle = SQLAllocHandle(handle_type : SqlSmallInt, input_handle : SqlHandle*, output_handle_ptr : SqlHandle*) : SqlReturn @@ -132,6 +149,10 @@ lib LibODBC fun fetch = SQLFetch(statement_handle : SqlHStmt) : SqlReturn + fun fetch_scroll = SQLFetchScroll(statement_handle : SqlHStmt, + fetch_orientation : SqlSmallInt, + fetch_offset : SqlLen) : SqlReturn + fun free_handle = SQLFreeHandle(handle_type : SqlSmallInt, handle : SqlHandle) fun free_stmt = SQLFreeStmt(statement_handle : SqlHStmt, option : SqlUSmallInt) @@ -160,7 +181,12 @@ lib LibODBC fun num_result_cols = SQLNumResultCols(statement_handle : SqlHStmt, column_count_ptr : SqlSmallInt*) : SqlReturn - fun prepare = SQLPrepare(statement_handle : SqlHStmt, statement_text : SqlChar*, text_length : SqlInteger) : SqlReturn + fun prepare = SQLPrepare(statement_handle : SqlHStmt, + statement_text : SqlChar*, + text_length : SqlInteger) : SqlReturn + + fun row_count = SQLRowCount(statement_handle : SqlHStmt, + row_count_ptr : SqlLen*) : SqlReturn fun set_connect_attr = SQLSetConnectAttr(connection_handle : SqlHDBC, attribute : SqlInteger, diff --git a/src/odbc/result.cr b/src/odbc/result.cr index e69de29..4c1e289 100644 --- a/src/odbc/result.cr +++ b/src/odbc/result.cr @@ -0,0 +1,121 @@ +module ODBC + class Field + @name : String + @col_type : SqlDataType + @col_size : LibODBC::SqlULen + @nullable : Bool + + getter name, col_type, col_size, nullable + + def initialize(stmt : Void*, col_num : Int32) + LibODBC.describe_col(stmt, col_num, out name, + 256, out name_len, out col_type, + out col_size, out digits, out nullable) + + @name = String.new(Pointer(UInt8).new(name)) + @col_type = SqlDataType.new(col_type.to_i32) + @col_size = col_size + @nullable = case nullable + when LibODBC::Nullable::SqlNullable + true + else + false + end + end + end + + class ResultSet < DB::ResultSet + @num_cols : Int32 + @rows_affected : Int64 + @buffer : Array(UInt8*) + @strlen : Array(Int64) + + getter rows_affected + + def initialize(statement) + super(statement) + @col_index = 0 + @row_index = 0_i64 + + LibODBC.row_count(statement.raw_stmt, out rows_affected) + @rows_affected = rows_affected + + LibODBC.num_result_cols(statement.raw_stmt, out num_cols) + @num_cols = num_cols.to_i32 + + @fields = Array(ODBC::Field).new + i = 0 + while i < @num_cols + @fields.push(ODBC::Field.new(statement.raw_stmt, i)) + i += 1 + end + + @buffer = Array(UInt8*).new(@num_cols, Pointer(UInt8).null) + @strlen = Array(Int64).new(@num_cols, 0) + i = 0 + while i < @num_cols + # kind of awkward workaround for dealing with an array of pointers and the fact that + # arrays are themselves built of pointers and somake accessing the contained pointers + # a bit clumsy + # + # TODO: a better way to handle this probably? + tmp_buf = Pointer(UInt8).malloc + + # and here, since we're calling C functions here we have to specify the length of the buffer into which + # we're reading the SqlCChars. does bind_col realloc strictly based on that? need to find some way to + # get around this since we'd rather dynamically rellocate memory to accommodate a large field than + # unnecessarily snip the end off + LibODBC.bind_col(statement.raw_stmt, i + 1, SqlCDataType::SqlCChar.value, tmp_buf.as(Void*), 256, out ind) + @buffer[i] = tmp_buf + @strlen[i] = ind + i += 1 + end + end + + protected def conn + statement.as(Statement).conn + end + + def move_next + if @row_index < @rows_affected - 1 + @row_index += 1 + true + else + false + end + end + + def column_count : Int32 + @num_cols + end + + def column_name(index : Int32) : String + @fields[index].name + end + + def column_type(index : Int32) : SqlDataType + @fields[index].col_type + end + + def read + case @col_index + when 0 + result = LibODBC.fetch_scroll(statement.raw_stmt, LibODBC::FetchOrientation::SqlFetchAbsolute, @row_index + 1) + if result != LibODBC::SqlReturn::SqlSuccess && result != LibODBC::SqlReturn::SqlSuccessWithInfo + err = ODBC.get_detail("SQLFetchScroll", statement.raw_stmt, 1) + raise "Error fetching row #{@row_index + 1}: #{err}" + end + + @col_index += 1 + return @buffer[0] + when .<(@num_cols) + value = buffer[@col_index] + @col_index += 1 + return value + else + @col_index = 0 + @row_index += 1 + end + end + end +end diff --git a/src/odbc/statement.cr b/src/odbc/statement.cr index 9cfe347..5a8b6c1 100644 --- a/src/odbc/statement.cr +++ b/src/odbc/statement.cr @@ -1,7 +1,14 @@ module ODBC class Statement < DB::Statement - def initialize(connection, @query_sql : String) + @raw_stmt : Void* + @encoded_query : Bytes + getter raw_stmt + + def initialize(connection, query : String) super(connection) + @raw_stmt = Pointer(Void).null + + @encoded_query = ODBC.encode_nts(query) end protected def conn @@ -9,11 +16,27 @@ module ODBC end protected def perform_query(args : Enumerable) : ODBC::ResultSet - body = ODBC.alloc_statement(@connection) - LibODBC.tables(body, nil, 0, nil, 0, nil, 0, "TABLE", 6) + @raw_stmt = ODBC.alloc_stmt(@connection.raw_conn) + + prep_result = LibODBC.prepare(raw_stmt, @encoded_query.to_unsafe, @encoded_query.size) + if prep_result != LibODBC::SqlReturn::SqlSuccess && prep_result != LibODBC::SqlReturn::SqlSuccessWithInfo + err = ODBC.get_detail("SQLPrepare", @raw_stmt, 1) + raise "Error preparing SQL statement: #{err}" + end + + exec_result = LibODBC.execute(raw_stmt) + if exec_result != LibODBC::SqlReturn::SqlSuccess && exec_result != LibODBC::SqlReturn::SqlSuccessWithInfo + err = ODBC.get_detail("SQLExecute", @raw_stmt, 1) + raise "Error executing SQL statement: #{err}" + end + + ODBC::ResultSet.new(self) end - protected def perform_exec(args : Enumerable) : ::DB::ExecResult + protected def perform_exec(args : Enumerable) : DB::ExecResult + result = perform_query(args) + result.each { } + DB::ExecResult.new(rows_affected: result.rows_affected, last_insert_id: 0_i64) end end end diff --git a/src/odbc/types.cr b/src/odbc/types.cr index f0f28da..b5bff6e 100644 --- a/src/odbc/types.cr +++ b/src/odbc/types.cr @@ -1,30 +1,80 @@ -alias SqlChar = UInt8 - -alias SqlWChar = UInt32 - -alias SqlSmallInt = Int16 - -alias SqlUSmallInt = UInt16 - -alias SqlInteger = Int32 - -alias SqlUInteger = UInt32 - -alias SqlReal = Float32 - -alias SqlDouble = Float64 - -alias SqlFloat = Float64 - -alias SqlBigInt = Int64 - -alias SqlUBigInt = UInt64 - -alias Bookmark = Array(UInt32) +enum SqlDataType + SqlUnknownType = 0 + SqlChar = 1 + SqlNumeric = 2 + SqlDecimal = 3 + SqlInteger = 4 + SqlSmallInt = 5 + SqlFloat = 6 + SqlReal = 7 + SqlDouble = 8 + SqlDatetime = 9 + SqlVarchar = 12 + SqlUdt = 17 + SqlRow = 19 + SqlArray = 50 + SqlMultiset = 55 + SqlDate = 91 + SqlTime = 92 + SqlTimestamp = 93 + SqlTimeWithTimezone = 94 + SqlTimestampWithTimezone = 95 + SqlExtlongVarchar = -1 + SqlExtBinary = -2 + SqlExtVarbinary = -3 + SqlExtlongvarbinary = -4 + SqlExtBigInt = -5 + SqlExtTinyInt = -6 + SqlExtBit = -7 + SqlExtWChar = -8 + SqlExtWVarchar = -9 + SqlExtwLongVarchar = -10 + SqlExtGuid = -11 + SqlSsVariant = -150 + SqlSsUdt = -151 + SqlSsXml = -152 + SqlSsTable = -153 + SqlSsTime2 = -154 + SqlSsTimestampOffset = -155 +end -enum SqlResult - SqlSuccess - SqlSuccessWithInfo - SqlInvalidHandle - SqlError +enum SqlCDataType + SqlCUTinyInt = -28 + SqlCUBigInt = -27 + SqlCSTinyInt = -26 + SqlCSBigInt = -25 + SqlCULong = -18 + SqlCUShort = -17 + SqlCSLong = -16 + SqlCSShort = -15 + SqlCGuid = -11 + SqlCWChar = -8 + SqlCBit = -7 + SqlCBinary = -2 + SqlCChar = 1 + SqlCNumeric = 2 + SqlCFloat = 7 + SqlCDouble = 8 + SqlCDate = 9 + SqlCTime = 10 + SqlCTimestamp = 11 + SqlCTypeDate = 91 + SqlCTypeTime = 92 + SqlCTypeTimestamp = 93 + SqlCTypeTimeWithTimezone = 94 + SqlCTypeTimestampWithTimezone = 95 + SqlCDefault = 99 + SqlCIntervalYear = 101 + SqlCIntervalMonth = 102 + SqlCIntervalDay = 103 + SqlCIntervalHour = 104 + SqlCIntervalMinute = 105 + SqlCIntervalSecond = 106 + SqlCIntervalYearToMonth = 107 + SqlCIntervalDayToHour = 108 + SqlCIntervalDayToMinute = 109 + SqlCIntervalDayToSecond = 110 + SqlCIntervalHourToMinute = 111 + SqlCIntervalHourToSecond = 112 + SqlCIntervalMinuteToSecond = 113 end