diff --git a/README.md b/README.md index b655176b..46e4eeab 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,26 @@ Other supported formats are listed below. * `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar" * `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with `}}`, password is "foo}bar" +### Azure Active Directory authentication - preview + +The configuration of functionality might change in the future. + +Azure Active Directory (AAD) access tokens are relatively short lived and need to be +valid when a new connection is made. Authentication is supported using a callback func that +provides a fresh and valid token using a connector: +``` golang +conn, err := mssql.NewAccessTokenConnector( + "Server=test.database.windows.net;Database=testdb", + tokenProvider) +if err != nil { + // handle errors in DSN +} +db := sql.OpenDB(conn) +``` +Where `tokenProvider` is a function that returns a fresh access token or an error. None of these statements +actually trigger the retrieval of a token, this happens when the first statment is issued and a connection +is created. + ## Executing Stored Procedures To run a stored procedure, set the query text to the procedure name: diff --git a/accesstokenconnector.go b/accesstokenconnector.go new file mode 100644 index 00000000..8dbe5099 --- /dev/null +++ b/accesstokenconnector.go @@ -0,0 +1,51 @@ +// +build go1.10 + +package mssql + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" +) + +var _ driver.Connector = &accessTokenConnector{} + +// accessTokenConnector wraps Connector and injects a +// fresh access token when connecting to the database +type accessTokenConnector struct { + Connector + + accessTokenProvider func() (string, error) +} + +// NewAccessTokenConnector creates a new connector from a DSN and a token provider. +// The token provider func will be called when a new connection is requested and should return a valid access token. +// The returned connector may be used with sql.OpenDB. +func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) (driver.Connector, error) { + if tokenProvider == nil { + return nil, errors.New("mssql: tokenProvider cannot be nil") + } + + conn, err := NewConnector(dsn) + if err != nil { + return nil, err + } + + c := &accessTokenConnector{ + Connector: *conn, + accessTokenProvider: tokenProvider, + } + return c, nil +} + +// Connect returns a new database connection +func (c *accessTokenConnector) Connect(ctx context.Context) (driver.Conn, error) { + var err error + c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider() + if err != nil { + return nil, fmt.Errorf("mssql: error retrieving access token: %+v", err) + } + + return c.Connector.Connect(ctx) +} diff --git a/accesstokenconnector_test.go b/accesstokenconnector_test.go new file mode 100644 index 00000000..826dedba --- /dev/null +++ b/accesstokenconnector_test.go @@ -0,0 +1,92 @@ +// +build go1.10 + +package mssql + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" + "strings" + "testing" +) + +func TestNewAccessTokenConnector(t *testing.T) { + dsn := "Server=server.database.windows.net;Database=db" + tp := func() (string, error) { return "token", nil } + type args struct { + dsn string + tokenProvider func() (string, error) + } + tests := []struct { + name string + args args + want func(driver.Connector) error + wantErr bool + }{ + { + name: "Happy path", + args: args{ + dsn: dsn, + tokenProvider: tp}, + want: func(c driver.Connector) error { + tc, ok := c.(*accessTokenConnector) + if !ok { + return fmt.Errorf("Expected driver to be of type *accessTokenConnector, but got %T", c) + } + p := tc.Connector.params + if p.database != "db" { + return fmt.Errorf("expected params.database=db, but got %v", p.database) + } + if p.host != "server.database.windows.net" { + return fmt.Errorf("expected params.host=server.database.windows.net, but got %v", p.host) + } + if tc.accessTokenProvider == nil { + return fmt.Errorf("Expected tokenProvider to not be nil") + } + t, err := tc.accessTokenProvider() + if t != "token" || err != nil { + return fmt.Errorf("Unexpected results from tokenProvider: %v, %v", t, err) + } + return nil + }, + wantErr: false, + }, + { + name: "Nil tokenProvider gives error", + args: args{ + dsn: dsn, + tokenProvider: nil}, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewAccessTokenConnector(tt.args.dsn, tt.args.tokenProvider) + if (err != nil) != tt.wantErr { + t.Errorf("NewAccessTokenConnector() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.want != nil { + if err := tt.want(got); err != nil { + t.Error(err) + } + } + }) + } +} + +func TestAccessTokenConnectorFailsToConnectIfNoAccessToken(t *testing.T) { + errorText := "This is a test" + dsn := "Server=server.database.windows.net;Database=db" + tp := func() (string, error) { return "", errors.New(errorText) } + sut, err := NewAccessTokenConnector(dsn, tp) + if err != nil { + t.Fatalf("expected err==nil, but got %+v", err) + } + _, err = sut.Connect(context.TODO()) + if err == nil || !strings.Contains(err.Error(), errorText) { + t.Fatalf("expected error to contain %q, but got %q", errorText, err) + } +} diff --git a/conn_str.go b/conn_str.go index 4ff54b89..26ac50f3 100644 --- a/conn_str.go +++ b/conn_str.go @@ -37,6 +37,7 @@ type connectParams struct { failOverPartner string failOverPort uint64 packetSize uint16 + fedAuthAccessToken string } func parseConnectParams(dsn string) (connectParams, error) { diff --git a/examples/azuread-accesstoken/README.md b/examples/azuread-accesstoken/README.md new file mode 100644 index 00000000..cb20a760 --- /dev/null +++ b/examples/azuread-accesstoken/README.md @@ -0,0 +1,9 @@ +## Azure Managed Identity example + +This example shows how Azure Managed Identity can be used to access SQL Azure. Take note of the +trust boundary before using MSI to prevent exposure of the tokens outside of the trust boundary. + +This example can only be run from a Azure Virtual Machine with Managed Identity configured. +You can follow the steps from this tutorial to turn on managed identity for your VM and grant the +VM access to a SQL Azure database: +https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/tutorial-windows-vm-access-sql diff --git a/examples/azuread-accesstoken/go.mod b/examples/azuread-accesstoken/go.mod new file mode 100644 index 00000000..eb113f3e --- /dev/null +++ b/examples/azuread-accesstoken/go.mod @@ -0,0 +1,8 @@ +module github.com/denisenkom/go-mssqldb/examples/azure-ad-accesstoken + +go 1.13 + +require ( + github.com/Azure/go-autorest/autorest/adal v0.8.1 + github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73 +) diff --git a/examples/azuread-accesstoken/go.sum b/examples/azuread-accesstoken/go.sum new file mode 100644 index 00000000..17148ff8 --- /dev/null +++ b/examples/azuread-accesstoken/go.sum @@ -0,0 +1,30 @@ +github.com/Azure/go-autorest v13.3.2+incompatible h1:VxzPyuhtnlBOzc4IWCZHqpyH2d+QMLQEuy3wREyY4oc= +github.com/Azure/go-autorest/autorest v0.9.0 h1:MRvx8gncNaXJqOoLmhNjUAKh33JJF8LyxPhomEtOsjs= +github.com/Azure/go-autorest/autorest v0.9.0/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI= +github.com/Azure/go-autorest/autorest v0.9.4 h1:1cM+NmKw91+8h5vfjgzK4ZGLuN72k87XVZBWyGwNjUM= +github.com/Azure/go-autorest/autorest/adal v0.5.0/go.mod h1:8Z9fGy2MpX0PvDjB1pEgQTmVqjGhiHBW7RJJEciWzS0= +github.com/Azure/go-autorest/autorest/adal v0.8.1 h1:pZdL8o72rK+avFWl+p9nE8RWi1JInZrWJYlnpfXJwHk= +github.com/Azure/go-autorest/autorest/adal v0.8.1/go.mod h1:ZjhuQClTqx435SRJ2iMlOxPYt3d2C/T/7TiQCVZSn3Q= +github.com/Azure/go-autorest/autorest/date v0.1.0/go.mod h1:plvfp3oPSKwf2DNjlBjWF/7vwR+cUD/ELuzDCXwHUVA= +github.com/Azure/go-autorest/autorest/date v0.2.0 h1:yW+Zlqf26583pE43KhfnhFcdmSWlm5Ew6bxipnr/tbM= +github.com/Azure/go-autorest/autorest/date v0.2.0/go.mod h1:vcORJHLJEh643/Ioh9+vPmf1Ij9AEBM5FuBIXLmIy0g= +github.com/Azure/go-autorest/autorest/mocks v0.1.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0= +github.com/Azure/go-autorest/autorest/mocks v0.2.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0= +github.com/Azure/go-autorest/autorest/mocks v0.3.0/go.mod h1:a8FDP3DYzQ4RYfVAxAN3SVSiiO77gL2j2ronKKP0syM= +github.com/Azure/go-autorest/logger v0.1.0/go.mod h1:oExouG+K6PryycPJfVSxi/koC6LSNgds39diKLz7Vrc= +github.com/Azure/go-autorest/tracing v0.5.0 h1:TRn4WjSnkcSy5AEG3pnbtFSwNtwzjr4VYyQflFE619k= +github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= +github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73 h1:OGNva6WhsKst5OZf7eZOklDztV3hwtTHovdrLHV+MsA= +github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g= +golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/examples/azuread-accesstoken/managed_identity.go b/examples/azuread-accesstoken/managed_identity.go new file mode 100644 index 00000000..0adf99cb --- /dev/null +++ b/examples/azuread-accesstoken/managed_identity.go @@ -0,0 +1,82 @@ +package main + +import ( + "database/sql" + "flag" + "fmt" + "log" + + "github.com/Azure/go-autorest/autorest/adal" + mssql "github.com/denisenkom/go-mssqldb" +) + +var ( + debug = flag.Bool("debug", false, "enable debugging") + server = flag.String("server", "", "the database server") + database = flag.String("database", "", "the database") +) + +func main() { + flag.Parse() + + if *debug { + fmt.Printf(" server:%s\n", *server) + fmt.Printf(" database:%s\n", *database) + } + + if *server == "" { + log.Fatal("Server name cannot be left empty") + } + + if *database == "" { + log.Fatal("Database name cannot be left empty") + } + + connString := fmt.Sprintf("Server=%s;Database=%s", *server, *database) + if *debug { + fmt.Printf(" connString:%s\n", connString) + } + + tokenProvider, err := getMSITokenProvider() + if err != nil { + log.Fatal("Error creating token provider for system assigned Azure Managed Identity:", err.Error()) + } + + connector, err := mssql.NewAccessTokenConnector( + connString, tokenProvider) + if err != nil { + log.Fatal("Connector creation failed:", err.Error()) + } + conn := sql.OpenDB(connector) + defer conn.Close() + + row := conn.QueryRow("select 1, 'abc'") + var somenumber int64 + var somechars string + err = row.Scan(&somenumber, &somechars) + if err != nil { + log.Fatal("Scan failed:", err.Error()) + } + fmt.Printf("somenumber:%d\n", somenumber) + fmt.Printf("somechars:%s\n", somechars) + + fmt.Printf("bye\n") +} + +func getMSITokenProvider() (func() (string, error), error) { + msiEndpoint, err := adal.GetMSIEndpoint() + if err != nil { + return nil, err + } + msi, err := adal.NewServicePrincipalTokenFromMSI( + msiEndpoint, "https://database.windows.net/") + if err != nil { + return nil, err + } + + return func() (string, error) { + msi.EnsureFresh() + token := msi.OAuthToken() + return token, nil + }, nil +} diff --git a/tds.go b/tds.go index 9a8376da..80bb8967 100644 --- a/tds.go +++ b/tds.go @@ -100,13 +100,15 @@ const ( // prelogin fields // http://msdn.microsoft.com/en-us/library/dd357559.aspx const ( - preloginVERSION = 0 - preloginENCRYPTION = 1 - preloginINSTOPT = 2 - preloginTHREADID = 3 - preloginMARS = 4 - preloginTRACEID = 5 - preloginTERMINATOR = 0xff + preloginVERSION = 0 + preloginENCRYPTION = 1 + preloginINSTOPT = 2 + preloginTHREADID = 3 + preloginMARS = 4 + preloginTRACEID = 5 + preloginFEDAUTHREQUIRED = 6 + preloginNONCEOPT = 7 + preloginTERMINATOR = 0xff ) const ( @@ -245,6 +247,12 @@ const ( fReadOnlyIntent = 32 ) +// OptionFlags3 +// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac +const ( + fExtension = 0x10 +) + type login struct { TDSVersion uint32 PacketSize uint32 @@ -269,6 +277,89 @@ type login struct { SSPI []byte AtchDBFile string ChangePassword string + FeatureExt featureExts +} + +type featureExts struct { + features map[byte]featureExt +} + +type featureExt interface { + featureID() byte + toBytes() []byte +} + +func (e *featureExts) Add(f featureExt) error { + if f == nil { + return nil + } + id := f.featureID() + if _, exists := e.features[id]; exists { + f := "Login error: Feature with ID '%v' is already present in FeatureExt block." + return fmt.Errorf(f, id) + } + if e.features == nil { + e.features = make(map[byte]featureExt) + } + e.features[id] = f + return nil +} + +func (e featureExts) toBytes() []byte { + if len(e.features) == 0 { + return nil + } + var d []byte + for featureID, f := range e.features { + featureData := f.toBytes() + + hdr := make([]byte, 5) + hdr[0] = featureID // FedAuth feature extension BYTE + binary.LittleEndian.PutUint32(hdr[1:], uint32(len(featureData))) // FeatureDataLen DWORD + d = append(d, hdr...) + + d = append(d, featureData...) // FeatureData *BYTE + } + if d != nil { + d = append(d, 0xff) // Terminator + } + return d +} + +type featureExtFedAuthSTS struct { + FedAuthEcho bool + FedAuthToken string + Nonce []byte +} + +func (e *featureExtFedAuthSTS) featureID() byte { + return 0x02 +} + +func (e *featureExtFedAuthSTS) toBytes() []byte { + if e == nil { + return nil + } + + options := byte(0x01) << 1 // 0x01 => STS bFedAuthLibrary 7BIT + if e.FedAuthEcho { + options |= 1 // fFedAuthEcho + } + + d := make([]byte, 5) + d[0] = options + + // looks like string in + // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508 + tokenBytes := str2ucs2(e.FedAuthToken) + binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work + d = append(d, tokenBytes...) + + if len(e.Nonce) == 32 { + d = append(d, e.Nonce...) + } + + return d } type loginHeader struct { @@ -295,7 +386,7 @@ type loginHeader struct { ServerNameOffset uint16 ServerNameLength uint16 ExtensionOffset uint16 - ExtensionLenght uint16 + ExtensionLength uint16 CtlIntNameOffset uint16 CtlIntNameLength uint16 LanguageOffset uint16 @@ -357,6 +448,8 @@ func sendLogin(w *tdsBuffer, login login) error { database := str2ucs2(login.Database) atchdbfile := str2ucs2(login.AtchDBFile) changepassword := str2ucs2(login.ChangePassword) + featureExt := login.FeatureExt.toBytes() + hdr := loginHeader{ TDSVersion: login.TDSVersion, PacketSize: login.PacketSize, @@ -405,7 +498,18 @@ func sendLogin(w *tdsBuffer, login login) error { offset += uint16(len(atchdbfile)) hdr.ChangePasswordOffset = offset offset += uint16(len(changepassword)) - hdr.Length = uint32(offset) + + featureExtOffset := uint32(0) + featureExtLen := len(featureExt) + if featureExtLen > 0 { + hdr.OptionFlags3 |= fExtension + hdr.ExtensionOffset = offset + hdr.ExtensionLength = 4 + offset += hdr.ExtensionLength // DWORD + featureExtOffset = uint32(offset) + } + hdr.Length = uint32(offset) + uint32(featureExtLen) + var err error err = binary.Write(w, binary.LittleEndian, &hdr) if err != nil { @@ -455,6 +559,16 @@ func sendLogin(w *tdsBuffer, login login) error { if err != nil { return err } + if featureExtOffset > 0 { + err = binary.Write(w, binary.LittleEndian, featureExtOffset) + if err != nil { + return err + } + _, err = w.Write(featureExt) + if err != nil { + return err + } + } return w.FinishPacket() } @@ -844,15 +958,23 @@ initiate_connection: AppName: p.appname, TypeFlags: p.typeFlags, } - auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation) - if auth_ok { + auth, authOk := getAuth(p.user, p.password, p.serverSPN, p.workstation) + switch { + case p.fedAuthAccessToken != "": // accesstoken ignores user/password + featurext := &featureExtFedAuthSTS{ + FedAuthEcho: len(fields[preloginFEDAUTHREQUIRED]) > 0 && fields[preloginFEDAUTHREQUIRED][0] == 1, + FedAuthToken: p.fedAuthAccessToken, + Nonce: fields[preloginNONCEOPT], + } + login.FeatureExt.Add(featurext) + case authOk: login.SSPI, err = auth.InitialBytes() if err != nil { return nil, err } login.OptionFlags2 |= fIntSecurity defer auth.Free() - } else { + default: login.UserName = p.user login.Password = p.password } diff --git a/tds_test.go b/tds_test.go index 00360d85..e725a668 100644 --- a/tds_test.go +++ b/tds_test.go @@ -5,7 +5,6 @@ import ( "context" "database/sql" "encoding/hex" - "fmt" "net/url" "os" "runtime" @@ -62,11 +61,60 @@ func TestSendLogin(t *testing.T) { 116, 0, 104, 0} out := memBuf.Bytes() if !bytes.Equal(ref, out) { - fmt.Println("Expected:") - fmt.Print(hex.Dump(ref)) - fmt.Println("Returned:") - fmt.Print(hex.Dump(out)) - t.Error("input output don't match") + t.Log("Expected:") + t.Log(hex.Dump(ref)) + t.Log("Returned:") + t.Log(hex.Dump(out)) + t.Fatal("input output don't match") + } +} + +func TestSendLoginWithFeatureExt(t *testing.T) { + memBuf := new(MockTransport) + buf := newTdsBuffer(1024, memBuf) + login := login{ + TDSVersion: verTDS74, + PacketSize: 0x1000, + ClientProgVer: 0x01060100, + ClientPID: 100, + ClientTimeZone: -4 * 60, + ClientID: [6]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab}, + OptionFlags1: 0xe0, + OptionFlags3: 8, + HostName: "subdev1", + AppName: "appname", + ServerName: "servername", + CtlIntName: "library", + Language: "en", + Database: "database", + ClientLCID: 0x204, + } + login.FeatureExt.Add(&featureExtFedAuthSTS{ + FedAuthToken: "fedauthtoken", + }) + err := sendLogin(buf, login) + if err != nil { + t.Error("sendLogin should succeed") + } + ref := []byte{ + 16, 1, 0, 223, 0, 0, 1, 0, 215, 0, 0, 0, 4, 0, 0, 116, 0, 16, 0, 0, 0, 1, + 6, 1, 100, 0, 0, 0, 0, 0, 0, 0, 224, 0, 0, 24, 16, 255, 255, 255, 4, 2, 0, + 0, 94, 0, 7, 0, 108, 0, 0, 0, 108, 0, 0, 0, 108, 0, 7, 0, 122, 0, 10, 0, 176, + 0, 4, 0, 142, 0, 7, 0, 156, 0, 2, 0, 160, 0, 8, 0, 18, 52, 86, 120, 144, 171, + 176, 0, 0, 0, 176, 0, 0, 0, 176, 0, 0, 0, 0, 0, 0, 0, 115, 0, 117, 0, 98, + 0, 100, 0, 101, 0, 118, 0, 49, 0, 97, 0, 112, 0, 112, 0, 110, 0, 97, 0, + 109, 0, 101, 0, 115, 0, 101, 0, 114, 0, 118, 0, 101, 0, 114, 0, 110, 0, 97, + 0, 109, 0, 101, 0, 108, 0, 105, 0, 98, 0, 114, 0, 97, 0, 114, 0, 121, 0, 101, + 0, 110, 0, 100, 0, 97, 0, 116, 0, 97, 0, 98, 0, 97, 0, 115, 0, 101, 0, 180, 0, + 0, 0, 2, 29, 0, 0, 0, 2, 24, 0, 0, 0, 102, 0, 101, 0, 100, 0, 97, 0, 117, 0, + 116, 0, 104, 0, 116, 0, 111, 0, 107, 0, 101, 0, 110, 0, 255} + out := memBuf.Bytes() + if !bytes.Equal(ref, out) { + t.Log("Expected:") + t.Log(hex.Dump(ref)) + t.Log("Returned:") + t.Log(hex.Dump(out)) + t.Fatal("input output don't match") } } @@ -109,7 +157,7 @@ loop: case []interface{}: lastRow = token default: - fmt.Println("unknown token", tok) + t.Log("unknown token", tok) } } @@ -176,8 +224,6 @@ func (l testLogger) Println(v ...interface{}) { l.t.Log(v...) } - - func TestConnect(t *testing.T) { checkConnStr(t) SetLogger(testLogger{t}) @@ -413,7 +459,7 @@ func TestSSPIAuth(t *testing.T) { func TestUcs22Str(t *testing.T) { // Test valid input - s, err := ucs22str([]byte{0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding + s, err := ucs22str([]byte{0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding if err != nil { t.Errorf("ucs22str should not fail for valid ucs2 byte sequence: %s", err) } @@ -429,7 +475,7 @@ func TestUcs22Str(t *testing.T) { } func TestReadUcs2(t *testing.T) { - buf := bytes.NewBuffer([]byte{0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding + buf := bytes.NewBuffer([]byte{0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding s, err := readUcs2(buf, 3) if err != nil { t.Errorf("readUcs2 should not fail for valid ucs2 byte sequence: %s", err) @@ -447,7 +493,7 @@ func TestReadUcs2(t *testing.T) { func TestReadUsVarChar(t *testing.T) { // should succeed for valid buffer - buf := bytes.NewBuffer([]byte{3, 0, 0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding with length prefix 3 uint16 + buf := bytes.NewBuffer([]byte{3, 0, 0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding with length prefix 3 uint16 s, err := readUsVarChar(buf) if err != nil { t.Errorf("readUsVarChar should not fail for valid ucs2 byte sequence: %s", err) @@ -487,4 +533,4 @@ func TestReadBVarByte(t *testing.T) { if err == nil { t.Error("readUsVarByte should fail on short buffer, but it didn't") } -} \ No newline at end of file +} diff --git a/token.go b/token.go index 1acac8a5..25385e89 100644 --- a/token.go +++ b/token.go @@ -17,20 +17,21 @@ type token byte // token ids const ( - tokenReturnStatus token = 121 // 0x79 - tokenColMetadata token = 129 // 0x81 - tokenOrder token = 169 // 0xA9 - tokenError token = 170 // 0xAA - tokenInfo token = 171 // 0xAB - tokenReturnValue token = 0xAC - tokenLoginAck token = 173 // 0xad - tokenRow token = 209 // 0xd1 - tokenNbcRow token = 210 // 0xd2 - tokenEnvChange token = 227 // 0xE3 - tokenSSPI token = 237 // 0xED - tokenDone token = 253 // 0xFD - tokenDoneProc token = 254 - tokenDoneInProc token = 255 + tokenReturnStatus token = 121 // 0x79 + tokenColMetadata token = 129 // 0x81 + tokenOrder token = 169 // 0xA9 + tokenError token = 170 // 0xAA + tokenInfo token = 171 // 0xAB + tokenReturnValue token = 0xAC + tokenLoginAck token = 173 // 0xad + tokenFeatureExtAck token = 174 // 0xae + tokenRow token = 209 // 0xd1 + tokenNbcRow token = 210 // 0xd2 + tokenEnvChange token = 227 // 0xE3 + tokenSSPI token = 237 // 0xED + tokenDone token = 253 // 0xFD + tokenDoneProc token = 254 + tokenDoneInProc token = 255 ) // done flags @@ -447,6 +448,22 @@ func parseLoginAck(r *tdsBuffer) loginAckStruct { return res } +// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/2eb82f8e-11f0-46dc-b42d-27302fa4701a +func parseFeatureExtAck(r *tdsBuffer) { + // at most 1 featureAck per feature in featureExt + // go-mssqldb will add at most 1 feature, the spec defines 7 different features + for i := 0; i < 8; i++ { + featureID := r.byte() // FeatureID + if featureID == 0xff { + return + } + size := r.uint32() // FeatureAckDataLen + d := make([]byte, size) + r.ReadFull(d) + } + panic("parsed more than 7 featureAck's, protocol implementation error?") +} + // http://msdn.microsoft.com/en-us/library/dd357363.aspx func parseColMetadata72(r *tdsBuffer) (columns []columnStruct) { count := r.uint16() @@ -577,6 +594,8 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin case tokenLoginAck: loginAck := parseLoginAck(sess.buf) ch <- loginAck + case tokenFeatureExtAck: + parseFeatureExtAck(sess.buf) case tokenOrder: order := parseOrder(sess.buf) ch <- order