Skip to content

Commit

Permalink
Merge pull request #133 from qianbin/ws-cors
Browse files Browse the repository at this point in the history
feat(api): websocket cors
  • Loading branch information
qianbin authored Aug 29, 2018
2 parents 3ac449a + a5e1fad commit 8817785
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 21 deletions.
16 changes: 13 additions & 3 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ package api

import (
"net/http"
"strings"

assetfs "github.com/elazarl/go-bindata-assetfs"
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
"github.com/vechain/thor/api/accounts"
"github.com/vechain/thor/api/blocks"
Expand All @@ -25,7 +27,12 @@ import (
)

//New return api router
func New(chain *chain.Chain, stateCreator *state.Creator, txPool *txpool.TxPool, logDB *logdb.LogDB, nw node.Network) (http.HandlerFunc, func()) {
func New(chain *chain.Chain, stateCreator *state.Creator, txPool *txpool.TxPool, logDB *logdb.LogDB, nw node.Network, allowedOrigins string) (http.HandlerFunc, func()) {
origins := strings.Split(strings.TrimSpace(allowedOrigins), ",")
for i, o := range origins {
origins[i] = strings.ToLower(strings.TrimSpace(o))
}

router := mux.NewRouter()

// to serve api doc and swagger-ui
Expand Down Expand Up @@ -58,8 +65,11 @@ func New(chain *chain.Chain, stateCreator *state.Creator, txPool *txpool.TxPool,
Mount(router, "/transactions")
node.New(nw).
Mount(router, "/node")
subs := subscriptions.New(chain)
subs := subscriptions.New(chain, origins)
subs.Mount(router, "/subscriptions")

return router.ServeHTTP, subs.Close // subscriptions handles hijacked conns, which need to be closed
return handlers.CORS(
handlers.AllowedOrigins(origins),
handlers.AllowedHeaders([]string{"content-type"}))(router).ServeHTTP,
subs.Close // subscriptions handles hijacked conns, which need to be closed
}
33 changes: 25 additions & 8 deletions api/subscriptions/subscriptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,39 @@ import (
)

type Subscriptions struct {
chain *chain.Chain
done chan struct{}
wg sync.WaitGroup
chain *chain.Chain
upgrader *websocket.Upgrader
done chan struct{}
wg sync.WaitGroup
}

type msgReader interface {
Read() (msgs []interface{}, hasMore bool, err error)
}

var (
upgrader = websocket.Upgrader{}
log = log15.New("pkg", "subscriptions")
log = log15.New("pkg", "subscriptions")
)

func New(chain *chain.Chain) *Subscriptions {
return &Subscriptions{chain: chain, done: make(chan struct{})}
func New(chain *chain.Chain, allowedOrigins []string) *Subscriptions {
return &Subscriptions{
chain: chain,
upgrader: &websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
return true
}
for _, allowedOrigin := range allowedOrigins {
if allowedOrigin == origin || allowedOrigin == "*" {
return true
}
}
return false
},
},
done: make(chan struct{}),
}
}

func (s *Subscriptions) handleBlockReader(w http.ResponseWriter, req *http.Request) (*blockReader, error) {
Expand Down Expand Up @@ -138,7 +155,7 @@ func (s *Subscriptions) handleSubject(w http.ResponseWriter, req *http.Request)
return utils.HTTPError(errors.New("not found"), http.StatusNotFound)
}

conn, err := upgrader.Upgrade(w, req, nil)
conn, err := s.upgrader.Upgrade(w, req, nil)
// since the conn is hijacked here, no error should be returned in lines below
if err != nil {
log.Debug("upgrade to websocket", "err", err)
Expand Down
4 changes: 2 additions & 2 deletions cmd/thor/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func defaultAction(ctx *cli.Context) error {

p2pcom := newP2PComm(ctx, chain, txPool, instanceDir)

apiHandler, apiCloser := api.New(chain, state.NewCreator(mainDB), txPool, logDB, p2pcom.comm)
apiHandler, apiCloser := api.New(chain, state.NewCreator(mainDB), txPool, logDB, p2pcom.comm, ctx.String(apiCorsFlag.Name))
defer func() { log.Info("closing API..."); apiCloser() }()

apiURL, srvCloser := startAPIServer(ctx, apiHandler, chain.GenesisBlock().Header().ID())
Expand Down Expand Up @@ -178,7 +178,7 @@ func soloAction(ctx *cli.Context) error {
txPool := txpool.New(chain, state.NewCreator(mainDB), defaultTxPoolOptions)
defer func() { log.Info("closing tx pool..."); txPool.Close() }()

apiHandler, apiCloser := api.New(chain, state.NewCreator(mainDB), txPool, logDB, solo.Communicator{})
apiHandler, apiCloser := api.New(chain, state.NewCreator(mainDB), txPool, logDB, solo.Communicator{}, ctx.String(apiCorsFlag.Name))
defer func() { log.Info("closing API..."); apiCloser() }()

apiURL, srvCloser := startAPIServer(ctx, apiHandler, chain.GenesisBlock().Header().ID())
Expand Down
8 changes: 0 additions & 8 deletions cmd/thor/must.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"net/http"
"os"
"path/filepath"
"strings"
"time"

"github.com/ethereum/go-ethereum/common"
Expand All @@ -22,7 +21,6 @@ import (
ethlog "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/rlp"
"github.com/gorilla/handlers"
"github.com/inconshreveable/log15"
"github.com/vechain/thor/chain"
"github.com/vechain/thor/cmd/thor/node"
Expand Down Expand Up @@ -265,12 +263,6 @@ func startAPIServer(ctx *cli.Context, handler http.Handler, genesisID thor.Bytes
fatal(fmt.Sprintf("listen API addr [%v]: %v", addr, err))
}

if origins := ctx.String(apiCorsFlag.Name); origins != "" {
handler = handlers.CORS(
handlers.AllowedOrigins(strings.Split(origins, ",")),
handlers.AllowedHeaders([]string{"content-type"}),
)(handler)
}
handler = handleXGenesisID(handler, genesisID)
handler = requestBodyLimit(handler)
srv := &http.Server{Handler: handler}
Expand Down

0 comments on commit 8817785

Please sign in to comment.