forked from nim4/DBShield
-
Notifications
You must be signed in to change notification settings - Fork 0
/
oracle.go
142 lines (126 loc) · 2.99 KB
/
oracle.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
package dbms
import (
"bytes"
"crypto/tls"
"io"
"net"
"time"
"github.com/nim4/DBShield/dbshield/logger"
"github.com/nim4/DBShield/dbshield/sql"
)
//Oracle DBMS
type Oracle struct {
client net.Conn
server net.Conn
certificate tls.Certificate
currentDB []byte
username []byte
reader func(io.Reader) ([]byte, error)
}
//SetCertificate to use if client asks for SSL
func (o *Oracle) SetCertificate(crt, key string) (err error) {
o.certificate, err = tls.LoadX509KeyPair(crt, key)
return
}
//SetReader function for sockets IO
func (o *Oracle) SetReader(f func(io.Reader) ([]byte, error)) {
o.reader = f
}
//SetSockets for dbms (client and server sockets)
func (o *Oracle) SetSockets(c, s net.Conn) {
defer handlePanic()
o.client = c
o.server = s
}
//Close sockets
func (o *Oracle) Close() {
defer handlePanic()
o.client.Close()
o.server.Close()
}
//DefaultPort of the DBMS
func (o *Oracle) DefaultPort() uint {
return 1521
}
//Handler gets incoming requests
func (o *Oracle) Handler() error {
defer handlePanic()
defer o.Close()
for {
buf, err := o.readPacket(o.client)
if err != nil {
return err
}
var eof bool
switch buf[4] { //Packet Type
case 0x01: //Connect
connectDataLen := int(buf[24])*256 + int(buf[25])
connectData := buf[len(buf)-connectDataLen:]
//Extracting Service name
// FIXME: avoid string
tmp1 := bytes.Split(connectData, []byte("SERVICE_NAME="))
tmp2 := bytes.Split(tmp1[1], []byte{0x29}) // )
o.currentDB = tmp2[0]
logger.Debugf("Connect Data: %s", connectData)
logger.Debugf("Service Name: %s", o.currentDB)
case 0x06: //Data
data := buf[8:]
eof = data[1] == 0x40
if !eof {
payload := data[2:]
if len(payload) > 16 && payload[0] == 0x11 && payload[15] == 0x03 && payload[16] == 0x5e {
// I have no idea what this TTC is but its on top of query
//simply skiping it
payload = payload[15:]
}
switch payload[0] {
case 0x03:
switch payload[1] {
case 0x5e: //reading query
query, _ := pascalString(payload[70:])
context := sql.QueryContext{
Query: query,
Database: o.currentDB,
User: o.username,
Client: remoteAddrToIP(o.client.RemoteAddr()),
Time: time.Now(),
}
processContext(context)
case 0x76: // Reading username
val, _ := pascalString(payload[19:])
o.username = val
logger.Debugf("Username: %s", o.username)
}
}
}
}
_, err = o.server.Write(buf)
if err != nil || eof {
return err
}
err = readWrite(o.server, o.client, o.readPacket)
if err != nil {
return err
}
}
}
//wrapper around our classic readPacket to handle segmented packets
func (o *Oracle) readPacket(c io.Reader) (buf []byte, err error) {
buf, err = o.reader(c)
if err != nil {
return
}
packetLen := int(buf[0])*256 + int(buf[1])
for {
if len(buf) == packetLen {
break
}
var b []byte
b, err = o.reader(c)
if err != nil {
return
}
buf = append(buf, b...)
}
return
}