Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
package main
import (
crand "crypto/rand"
"crypto/sha1"
"database/sql"
"encoding/binary"
"encoding/json"
"fmt"
"html/template"
"io"
"io/ioutil"
"log"
"math/rand"
"net"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/go-redis/redis"
"github.com/go-sql-driver/mysql"
"github.com/gorilla/sessions"
"github.com/jmoiron/sqlx"
"github.com/labstack/echo"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/middleware"
)
const (
avatarMaxBytes = 1 * 1024 * 1024
)
var (
db *sqlx.DB
ErrBadReqeust = echo.NewHTTPError(http.StatusBadRequest)
redisClients []*redis.Client
)
func getRedisClient(key int64) *redis.Client {
k := key % 4
if k == 3 {
k = 2
}
return redisClients[k]
}
func iconKey(digest string) int64 {
key := int64(0)
for _, c := range digest {
key = (key + int64(c)) % 4
}
return key
}
type Renderer struct {
templates *template.Template
}
func (r *Renderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
return r.templates.ExecuteTemplate(w, name, data)
}
func init() {
seedBuf := make([]byte, 8)
crand.Read(seedBuf)
rand.Seed(int64(binary.LittleEndian.Uint64(seedBuf)))
db_host := os.Getenv("ISUBATA_DB_HOST")
if db_host == "" {
db_host = "127.0.0.1"
}
db_port := os.Getenv("ISUBATA_DB_PORT")
if db_port == "" {
db_port = "3306"
}
db_user := os.Getenv("ISUBATA_DB_USER")
if db_user == "" {
db_user = "root"
}
db_password := os.Getenv("ISUBATA_DB_PASSWORD")
if db_password != "" {
db_password = ":" + db_password
}
dsn := fmt.Sprintf("%s%s@tcp(%s:%s)/isubata?parseTime=true&loc=Local&charset=utf8mb4",
db_user, db_password, db_host, db_port)
log.Printf("Connecting to db: %q", dsn)
db, _ = sqlx.Connect("mysql", dsn)
for _, i := range []int{0, 1, 2} {
redisAddr := os.Getenv(fmt.Sprintf("ISUBATA_REDIS_ADDR%d", i))
if redisAddr == "" {
redisAddr = "localhost:6379"
}
client := redis.NewClient(&redis.Options{
Addr: redisAddr,
Password: "",
DB: 0,
})
redisClients = append(redisClients, client)
}
for {
err := db.Ping()
if err == nil {
break
}
log.Println(err)
time.Sleep(time.Second * 3)
}
db.SetMaxOpenConns(20)
db.SetConnMaxLifetime(5 * time.Minute)
log.Printf("Succeeded to connect db.")
for _, i := range []int{0, 1, 2} {
for {
_, err := redisClients[i].Ping().Result()
if err == nil {
break
}
log.Println(err)
time.Sleep(time.Second * 3)
}
log.Printf("Succeeded to connect Redis %d.", i)
}
}
type User struct {
ID int64 `json:"-" db:"id"`
Name string `json:"name" db:"name"`
Salt string `json:"-" db:"salt"`
Password string `json:"-" db:"password"`
DisplayName string `json:"display_name" db:"display_name"`
AvatarIcon string `json:"avatar_icon" db:"avatar_icon"`
CreatedAt time.Time `json:"-" db:"created_at"`
}
func getUser(userID int64) (*User, error) {
u := User{}
if err := db.Get(&u, "SELECT * FROM user WHERE id = ?", userID); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return &u, nil
}
func addMessage(channelID, userID int64, content string) (int64, error) {
res, err := db.Exec(
"INSERT INTO message (channel_id, user_id, content, created_at) VALUES (?, ?, ?, NOW())",
channelID, userID, content)
if err != nil {
return 0, err
}
getRedisClient(channelID).HIncrBy("message_count", fmt.Sprintf("%d", channelID), 1)
return res.LastInsertId()
}
type Message struct {
ID int64 `db:"id"`
ChannelID int64 `db:"channel_id"`
UserID int64 `db:"user_id"`
Content string `db:"content"`
CreatedAt time.Time `db:"created_at"`
}
type MessageWithUser struct {
ID int64 `db:"id"`
Content string `db:"content"`
CreatedAt time.Time `db:"created_at"`
UserName string `db:"name"`
UserDisplayName string `db:"display_name"`
UserAvatarIcon string `db:"avatar_icon"`
}
func queryMessages(chanID, lastID int64) ([]Message, error) {
msgs := []Message{}
err := db.Select(&msgs, "SELECT * FROM message WHERE id > ? AND channel_id = ? ORDER BY id DESC LIMIT 100",
lastID, chanID)
return msgs, err
}
func sessUserID(c echo.Context) int64 {
sess, _ := session.Get("session", c)
var userID int64
if x, ok := sess.Values["user_id"]; ok {
userID, _ = x.(int64)
}
return userID
}
func sessSetUserID(c echo.Context, id int64) {
sess, _ := session.Get("session", c)
sess.Options = &sessions.Options{
HttpOnly: true,
MaxAge: 360000,
}
sess.Values["user_id"] = id
sess.Save(c.Request(), c.Response())
}
func ensureLogin(c echo.Context) (*User, error) {
var user *User
var err error
userID := sessUserID(c)
if userID == 0 {
goto redirect
}
user, err = getUser(userID)
if err != nil {
return nil, err
}
if user == nil {
sess, _ := session.Get("session", c)
delete(sess.Values, "user_id")
sess.Save(c.Request(), c.Response())
goto redirect
}
return user, nil
redirect:
c.Redirect(http.StatusSeeOther, "/login")
return nil, nil
}
const LettersAndDigits = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
func randomString(n int) string {
b := make([]byte, n)
z := len(LettersAndDigits)
for i := 0; i < n; i++ {
b[i] = LettersAndDigits[rand.Intn(z)]
}
return string(b)
}
func register(name, password string) (int64, error) {
salt := randomString(20)
digest := fmt.Sprintf("%x", sha1.Sum([]byte(salt+password)))
res, err := db.Exec(
"INSERT INTO user (name, salt, password, display_name, avatar_icon, created_at)"+
" VALUES (?, ?, ?, ?, ?, NOW())",
name, salt, digest, name, "default.png")
if err != nil {
return 0, err
}
return res.LastInsertId()
}
// request handlers
func getInitialize(c echo.Context) error {
db.MustExec("DELETE FROM user WHERE id > 1000")
db.MustExec("DELETE FROM image WHERE id > 1001")
db.MustExec("DELETE FROM channel WHERE id > 10")
db.MustExec("DELETE FROM message WHERE id > 10000")
db.MustExec("DELETE FROM haveread")
for _, client := range redisClients {
keys, err := client.Keys("haveread/*").Result()
if err != nil {
panic(err)
}
client.Del(keys...)
client.Del("message_count")
}
counts := []struct {
Count int `db:"cnt"`
ChannelID int64 `db:"channel_id"`
}{}
err := db.Select(&counts, "SELECT COUNT(*) as cnt, channel_id FROM message group by channel_id")
if err != nil {
return err
}
for _, c := range counts {
getRedisClient(c.ChannelID).HSet("message_count", fmt.Sprintf("%d", c.ChannelID), c.Count)
}
return c.String(204, "")
}
func getInitializeRedis(c echo.Context) error {
type Image struct {
Name string `db:"name"`
Data []byte `db:"data"`
}
imgs := []Image{}
log.Println("Start loading images from MySQL")
err := db.Select(&imgs, "select name, data from image")
if err != nil {
return err
}
log.Printf("Loaded %d images", len(imgs))
for _, client := range redisClients {
client.FlushDB()
}
for _, img := range imgs {
if err := getRedisClient(iconKey(img.Name)).Set(fmt.Sprintf("icons/%s", img.Name), img.Data, 0).Err(); err != nil {
return err
}
}
log.Printf("Saved images into Redis", len(imgs))
return c.String(204, "")
}
func getIndex(c echo.Context) error {
userID := sessUserID(c)
if userID != 0 {
return c.Redirect(http.StatusSeeOther, "/channel/1")
}
return c.Render(http.StatusOK, "index", map[string]interface{}{
"ChannelID": nil,
})
}
type ChannelInfo struct {
ID int64 `db:"id"`
Name string `db:"name"`
Description string `db:"description"`
UpdatedAt time.Time `db:"updated_at"`
CreatedAt time.Time `db:"created_at"`
}
func getChannel(c echo.Context) error {
user, err := ensureLogin(c)
if user == nil {
return err
}
cID, err := strconv.Atoi(c.Param("channel_id"))
if err != nil {
return err
}
channels := []ChannelInfo{}
err = db.Select(&channels, "SELECT * FROM channel ORDER BY id")
if err != nil {
return err
}
var desc string
for _, ch := range channels {
if ch.ID == int64(cID) {
desc = ch.Description
break
}
}
return c.Render(http.StatusOK, "channel", map[string]interface{}{
"ChannelID": cID,
"Channels": channels,
"User": user,
"Description": desc,
})
}
func getRegister(c echo.Context) error {
return c.Render(http.StatusOK, "register", map[string]interface{}{
"ChannelID": 0,
"Channels": []ChannelInfo{},
"User": nil,
})
}
func postRegister(c echo.Context) error {
name := c.FormValue("name")
pw := c.FormValue("password")
if name == "" || pw == "" {
return ErrBadReqeust
}
userID, err := register(name, pw)
if err != nil {
if merr, ok := err.(*mysql.MySQLError); ok {
if merr.Number == 1062 { // Duplicate entry xxxx for key zzzz
return c.NoContent(http.StatusConflict)
}
}
return err
}
sessSetUserID(c, userID)
return c.Redirect(http.StatusSeeOther, "/")
}
func getLogin(c echo.Context) error {
return c.Render(http.StatusOK, "login", map[string]interface{}{
"ChannelID": 0,
"Channels": []ChannelInfo{},
"User": nil,
})
}
func postLogin(c echo.Context) error {
name := c.FormValue("name")
pw := c.FormValue("password")
if name == "" || pw == "" {
return ErrBadReqeust
}
var user User
err := db.Get(&user, "SELECT * FROM user WHERE name = ?", name)
if err == sql.ErrNoRows {
return echo.ErrForbidden
} else if err != nil {
return err
}
digest := fmt.Sprintf("%x", sha1.Sum([]byte(user.Salt+pw)))
if digest != user.Password {
return echo.ErrForbidden
}
sessSetUserID(c, user.ID)
return c.Redirect(http.StatusSeeOther, "/")
}
func getLogout(c echo.Context) error {
sess, _ := session.Get("session", c)
delete(sess.Values, "user_id")
sess.Save(c.Request(), c.Response())
return c.Redirect(http.StatusSeeOther, "/")
}
func postMessage(c echo.Context) error {
userID := sessUserID(c)
if userID == 0 {
c.Redirect(http.StatusSeeOther, "/login")
return nil
}
message := c.FormValue("message")
if message == "" {
return echo.ErrForbidden
}
var chanID int64
if x, err := strconv.Atoi(c.FormValue("channel_id")); err != nil {
return echo.ErrForbidden
} else {
chanID = int64(x)
}
if _, err := addMessage(chanID, userID, message); err != nil {
return err
}
return c.NoContent(204)
}
func getMessage(c echo.Context) error {
userID := sessUserID(c)
if userID == 0 {
return c.NoContent(http.StatusForbidden)
}
chanID, err := strconv.ParseInt(c.QueryParam("channel_id"), 10, 64)
if err != nil {
return err
}
lastID, err := strconv.ParseInt(c.QueryParam("last_message_id"), 10, 64)
if err != nil {
return err
}
messages := []MessageWithUser{}
err = db.Select(&messages, "select m.id, m.content, m.created_at, u.name, u.display_name, u.avatar_icon from message m inner join user u on u.id = m.user_id where m.id > ? and m.channel_id = ? order by m.id desc limit 100", lastID, chanID)
if err != nil {
return err
}
response := make([]map[string]interface{}, 0)
for i := len(messages) - 1; i >= 0; i-- {
m := messages[i]
r := make(map[string]interface{})
r["id"] = m.ID
r["user"] = User{
Name: m.UserName,
DisplayName: m.UserDisplayName,
AvatarIcon: m.UserAvatarIcon,
}
r["date"] = m.CreatedAt.Format("2006/01/02 15:04:05")
r["content"] = m.Content
response = append(response, r)
}
if len(messages) > 0 {
err := getRedisClient(userID+chanID).Set(fmt.Sprintf("haveread/%d/%d", userID, chanID), messages[0].ID, 0).Err()
if err != nil {
return err
}
}
return streamJSON(c, response)
}
func streamJSON(c echo.Context, data interface{}) error {
c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSONCharsetUTF8)
c.Response().WriteHeader(http.StatusOK)
return json.NewEncoder(c.Response()).Encode(data)
}
func queryChannels() ([]int64, error) {
res := []int64{}
err := db.Select(&res, "SELECT id FROM channel")
return res, err
}
func queryHaveReads(userID int64, chIDs []int64) ([]int64, error) {
keys := make([]string, len(chIDs))
for i, chID := range chIDs {
keys[i] = fmt.Sprintf("haveread/%d/%d", userID, chID)
}
mergedResults := make([]string, len(chIDs))
for _, client := range redisClients {
results, err := client.MGet(keys...).Result()
if err != nil {
return []int64{}, err
}
for j, result := range results {
if result != nil {
mergedResults[j] = result.(string)
}
}
}
messageIds := make([]int64, len(chIDs))
for i, result := range mergedResults {
if result == "" {
messageIds[i] = 0
} else {
id, err := strconv.ParseInt(result, 10, 64)
if err != nil {
return messageIds, err
}
messageIds[i] = id
}
}
return messageIds, nil
}
func fetchUnread(c echo.Context) error {
userID := sessUserID(c)
if userID == 0 {
return c.NoContent(http.StatusForbidden)
}
time.Sleep(time.Second)
channels, err := queryChannels()
if err != nil {
return err
}
resp := []map[string]interface{}{}
lastIDs, err := queryHaveReads(userID, channels)
if err != nil {
return err
}
zeroChannels := make([]int64, 0)
zeroChannelsStr := make([]string, 0)
for i, chID := range channels {
lastID := lastIDs[i]
if lastID == 0 {
zeroChannels = append(zeroChannels, chID)
zeroChannelsStr = append(zeroChannelsStr, strconv.FormatInt(chID, 10))
}
}
mergedResults := make([]string, len(zeroChannels))
for _, client := range redisClients {
results, err := client.HMGet("message_count", zeroChannelsStr...).Result()
if err != nil {
return err
}
for j, result := range results {
if result != nil {
mergedResults[j] = result.(string)
}
}
}
messageCounts := make(map[int64]int64, len(zeroChannels))
for i, result := range mergedResults {
if result == "" {
messageCounts[zeroChannels[i]] = 0
} else {
messageCounts[zeroChannels[i]], err = strconv.ParseInt(result, 10, 64)
if err != nil {
return err
}
}
}
for i, chID := range channels {
lastID := lastIDs[i]
var cnt int64
if lastID > 0 {
err = db.Get(&cnt,
"SELECT COUNT(*) as cnt FROM message WHERE channel_id = ? AND ? < id",
chID, lastID)
} else {
if c, ok := messageCounts[chID]; ok {
cnt = c
} else {
cnt = 0
}
}
if err != nil {
return err
}
r := map[string]interface{}{
"channel_id": chID,
"unread": cnt}
resp = append(resp, r)
}
return streamJSON(c, resp)
}
func getHistory(c echo.Context) error {
chID, err := strconv.ParseInt(c.Param("channel_id"), 10, 64)
if err != nil || chID <= 0 {
return ErrBadReqeust
}
user, err := ensureLogin(c)
if user == nil {
return err
}
var page int64
pageStr := c.QueryParam("page")
if pageStr == "" {
page = 1
} else {
page, err = strconv.ParseInt(pageStr, 10, 64)
if err != nil || page < 1 {
return ErrBadReqeust
}
}
const N = 20
var cnt int64
res := getRedisClient(chID).HGet("message_count", fmt.Sprintf("%d", chID))
cnt, err = res.Int64()
if err == redis.Nil {
cnt = 0
err = nil
}
maxPage := int64(cnt+N-1) / N
if maxPage == 0 {
maxPage = 1
}
if page > maxPage {
return ErrBadReqeust
}
messages := []MessageWithUser{}
err = db.Select(&messages,
"select m.id, m.content, m.created_at, u.name, u.display_name, u.avatar_icon from message m inner join user u on u.id = m.user_id where m.channel_id = ? order by m.id desc limit ? offset ?",
chID, N, (page-1)*N)
if err != nil {
return err
}
mjson := make([]map[string]interface{}, 0)
for i := len(messages) - 1; i >= 0; i-- {
m := messages[i]
r := make(map[string]interface{})
r["id"] = m.ID
r["user"] = User{
Name: m.UserName,
DisplayName: m.UserDisplayName,
AvatarIcon: m.UserAvatarIcon,
}
r["date"] = m.CreatedAt.Format("2006/01/02 15:04:05")
r["content"] = m.Content
mjson = append(mjson, r)
}
channels := []ChannelInfo{}
err = db.Select(&channels, "SELECT * FROM channel ORDER BY id")
if err != nil {
return err
}
return c.Render(http.StatusOK, "history", map[string]interface{}{
"ChannelID": chID,
"Channels": channels,
"Messages": mjson,
"MaxPage": maxPage,
"Page": page,
"User": user,
})
}
func getProfile(c echo.Context) error {
self, err := ensureLogin(c)
if self == nil {
return err
}
channels := []ChannelInfo{}
err = db.Select(&channels, "SELECT * FROM channel ORDER BY id")
if err != nil {
return err
}
userName := c.Param("user_name")
var other User
err = db.Get(&other, "SELECT * FROM user WHERE name = ?", userName)
if err == sql.ErrNoRows {
return echo.ErrNotFound
}
if err != nil {
return err
}
return c.Render(http.StatusOK, "profile", map[string]interface{}{
"ChannelID": 0,
"Channels": channels,
"User": self,
"Other": other,
"SelfProfile": self.ID == other.ID,
})
}
func getAddChannel(c echo.Context) error {
self, err := ensureLogin(c)
if self == nil {
return err
}
channels := []ChannelInfo{}
err = db.Select(&channels, "SELECT * FROM channel ORDER BY id")
if err != nil {
return err
}
return c.Render(http.StatusOK, "add_channel", map[string]interface{}{
"ChannelID": 0,
"Channels": channels,
"User": self,
})
}
func postAddChannel(c echo.Context) error {
userID := sessUserID(c)
if userID == 0 {
c.Redirect(http.StatusSeeOther, "/login")
return nil
}
name := c.FormValue("name")
desc := c.FormValue("description")
if name == "" || desc == "" {
return ErrBadReqeust
}
res, err := db.Exec(
"INSERT INTO channel (name, description, updated_at, created_at) VALUES (?, ?, NOW(), NOW())",
name, desc)
if err != nil {
return err
}
lastID, _ := res.LastInsertId()
return c.Redirect(http.StatusSeeOther,
fmt.Sprintf("/channel/%v", lastID))
}
func postProfile(c echo.Context) error {
userID := sessUserID(c)
if userID == 0 {
c.Redirect(http.StatusSeeOther, "/login")
return nil
}
avatarName := ""
var avatarData []byte
if fh, err := c.FormFile("avatar_icon"); err == http.ErrMissingFile {
// no file upload
} else if err != nil {
return err
} else {
dotPos := strings.LastIndexByte(fh.Filename, '.')
if dotPos < 0 {
return ErrBadReqeust
}
ext := fh.Filename[dotPos:]
switch ext {
case ".jpg", ".jpeg", ".png", ".gif":
break
default:
return ErrBadReqeust
}
file, err := fh.Open()
if err != nil {
return err
}
avatarData, _ = ioutil.ReadAll(file)
file.Close()
if len(avatarData) > avatarMaxBytes {
return ErrBadReqeust
}
avatarName = fmt.Sprintf("%x%s", sha1.Sum(avatarData), ext)
}
if avatarName != "" && len(avatarData) > 0 {
err := getRedisClient(iconKey(avatarName)).Set(fmt.Sprintf("icons/%s", avatarName), avatarData, 0).Err()
if err != nil {
return err
}
_, err = db.Exec("UPDATE user SET avatar_icon = ? WHERE id = ?", avatarName, userID)
if err != nil {
return err
}
}
if name := c.FormValue("display_name"); name != "" {
_, err := db.Exec("UPDATE user SET display_name = ? WHERE id = ?", name, userID)
if err != nil {
return err
}
}
return c.Redirect(http.StatusSeeOther, "/")
}
func getIcon(c echo.Context) error {
name := c.Param("file_name")
c.Response().Header().Set("Cache-Control", "public, max-age=31536000")
c.Response().Header().Set("ETag", name[0:len(name)-4])
c.Response().Header().Set("Last-Modified", "Mon, 16 Oct 2017 16:33:02 GMT")
if c.Request().Header.Get("If-Modified-Since") != "" || c.Request().Header.Get("If-None-Match") != "" {
return c.NoContent(304)
}
data, err := getRedisClient(iconKey(name)).Get(fmt.Sprintf("icons/%s", name)).Bytes()
if err == redis.Nil {
return echo.ErrNotFound
}
if err != nil {
return err
}
mime := ""
switch true {
case strings.HasSuffix(name, ".jpg"), strings.HasSuffix(name, ".jpeg"):
mime = "image/jpeg"
case strings.HasSuffix(name, ".png"):
mime = "image/png"
case strings.HasSuffix(name, ".gif"):
mime = "image/gif"
default:
return echo.ErrNotFound
}
return c.Blob(http.StatusOK, mime, data)
}
func tAdd(a, b int64) int64 {
return a + b
}
func tRange(a, b int64) []int64 {
r := make([]int64, b-a+1)
for i := int64(0); i <= (b - a); i++ {
r[i] = a + i
}
return r
}
func main() {
e := echo.New()
funcs := template.FuncMap{
"add": tAdd,
"xrange": tRange,
}
e.Renderer = &Renderer{
templates: template.Must(template.New("").Funcs(funcs).ParseGlob("views/*.html")),
}
e.Use(session.Middleware(sessions.NewCookieStore([]byte("secretonymoris"))))
e.Use(middleware.Static("../public"))
e.GET("/initialize", getInitialize)
e.GET("/initialize_redis", getInitializeRedis)
e.GET("/", getIndex)
e.GET("/register", getRegister)
e.POST("/register", postRegister)
e.GET("/login", getLogin)
e.POST("/login", postLogin)
e.GET("/logout", getLogout)
e.GET("/channel/:channel_id", getChannel)
e.GET("/message", getMessage)
e.POST("/message", postMessage)
e.GET("/fetch", fetchUnread)
e.GET("/history/:channel_id", getHistory)
e.GET("/profile/:user_name", getProfile)
e.POST("/profile", postProfile)
e.GET("add_channel", getAddChannel)
e.POST("add_channel", postAddChannel)
e.GET("/icons/:file_name", getIcon)
if path := os.Getenv("ISUBATA_UNIX_SOCKET"); path != "" {
os.Remove(path)
l, err := net.Listen("unix", path)
if err != nil {
e.Logger.Fatal(err)
}
if err := os.Chmod(path, 0777); err != nil {
e.Logger.Fatal(err)
}
e.Listener = l
e.Start("")
} else {
e.Start(":5000")
}
}