Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
937 lines (818 sloc) 21 KB
package main
import (
crand "crypto/rand"
"crypto/sha1"
"database/sql"
"encoding/binary"
"encoding/json"
"fmt"
"html/template"
"io"
"io/ioutil"
"log"
"math/rand"
"net"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/go-redis/redis"
"github.com/go-sql-driver/mysql"
"github.com/gorilla/sessions"
"github.com/jmoiron/sqlx"
"github.com/labstack/echo"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/middleware"
)
const (
avatarMaxBytes = 1 * 1024 * 1024
)
var (
db *sqlx.DB
ErrBadReqeust = echo.NewHTTPError(http.StatusBadRequest)
redisClients []*redis.Client
)
func getRedisClient(key int64) *redis.Client {
k := key % 4
if k == 3 {
k = 2
}
return redisClients[k]
}
func iconKey(digest string) int64 {
key := int64(0)
for _, c := range digest {
key = (key + int64(c)) % 4
}
return key
}
type Renderer struct {
templates *template.Template
}
func (r *Renderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
return r.templates.ExecuteTemplate(w, name, data)
}
func init() {
seedBuf := make([]byte, 8)
crand.Read(seedBuf)
rand.Seed(int64(binary.LittleEndian.Uint64(seedBuf)))
db_host := os.Getenv("ISUBATA_DB_HOST")
if db_host == "" {
db_host = "127.0.0.1"
}
db_port := os.Getenv("ISUBATA_DB_PORT")
if db_port == "" {
db_port = "3306"
}
db_user := os.Getenv("ISUBATA_DB_USER")
if db_user == "" {
db_user = "root"
}
db_password := os.Getenv("ISUBATA_DB_PASSWORD")
if db_password != "" {
db_password = ":" + db_password
}
dsn := fmt.Sprintf("%s%s@tcp(%s:%s)/isubata?parseTime=true&loc=Local&charset=utf8mb4",
db_user, db_password, db_host, db_port)
log.Printf("Connecting to db: %q", dsn)
db, _ = sqlx.Connect("mysql", dsn)
for _, i := range []int{0, 1, 2} {
redisAddr := os.Getenv(fmt.Sprintf("ISUBATA_REDIS_ADDR%d", i))
if redisAddr == "" {
redisAddr = "localhost:6379"
}
client := redis.NewClient(&redis.Options{
Addr: redisAddr,
Password: "",
DB: 0,
})
redisClients = append(redisClients, client)
}
for {
err := db.Ping()
if err == nil {
break
}
log.Println(err)
time.Sleep(time.Second * 3)
}
db.SetMaxOpenConns(20)
db.SetConnMaxLifetime(5 * time.Minute)
log.Printf("Succeeded to connect db.")
for _, i := range []int{0, 1, 2} {
for {
_, err := redisClients[i].Ping().Result()
if err == nil {
break
}
log.Println(err)
time.Sleep(time.Second * 3)
}
log.Printf("Succeeded to connect Redis %d.", i)
}
}
type User struct {
ID int64 `json:"-" db:"id"`
Name string `json:"name" db:"name"`
Salt string `json:"-" db:"salt"`
Password string `json:"-" db:"password"`
DisplayName string `json:"display_name" db:"display_name"`
AvatarIcon string `json:"avatar_icon" db:"avatar_icon"`
CreatedAt time.Time `json:"-" db:"created_at"`
}
func getUser(userID int64) (*User, error) {
u := User{}
if err := db.Get(&u, "SELECT * FROM user WHERE id = ?", userID); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return &u, nil
}
func addMessage(channelID, userID int64, content string) (int64, error) {
res, err := db.Exec(
"INSERT INTO message (channel_id, user_id, content, created_at) VALUES (?, ?, ?, NOW())",
channelID, userID, content)
if err != nil {
return 0, err
}
getRedisClient(channelID).HIncrBy("message_count", fmt.Sprintf("%d", channelID), 1)
return res.LastInsertId()
}
type Message struct {
ID int64 `db:"id"`
ChannelID int64 `db:"channel_id"`
UserID int64 `db:"user_id"`
Content string `db:"content"`
CreatedAt time.Time `db:"created_at"`
}
type MessageWithUser struct {
ID int64 `db:"id"`
Content string `db:"content"`
CreatedAt time.Time `db:"created_at"`
UserName string `db:"name"`
UserDisplayName string `db:"display_name"`
UserAvatarIcon string `db:"avatar_icon"`
}
func queryMessages(chanID, lastID int64) ([]Message, error) {
msgs := []Message{}
err := db.Select(&msgs, "SELECT * FROM message WHERE id > ? AND channel_id = ? ORDER BY id DESC LIMIT 100",
lastID, chanID)
return msgs, err
}
func sessUserID(c echo.Context) int64 {
sess, _ := session.Get("session", c)
var userID int64
if x, ok := sess.Values["user_id"]; ok {
userID, _ = x.(int64)
}
return userID
}
func sessSetUserID(c echo.Context, id int64) {
sess, _ := session.Get("session", c)
sess.Options = &sessions.Options{
HttpOnly: true,
MaxAge: 360000,
}
sess.Values["user_id"] = id
sess.Save(c.Request(), c.Response())
}
func ensureLogin(c echo.Context) (*User, error) {
var user *User
var err error
userID := sessUserID(c)
if userID == 0 {
goto redirect
}
user, err = getUser(userID)
if err != nil {
return nil, err
}
if user == nil {
sess, _ := session.Get("session", c)
delete(sess.Values, "user_id")
sess.Save(c.Request(), c.Response())
goto redirect
}
return user, nil
redirect:
c.Redirect(http.StatusSeeOther, "/login")
return nil, nil
}
const LettersAndDigits = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
func randomString(n int) string {
b := make([]byte, n)
z := len(LettersAndDigits)
for i := 0; i < n; i++ {
b[i] = LettersAndDigits[rand.Intn(z)]
}
return string(b)
}
func register(name, password string) (int64, error) {
salt := randomString(20)
digest := fmt.Sprintf("%x", sha1.Sum([]byte(salt+password)))
res, err := db.Exec(
"INSERT INTO user (name, salt, password, display_name, avatar_icon, created_at)"+
" VALUES (?, ?, ?, ?, ?, NOW())",
name, salt, digest, name, "default.png")
if err != nil {
return 0, err
}
return res.LastInsertId()
}
// request handlers
func getInitialize(c echo.Context) error {
db.MustExec("DELETE FROM user WHERE id > 1000")
db.MustExec("DELETE FROM image WHERE id > 1001")
db.MustExec("DELETE FROM channel WHERE id > 10")
db.MustExec("DELETE FROM message WHERE id > 10000")
db.MustExec("DELETE FROM haveread")
for _, client := range redisClients {
keys, err := client.Keys("haveread/*").Result()
if err != nil {
panic(err)
}
client.Del(keys...)
client.Del("message_count")
}
counts := []struct {
Count int `db:"cnt"`
ChannelID int64 `db:"channel_id"`
}{}
err := db.Select(&counts, "SELECT COUNT(*) as cnt, channel_id FROM message group by channel_id")
if err != nil {
return err
}
for _, c := range counts {
getRedisClient(c.ChannelID).HSet("message_count", fmt.Sprintf("%d", c.ChannelID), c.Count)
}
return c.String(204, "")
}
func getInitializeRedis(c echo.Context) error {
type Image struct {
Name string `db:"name"`
Data []byte `db:"data"`
}
imgs := []Image{}
log.Println("Start loading images from MySQL")
err := db.Select(&imgs, "select name, data from image")
if err != nil {
return err
}
log.Printf("Loaded %d images", len(imgs))
for _, client := range redisClients {
client.FlushDB()
}
for _, img := range imgs {
if err := getRedisClient(iconKey(img.Name)).Set(fmt.Sprintf("icons/%s", img.Name), img.Data, 0).Err(); err != nil {
return err
}
}
log.Printf("Saved images into Redis", len(imgs))
return c.String(204, "")
}
func getIndex(c echo.Context) error {
userID := sessUserID(c)
if userID != 0 {
return c.Redirect(http.StatusSeeOther, "/channel/1")
}
return c.Render(http.StatusOK, "index", map[string]interface{}{
"ChannelID": nil,
})
}
type ChannelInfo struct {
ID int64 `db:"id"`
Name string `db:"name"`
Description string `db:"description"`
UpdatedAt time.Time `db:"updated_at"`
CreatedAt time.Time `db:"created_at"`
}
func getChannel(c echo.Context) error {
user, err := ensureLogin(c)
if user == nil {
return err
}
cID, err := strconv.Atoi(c.Param("channel_id"))
if err != nil {
return err
}
channels := []ChannelInfo{}
err = db.Select(&channels, "SELECT * FROM channel ORDER BY id")
if err != nil {
return err
}
var desc string
for _, ch := range channels {
if ch.ID == int64(cID) {
desc = ch.Description
break
}
}
return c.Render(http.StatusOK, "channel", map[string]interface{}{
"ChannelID": cID,
"Channels": channels,
"User": user,
"Description": desc,
})
}
func getRegister(c echo.Context) error {
return c.Render(http.StatusOK, "register", map[string]interface{}{
"ChannelID": 0,
"Channels": []ChannelInfo{},
"User": nil,
})
}
func postRegister(c echo.Context) error {
name := c.FormValue("name")
pw := c.FormValue("password")
if name == "" || pw == "" {
return ErrBadReqeust
}
userID, err := register(name, pw)
if err != nil {
if merr, ok := err.(*mysql.MySQLError); ok {
if merr.Number == 1062 { // Duplicate entry xxxx for key zzzz
return c.NoContent(http.StatusConflict)
}
}
return err
}
sessSetUserID(c, userID)
return c.Redirect(http.StatusSeeOther, "/")
}
func getLogin(c echo.Context) error {
return c.Render(http.StatusOK, "login", map[string]interface{}{
"ChannelID": 0,
"Channels": []ChannelInfo{},
"User": nil,
})
}
func postLogin(c echo.Context) error {
name := c.FormValue("name")
pw := c.FormValue("password")
if name == "" || pw == "" {
return ErrBadReqeust
}
var user User
err := db.Get(&user, "SELECT * FROM user WHERE name = ?", name)
if err == sql.ErrNoRows {
return echo.ErrForbidden
} else if err != nil {
return err
}
digest := fmt.Sprintf("%x", sha1.Sum([]byte(user.Salt+pw)))
if digest != user.Password {
return echo.ErrForbidden
}
sessSetUserID(c, user.ID)
return c.Redirect(http.StatusSeeOther, "/")
}
func getLogout(c echo.Context) error {
sess, _ := session.Get("session", c)
delete(sess.Values, "user_id")
sess.Save(c.Request(), c.Response())
return c.Redirect(http.StatusSeeOther, "/")
}
func postMessage(c echo.Context) error {
userID := sessUserID(c)
if userID == 0 {
c.Redirect(http.StatusSeeOther, "/login")
return nil
}
message := c.FormValue("message")
if message == "" {
return echo.ErrForbidden
}
var chanID int64
if x, err := strconv.Atoi(c.FormValue("channel_id")); err != nil {
return echo.ErrForbidden
} else {
chanID = int64(x)
}
if _, err := addMessage(chanID, userID, message); err != nil {
return err
}
return c.NoContent(204)
}
func getMessage(c echo.Context) error {
userID := sessUserID(c)
if userID == 0 {
return c.NoContent(http.StatusForbidden)
}
chanID, err := strconv.ParseInt(c.QueryParam("channel_id"), 10, 64)
if err != nil {
return err
}
lastID, err := strconv.ParseInt(c.QueryParam("last_message_id"), 10, 64)
if err != nil {
return err
}
messages := []MessageWithUser{}
err = db.Select(&messages, "select m.id, m.content, m.created_at, u.name, u.display_name, u.avatar_icon from message m inner join user u on u.id = m.user_id where m.id > ? and m.channel_id = ? order by m.id desc limit 100", lastID, chanID)
if err != nil {
return err
}
response := make([]map[string]interface{}, 0)
for i := len(messages) - 1; i >= 0; i-- {
m := messages[i]
r := make(map[string]interface{})
r["id"] = m.ID
r["user"] = User{
Name: m.UserName,
DisplayName: m.UserDisplayName,
AvatarIcon: m.UserAvatarIcon,
}
r["date"] = m.CreatedAt.Format("2006/01/02 15:04:05")
r["content"] = m.Content
response = append(response, r)
}
if len(messages) > 0 {
err := getRedisClient(userID+chanID).Set(fmt.Sprintf("haveread/%d/%d", userID, chanID), messages[0].ID, 0).Err()
if err != nil {
return err
}
}
return streamJSON(c, response)
}
func streamJSON(c echo.Context, data interface{}) error {
c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSONCharsetUTF8)
c.Response().WriteHeader(http.StatusOK)
return json.NewEncoder(c.Response()).Encode(data)
}
func queryChannels() ([]int64, error) {
res := []int64{}
err := db.Select(&res, "SELECT id FROM channel")
return res, err
}
func queryHaveReads(userID int64, chIDs []int64) ([]int64, error) {
keys := make([]string, len(chIDs))
for i, chID := range chIDs {
keys[i] = fmt.Sprintf("haveread/%d/%d", userID, chID)
}
mergedResults := make([]string, len(chIDs))
for _, client := range redisClients {
results, err := client.MGet(keys...).Result()
if err != nil {
return []int64{}, err
}
for j, result := range results {
if result != nil {
mergedResults[j] = result.(string)
}
}
}
messageIds := make([]int64, len(chIDs))
for i, result := range mergedResults {
if result == "" {
messageIds[i] = 0
} else {
id, err := strconv.ParseInt(result, 10, 64)
if err != nil {
return messageIds, err
}
messageIds[i] = id
}
}
return messageIds, nil
}
func fetchUnread(c echo.Context) error {
userID := sessUserID(c)
if userID == 0 {
return c.NoContent(http.StatusForbidden)
}
time.Sleep(time.Second)
channels, err := queryChannels()
if err != nil {
return err
}
resp := []map[string]interface{}{}
lastIDs, err := queryHaveReads(userID, channels)
if err != nil {
return err
}
zeroChannels := make([]int64, 0)
zeroChannelsStr := make([]string, 0)
for i, chID := range channels {
lastID := lastIDs[i]
if lastID == 0 {
zeroChannels = append(zeroChannels, chID)
zeroChannelsStr = append(zeroChannelsStr, strconv.FormatInt(chID, 10))
}
}
mergedResults := make([]string, len(zeroChannels))
for _, client := range redisClients {
results, err := client.HMGet("message_count", zeroChannelsStr...).Result()
if err != nil {
return err
}
for j, result := range results {
if result != nil {
mergedResults[j] = result.(string)
}
}
}
messageCounts := make(map[int64]int64, len(zeroChannels))
for i, result := range mergedResults {
if result == "" {
messageCounts[zeroChannels[i]] = 0
} else {
messageCounts[zeroChannels[i]], err = strconv.ParseInt(result, 10, 64)
if err != nil {
return err
}
}
}
for i, chID := range channels {
lastID := lastIDs[i]
var cnt int64
if lastID > 0 {
err = db.Get(&cnt,
"SELECT COUNT(*) as cnt FROM message WHERE channel_id = ? AND ? < id",
chID, lastID)
} else {
if c, ok := messageCounts[chID]; ok {
cnt = c
} else {
cnt = 0
}
}
if err != nil {
return err
}
r := map[string]interface{}{
"channel_id": chID,
"unread": cnt}
resp = append(resp, r)
}
return streamJSON(c, resp)
}
func getHistory(c echo.Context) error {
chID, err := strconv.ParseInt(c.Param("channel_id"), 10, 64)
if err != nil || chID <= 0 {
return ErrBadReqeust
}
user, err := ensureLogin(c)
if user == nil {
return err
}
var page int64
pageStr := c.QueryParam("page")
if pageStr == "" {
page = 1
} else {
page, err = strconv.ParseInt(pageStr, 10, 64)
if err != nil || page < 1 {
return ErrBadReqeust
}
}
const N = 20
var cnt int64
res := getRedisClient(chID).HGet("message_count", fmt.Sprintf("%d", chID))
cnt, err = res.Int64()
if err == redis.Nil {
cnt = 0
err = nil
}
maxPage := int64(cnt+N-1) / N
if maxPage == 0 {
maxPage = 1
}
if page > maxPage {
return ErrBadReqeust
}
messages := []MessageWithUser{}
err = db.Select(&messages,
"select m.id, m.content, m.created_at, u.name, u.display_name, u.avatar_icon from message m inner join user u on u.id = m.user_id where m.channel_id = ? order by m.id desc limit ? offset ?",
chID, N, (page-1)*N)
if err != nil {
return err
}
mjson := make([]map[string]interface{}, 0)
for i := len(messages) - 1; i >= 0; i-- {
m := messages[i]
r := make(map[string]interface{})
r["id"] = m.ID
r["user"] = User{
Name: m.UserName,
DisplayName: m.UserDisplayName,
AvatarIcon: m.UserAvatarIcon,
}
r["date"] = m.CreatedAt.Format("2006/01/02 15:04:05")
r["content"] = m.Content
mjson = append(mjson, r)
}
channels := []ChannelInfo{}
err = db.Select(&channels, "SELECT * FROM channel ORDER BY id")
if err != nil {
return err
}
return c.Render(http.StatusOK, "history", map[string]interface{}{
"ChannelID": chID,
"Channels": channels,
"Messages": mjson,
"MaxPage": maxPage,
"Page": page,
"User": user,
})
}
func getProfile(c echo.Context) error {
self, err := ensureLogin(c)
if self == nil {
return err
}
channels := []ChannelInfo{}
err = db.Select(&channels, "SELECT * FROM channel ORDER BY id")
if err != nil {
return err
}
userName := c.Param("user_name")
var other User
err = db.Get(&other, "SELECT * FROM user WHERE name = ?", userName)
if err == sql.ErrNoRows {
return echo.ErrNotFound
}
if err != nil {
return err
}
return c.Render(http.StatusOK, "profile", map[string]interface{}{
"ChannelID": 0,
"Channels": channels,
"User": self,
"Other": other,
"SelfProfile": self.ID == other.ID,
})
}
func getAddChannel(c echo.Context) error {
self, err := ensureLogin(c)
if self == nil {
return err
}
channels := []ChannelInfo{}
err = db.Select(&channels, "SELECT * FROM channel ORDER BY id")
if err != nil {
return err
}
return c.Render(http.StatusOK, "add_channel", map[string]interface{}{
"ChannelID": 0,
"Channels": channels,
"User": self,
})
}
func postAddChannel(c echo.Context) error {
userID := sessUserID(c)
if userID == 0 {
c.Redirect(http.StatusSeeOther, "/login")
return nil
}
name := c.FormValue("name")
desc := c.FormValue("description")
if name == "" || desc == "" {
return ErrBadReqeust
}
res, err := db.Exec(
"INSERT INTO channel (name, description, updated_at, created_at) VALUES (?, ?, NOW(), NOW())",
name, desc)
if err != nil {
return err
}
lastID, _ := res.LastInsertId()
return c.Redirect(http.StatusSeeOther,
fmt.Sprintf("/channel/%v", lastID))
}
func postProfile(c echo.Context) error {
userID := sessUserID(c)
if userID == 0 {
c.Redirect(http.StatusSeeOther, "/login")
return nil
}
avatarName := ""
var avatarData []byte
if fh, err := c.FormFile("avatar_icon"); err == http.ErrMissingFile {
// no file upload
} else if err != nil {
return err
} else {
dotPos := strings.LastIndexByte(fh.Filename, '.')
if dotPos < 0 {
return ErrBadReqeust
}
ext := fh.Filename[dotPos:]
switch ext {
case ".jpg", ".jpeg", ".png", ".gif":
break
default:
return ErrBadReqeust
}
file, err := fh.Open()
if err != nil {
return err
}
avatarData, _ = ioutil.ReadAll(file)
file.Close()
if len(avatarData) > avatarMaxBytes {
return ErrBadReqeust
}
avatarName = fmt.Sprintf("%x%s", sha1.Sum(avatarData), ext)
}
if avatarName != "" && len(avatarData) > 0 {
err := getRedisClient(iconKey(avatarName)).Set(fmt.Sprintf("icons/%s", avatarName), avatarData, 0).Err()
if err != nil {
return err
}
_, err = db.Exec("UPDATE user SET avatar_icon = ? WHERE id = ?", avatarName, userID)
if err != nil {
return err
}
}
if name := c.FormValue("display_name"); name != "" {
_, err := db.Exec("UPDATE user SET display_name = ? WHERE id = ?", name, userID)
if err != nil {
return err
}
}
return c.Redirect(http.StatusSeeOther, "/")
}
func getIcon(c echo.Context) error {
name := c.Param("file_name")
c.Response().Header().Set("Cache-Control", "public, max-age=31536000")
c.Response().Header().Set("ETag", name[0:len(name)-4])
c.Response().Header().Set("Last-Modified", "Mon, 16 Oct 2017 16:33:02 GMT")
if c.Request().Header.Get("If-Modified-Since") != "" || c.Request().Header.Get("If-None-Match") != "" {
return c.NoContent(304)
}
data, err := getRedisClient(iconKey(name)).Get(fmt.Sprintf("icons/%s", name)).Bytes()
if err == redis.Nil {
return echo.ErrNotFound
}
if err != nil {
return err
}
mime := ""
switch true {
case strings.HasSuffix(name, ".jpg"), strings.HasSuffix(name, ".jpeg"):
mime = "image/jpeg"
case strings.HasSuffix(name, ".png"):
mime = "image/png"
case strings.HasSuffix(name, ".gif"):
mime = "image/gif"
default:
return echo.ErrNotFound
}
return c.Blob(http.StatusOK, mime, data)
}
func tAdd(a, b int64) int64 {
return a + b
}
func tRange(a, b int64) []int64 {
r := make([]int64, b-a+1)
for i := int64(0); i <= (b - a); i++ {
r[i] = a + i
}
return r
}
func main() {
e := echo.New()
funcs := template.FuncMap{
"add": tAdd,
"xrange": tRange,
}
e.Renderer = &Renderer{
templates: template.Must(template.New("").Funcs(funcs).ParseGlob("views/*.html")),
}
e.Use(session.Middleware(sessions.NewCookieStore([]byte("secretonymoris"))))
e.Use(middleware.Static("../public"))
e.GET("/initialize", getInitialize)
e.GET("/initialize_redis", getInitializeRedis)
e.GET("/", getIndex)
e.GET("/register", getRegister)
e.POST("/register", postRegister)
e.GET("/login", getLogin)
e.POST("/login", postLogin)
e.GET("/logout", getLogout)
e.GET("/channel/:channel_id", getChannel)
e.GET("/message", getMessage)
e.POST("/message", postMessage)
e.GET("/fetch", fetchUnread)
e.GET("/history/:channel_id", getHistory)
e.GET("/profile/:user_name", getProfile)
e.POST("/profile", postProfile)
e.GET("add_channel", getAddChannel)
e.POST("add_channel", postAddChannel)
e.GET("/icons/:file_name", getIcon)
if path := os.Getenv("ISUBATA_UNIX_SOCKET"); path != "" {
os.Remove(path)
l, err := net.Listen("unix", path)
if err != nil {
e.Logger.Fatal(err)
}
if err := os.Chmod(path, 0777); err != nil {
e.Logger.Fatal(err)
}
e.Listener = l
e.Start("")
} else {
e.Start(":5000")
}
}