Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented postgres #43

Merged
merged 17 commits into from Mar 13, 2019
15 changes: 8 additions & 7 deletions database/database.go
Expand Up @@ -19,6 +19,7 @@ package database
import (
"database/sql"

_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"

log "maunium.net/go/maulogger/v2"
Expand All @@ -34,8 +35,8 @@ type Database struct {
Message *MessageQuery
}

func New(file string) (*Database, error) {
conn, err := sql.Open("sqlite3", file)
func New(dbType string, uri string) (*Database, error) {
conn, err := sql.Open(dbType, uri)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -63,20 +64,20 @@ func New(file string) (*Database, error) {
return db, nil
}

func (db *Database) CreateTables() error {
err := db.User.CreateTable()
func (db *Database) CreateTables(dbType string) error {
err := db.User.CreateTable(dbType)
if err != nil {
return err
}
err = db.Portal.CreateTable()
err = db.Portal.CreateTable(dbType)
if err != nil {
return err
}
err = db.Puppet.CreateTable()
err = db.Puppet.CreateTable(dbType)
if err != nil {
return err
}
err = db.Message.CreateTable()
err = db.Message.CreateTable(dbType)
if err != nil {
return err
}
Expand Down
48 changes: 32 additions & 16 deletions database/message.go
Expand Up @@ -18,6 +18,7 @@ package database

import (
"bytes"
"strings"
"database/sql"
"encoding/json"

Expand All @@ -33,19 +34,34 @@ type MessageQuery struct {
log log.Logger
}

func (mq *MessageQuery) CreateTable() error {
_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
chat_jid VARCHAR(25),
chat_receiver VARCHAR(25),
jid VARCHAR(255),
mxid VARCHAR(255) NOT NULL UNIQUE,
sender VARCHAR(25) NOT NULL,
content BLOB NOT NULL,

PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
)`)
func (mq *MessageQuery) CreateTable(dbType string) error {
if strings.ToLower(dbType) == "postgres" {
_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
chat_jid VARCHAR(255),
chat_receiver VARCHAR(255),
jid VARCHAR(255),
mxid VARCHAR(255) NOT NULL UNIQUE,
sender VARCHAR(255) NOT NULL,
content bytea NOT NULL,

PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
)`)
return err
} else {
_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
chat_jid VARCHAR(255),
chat_receiver VARCHAR(255),
jid VARCHAR(255),
mxid VARCHAR(255) NOT NULL UNIQUE,
sender VARCHAR(255) NOT NULL,
content BLOB NOT NULL,

PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
)`)
return err
}
}

func (mq *MessageQuery) New() *Message {
Expand All @@ -56,7 +72,7 @@ func (mq *MessageQuery) New() *Message {
}

func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=?", chat.JID, chat.Receiver)
rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=$1 AND chat_receiver=$2", chat.JID, chat.Receiver)
if err != nil || rows == nil {
return nil
}
Expand All @@ -68,11 +84,11 @@ func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
}

func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *Message {
return mq.get("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=? AND jid=?", chat.JID, chat.Receiver, jid)
return mq.get("SELECT * FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid)
}

func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message {
return mq.get("SELECT * FROM message WHERE mxid=?", mxid)
return mq.get("SELECT * FROM message WHERE mxid=$1", mxid)
}

func (mq *MessageQuery) get(query string, args ...interface{}) *Message {
Expand Down Expand Up @@ -130,7 +146,7 @@ func (msg *Message) encodeBinaryContent() []byte {
}

func (msg *Message) Insert() {
_, err := msg.db.Exec("INSERT INTO message VALUES (?, ?, ?, ?, ?, ?)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.encodeBinaryContent())
_, err := msg.db.Exec("INSERT INTO message VALUES ($1, $2, $3, $4, $5, $6)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.encodeBinaryContent())
if err != nil {
msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
}
Expand Down
17 changes: 8 additions & 9 deletions database/portal.go
Expand Up @@ -59,18 +59,17 @@ type PortalQuery struct {
log log.Logger
}

func (pq *PortalQuery) CreateTable() error {
func (pq *PortalQuery) CreateTable(dbType string) error {
_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal (
jid VARCHAR(25),
receiver VARCHAR(25),
jid VARCHAR(255),
receiver VARCHAR(255),
mxid VARCHAR(255) UNIQUE,

name VARCHAR(255) NOT NULL,
topic VARCHAR(255) NOT NULL,
avatar VARCHAR(255) NOT NULL,

PRIMARY KEY (jid, receiver),
FOREIGN KEY (receiver) REFERENCES user(mxid)
PRIMARY KEY (jid, receiver)
)`)
return err
}
Expand All @@ -95,11 +94,11 @@ func (pq *PortalQuery) GetAll() (portals []*Portal) {
}

func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
return pq.get("SELECT * FROM portal WHERE jid=? AND receiver=?", key.JID, key.Receiver)
return pq.get("SELECT * FROM portal WHERE jid=$1 AND receiver=$2", key.JID, key.Receiver)
}

func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal {
return pq.get("SELECT * FROM portal WHERE mxid=?", mxid)
return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
}

func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
Expand Down Expand Up @@ -143,7 +142,7 @@ func (portal *Portal) mxidPtr() *string {
}

func (portal *Portal) Insert() {
_, err := portal.db.Exec("INSERT INTO portal VALUES (?, ?, ?, ?, ?, ?)",
_, err := portal.db.Exec("INSERT INTO portal VALUES ($1, $2, $3, $4, $5, $6)",
portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar)
if err != nil {
portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
Expand All @@ -155,7 +154,7 @@ func (portal *Portal) Update() {
if len(portal.MXID) > 0 {
mxid = &portal.MXID
}
_, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND receiver=?",
_, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4 WHERE jid=$5 AND receiver=$6",
mxid, portal.Name, portal.Topic, portal.Avatar, portal.Key.JID, portal.Key.Receiver)
if err != nil {
portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
Expand Down
12 changes: 6 additions & 6 deletions database/puppet.go
Expand Up @@ -29,12 +29,12 @@ type PuppetQuery struct {
log log.Logger
}

func (pq *PuppetQuery) CreateTable() error {
func (pq *PuppetQuery) CreateTable(dbType string) error {
_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet (
jid VARCHAR(25) PRIMARY KEY,
jid VARCHAR(255) PRIMARY KEY,
avatar VARCHAR(255),
displayname VARCHAR(255),
name_quality TINYINT
name_quality SMALLINT
)`)
return err
}
Expand All @@ -59,7 +59,7 @@ func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
}

func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet {
row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=?", jid)
row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=$1", jid)
if row == nil {
return nil
}
Expand Down Expand Up @@ -93,15 +93,15 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
}

func (puppet *Puppet) Insert() {
_, err := puppet.db.Exec("INSERT INTO puppet VALUES (?, ?, ?, ?)",
_, err := puppet.db.Exec("INSERT INTO puppet VALUES ($1, $2, $3, $4)",
puppet.JID, puppet.Avatar, puppet.Displayname, puppet.NameQuality)
if err != nil {
puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err)
}
}

func (puppet *Puppet) Update() {
_, err := puppet.db.Exec("UPDATE puppet SET displayname=?, name_quality=?, avatar=? WHERE jid=?",
_, err := puppet.db.Exec("UPDATE puppet SET displayname=$1, name_quality=$2, avatar=$3 WHERE jid=$4",
puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.JID)
if err != nil {
puppet.log.Warnfln("Failed to update %s->%s: %v", puppet.JID, err)
Expand Down
54 changes: 35 additions & 19 deletions database/user.go
Expand Up @@ -33,20 +33,36 @@ type UserQuery struct {
log log.Logger
}

func (uq *UserQuery) CreateTable() error {
_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS user (
mxid VARCHAR(255) PRIMARY KEY,
jid VARCHAR(25) UNIQUE,

management_room VARCHAR(255),

client_id VARCHAR(255),
client_token VARCHAR(255),
server_token VARCHAR(255),
enc_key BLOB,
mac_key BLOB
)`)
return err
func (uq *UserQuery) CreateTable(dbType string) error {
if strings.ToLower(dbType) == "postgres" {
_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS "user" (
mxid VARCHAR(255) PRIMARY KEY,
jid VARCHAR(255) UNIQUE,

management_room VARCHAR(255),

client_id VARCHAR(255),
client_token VARCHAR(255),
server_token VARCHAR(255),
enc_key bytea,
mac_key bytea
)`)
return err
} else {
_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS "user" (
mxid VARCHAR(255) PRIMARY KEY,
jid VARCHAR(255) UNIQUE,

management_room VARCHAR(255),

client_id VARCHAR(255),
client_token VARCHAR(255),
server_token VARCHAR(255),
enc_key BLOB,
mac_key BLOB
)`)
return err
}
}

func (uq *UserQuery) New() *User {
Expand All @@ -57,7 +73,7 @@ func (uq *UserQuery) New() *User {
}

func (uq *UserQuery) GetAll() (users []*User) {
rows, err := uq.db.Query("SELECT * FROM user")
rows, err := uq.db.Query(`SELECT * FROM "user"`)
if err != nil || rows == nil {
return nil
}
Expand All @@ -69,15 +85,15 @@ func (uq *UserQuery) GetAll() (users []*User) {
}

func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User {
row := uq.db.QueryRow("SELECT * FROM user WHERE mxid=?", userID)
row := uq.db.QueryRow(`SELECT * FROM "user" WHERE mxid=$1`, userID)
if row == nil {
return nil
}
return uq.New().Scan(row)
}

func (uq *UserQuery) GetByJID(userID types.WhatsAppID) *User {
row := uq.db.QueryRow("SELECT * FROM user WHERE jid=?", stripSuffix(userID))
row := uq.db.QueryRow(`SELECT * FROM "user" WHERE jid=$1`, stripSuffix(userID))
if row == nil {
return nil
}
Expand Down Expand Up @@ -150,7 +166,7 @@ func (user *User) sessionUnptr() (sess whatsapp.Session) {

func (user *User) Insert() {
sess := user.sessionUnptr()
_, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?, ?)", user.MXID, user.jidPtr(),
_, err := user.db.Exec(`INSERT INTO "user" VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, user.MXID, user.jidPtr(),
user.ManagementRoom,
sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey)
if err != nil {
Expand All @@ -160,7 +176,7 @@ func (user *User) Insert() {

func (user *User) Update() {
sess := user.sessionUnptr()
_, err := user.db.Exec("UPDATE user SET jid=?, management_room=?, client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=? WHERE mxid=?",
_, err := user.db.Exec(`UPDATE "user" SET jid=$1, management_room=$2, client_id=$3, client_token=$4, server_token=$5, enc_key=$6, mac_key=$7 WHERE mxid=$8`,
user.jidPtr(), user.ManagementRoom,
sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey,
user.MXID)
Expand Down
1 change: 1 addition & 0 deletions example-config.yaml
Expand Up @@ -20,6 +20,7 @@ appservice:
# The database type. Only "sqlite3" is supported.
type: sqlite3
# The database URI. Usually file name. https://github.com/mattn/go-sqlite3#connection-string
# postres example: postgres://synapse:changeme@db/whatsapp?sslmode=disable
uri: mautrix-whatsapp.db
# Path to the Matrix room state store.
state_store_path: ./mx-state.json
Expand Down
5 changes: 3 additions & 2 deletions main.go
Expand Up @@ -133,7 +133,7 @@ func (bridge *Bridge) Init() {
bridge.AS.StateStore = bridge.StateStore

bridge.Log.Debugln("Initializing database")
bridge.DB, err = database.New(bridge.Config.AppService.Database.URI)
bridge.DB, err = database.New(bridge.Config.AppService.Database.Type, bridge.Config.AppService.Database.URI)
if err != nil {
bridge.Log.Fatalln("Failed to initialize database:", err)
os.Exit(14)
Expand All @@ -147,7 +147,7 @@ func (bridge *Bridge) Init() {
}

func (bridge *Bridge) Start() {
err := bridge.DB.CreateTables()
err := bridge.DB.CreateTables(bridge.Config.AppService.Database.Type)
if err != nil {
bridge.Log.Fatalln("Failed to create database tables:", err)
os.Exit(15)
Expand Down Expand Up @@ -185,6 +185,7 @@ func (bridge *Bridge) UpdateBotProfile() {
}

func (bridge *Bridge) StartUsers() {
bridge.Log.Debugln("Starting users")
for _, user := range bridge.GetAllUsers() {
go user.Connect(false)
}
Expand Down