Skip to content

Commit

Permalink
refactored ws
Browse files Browse the repository at this point in the history
  • Loading branch information
jubeless committed Nov 9, 2020
1 parent 81f4290 commit 76687fa
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 112 deletions.
129 changes: 21 additions & 108 deletions rpc/ws.go → rpc/ws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package rpc
package ws

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"math/rand"
"reflect"
"sync"

"github.com/dfuse-io/solana-go/rpc"
"github.com/gorilla/rpc/v2/json2"
"github.com/gorilla/websocket"
"github.com/tidwall/gjson"
Expand All @@ -32,7 +32,7 @@ import (

type result interface{}

type WSClient struct {
type Client struct {
rpcURL string
conn *websocket.Conn
lock sync.RWMutex
Expand All @@ -41,8 +41,8 @@ type WSClient struct {
reconnectOnErr bool
}

func Dial(ctx context.Context, rpcURL string) (c *WSClient, err error) {
c = &WSClient{
func Dial(ctx context.Context, rpcURL string) (c *Client, err error) {
c = &Client{
rpcURL: rpcURL,
subscriptionByRequestID: map[uint64]*Subscription{},
subscriptionByWSSubID: map[uint64]*Subscription{},
Expand All @@ -57,11 +57,11 @@ func Dial(ctx context.Context, rpcURL string) (c *WSClient, err error) {
return c, nil
}

func (c *WSClient) Close() {
func (c *Client) Close() {
c.conn.Close()
}

func (c *WSClient) receiveMessages() {
func (c *Client) receiveMessages() {
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
Expand All @@ -72,7 +72,7 @@ func (c *WSClient) receiveMessages() {
}
}

func (c *WSClient) handleMessage(message []byte) {
func (c *Client) handleMessage(message []byte) {
// when receiving message with id. the result will be a subscription number.
// that number will be associated to all future message destine to this request
if gjson.GetBytes(message, "id").Exists() {
Expand All @@ -86,7 +86,7 @@ func (c *WSClient) handleMessage(message []byte) {

}

func (c *WSClient) handleNewSubscriptionMessage(requestID, subID uint64) {
func (c *Client) handleNewSubscriptionMessage(requestID, subID uint64) {
c.lock.Lock()
defer c.lock.Unlock()

Expand All @@ -100,7 +100,7 @@ func (c *WSClient) handleNewSubscriptionMessage(requestID, subID uint64) {
return
}

func (c *WSClient) handleSubscriptionMessage(subID uint64, message []byte) {
func (c *Client) handleSubscriptionMessage(subID uint64, message []byte) {
zlog.Info("received subscription message",
zap.Uint64("subscription_id", subID),
)
Expand All @@ -116,7 +116,7 @@ func (c *WSClient) handleSubscriptionMessage(subID uint64, message []byte) {
//getting and instantiate the return type for the call back.
resultType := reflect.New(sub.reflectType)
result := resultType.Interface()
err := decodeClientResponse(bytes.NewReader(message), &result)
err := decodeResponse(bytes.NewReader(message), &result)
if err != nil {
c.closeSubscription(sub.req.ID, fmt.Errorf("unable to decode client response: %w", err))
return
Expand All @@ -133,7 +133,7 @@ func (c *WSClient) handleSubscriptionMessage(subID uint64, message []byte) {
return
}

func (c *WSClient) closeAllSubscription(err error) {
func (c *Client) closeAllSubscription(err error) {
c.lock.Lock()
defer c.lock.Unlock()

Expand All @@ -145,7 +145,7 @@ func (c *WSClient) closeAllSubscription(err error) {
c.subscriptionByWSSubID = map[uint64]*Subscription{}
}

func (c *WSClient) closeSubscription(reqID uint64, err error) {
func (c *Client) closeSubscription(reqID uint64, err error) {
c.lock.Lock()
defer c.lock.Unlock()

Expand All @@ -156,7 +156,7 @@ func (c *WSClient) closeSubscription(reqID uint64, err error) {

sub.err <- err

err = c.rpcUnsubscribe(sub.subID, sub.unsubscriptionMethod)
err = c.unsubscribe(sub.subID, sub.unsubscriptionMethod)
if err != nil {
zlog.Warn("unable to send rpc unsubscribe call",
zap.Error(err),
Expand All @@ -167,8 +167,8 @@ func (c *WSClient) closeSubscription(reqID uint64, err error) {
delete(c.subscriptionByWSSubID, sub.subID)
}

func (c *WSClient) rpcUnsubscribe(subID uint64, method string) error {
req := newClientRequest([]interface{}{subID}, method, map[string]interface{}{})
func (c *Client) unsubscribe(subID uint64, method string) error {
req := newRequest([]interface{}{subID}, method, map[string]interface{}{})
data, err := req.encode()
if err != nil {
return fmt.Errorf("unable to encode unsubscription message for subID %d and method %s", subID, method)
Expand All @@ -181,57 +181,15 @@ func (c *WSClient) rpcUnsubscribe(subID uint64, method string) error {
return nil
}

type Subscription struct {
req *clientRequest
subID uint64
stream chan result
err chan error
reflectType reflect.Type
closeFunc func(err error)
unsubscriptionMethod string
}

func newSubscription(req *clientRequest, reflectType reflect.Type, closeFunc func(err error)) *Subscription {
return &Subscription{
req: req,
reflectType: reflectType,
stream: make(chan result, 200),
err: make(chan error, 1),
closeFunc: closeFunc,
}
}

func (s *Subscription) Recv() (interface{}, error) {
select {
case d := <-s.stream:
return d, nil
case err := <-s.err:
return nil, err
}
}

func (s *Subscription) Unsubscribe() {
s.unsubscribe(nil)
}

func (s *Subscription) unsubscribe(err error) {
s.closeFunc(err)

}

func (c *WSClient) ProgramSubscribe(programID string, commitment CommitmentType) (*Subscription, error) {
return c.subscribe([]interface{}{programID}, "programSubscribe", "programUnsubscribe", commitment, ProgramWSResult{})
}

func (c *WSClient) subscribe(params []interface{}, subscriptionMethod, unsubscriptionMethod string, commitment CommitmentType, resultType interface{}) (*Subscription, error) {
func (c *Client) subscribe(params []interface{}, subscriptionMethod, unsubscriptionMethod string, commitment rpc.CommitmentType, resultType interface{}) (*Subscription, error) {
conf := map[string]interface{}{
"encoding": "jsonParsed",
"encoding": "base64",
}
if commitment != "" {
conf["commitment"] = string(commitment)
}

req := newClientRequest(params, subscriptionMethod, conf)
req := newRequest(params, subscriptionMethod, conf)
data, err := req.encode()
if err != nil {
return nil, fmt.Errorf("subscribe: unable to encode subsciption request: %w", err)
Expand All @@ -254,53 +212,8 @@ func (c *WSClient) subscribe(params []interface{}, subscriptionMethod, unsubscri
return sub, nil
}

type ProgramWSResult struct {
Context struct {
Slot uint64
} `json:"context"`
Value struct {
Account Account `json:"account"`
} `json:"value"`
}

type clientRequest struct {
Version string `json:"jsonrpc"`
Method string `json:"method"`
Params interface{} `json:"params"`
ID uint64 `json:"id"`
}

func newClientRequest(params []interface{}, method string, configuration map[string]interface{}) *clientRequest {
params = append(params, configuration)
return &clientRequest{
Version: "2.0",
Method: method,
Params: params,
ID: uint64(rand.Int63()),
}
}

func (c *clientRequest) encode() ([]byte, error) {
data, err := json.Marshal(c)
if err != nil {
return nil, fmt.Errorf("encode request: json marshall: %w", err)
}
return data, nil
}

type wsClientResponse struct {
Version string `json:"jsonrpc"`
Params *wsClientResponseParams `json:"params"`
Error *json.RawMessage `json:"error"`
}

type wsClientResponseParams struct {
Result *json.RawMessage `json:"result"`
Subscription int `json:"subscription"`
}

func decodeClientResponse(r io.Reader, reply interface{}) (err error) {
var c *wsClientResponse
func decodeResponse(r io.Reader, reply interface{}) (err error) {
var c *response
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
Expand Down
32 changes: 28 additions & 4 deletions rpc/ws_test.go → rpc/ws/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,58 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package rpc
package ws

import (
"context"
"fmt"
"testing"

"github.com/dfuse-io/solana-go"

"go.uber.org/zap"

"github.com/stretchr/testify/require"
)

func TestWSClient_ProgramSubscribe(t *testing.T) {
func Test_AccountSubscribe(t *testing.T) {
zlog, _ = zap.NewDevelopment()

c, err := Dial(context.Background(), "ws://api.mainnet-beta.solana.com:80/rpc")
defer c.Close()
require.NoError(t, err)

accountID := solana.MustPublicKeyFromBase58("SqJP6vrvMad5XBQK5PCFEZjeuQSFi959sdpqtSNvnsX")
sub, err := c.AccountSubscribe(accountID, "")
require.NoError(t, err)

data, err := sub.Recv()
if err != nil {
fmt.Println("receive an error: ", err)
return
}
fmt.Println("data received: ", data.(*AccountResult).Value.Account.Owner)
return

}

func Test_ProgramSubscribe(t *testing.T) {
zlog, _ = zap.NewDevelopment()

c, err := Dial(context.Background(), "ws://api.mainnet-beta.solana.com:80/rpc")
defer c.Close()
require.NoError(t, err)

sub, err := c.ProgramSubscribe("EUqojwWA2rd19FZrzeBncJsm38Jm1hEhE3zsmX3bRc2o", "")
programID := solana.MustPublicKeyFromBase58("EUqojwWA2rd19FZrzeBncJsm38Jm1hEhE3zsmX3bRc2o")
sub, err := c.ProgramSubscribe(programID, "")
require.NoError(t, err)

data, err := sub.Recv()
if err != nil {
fmt.Println("receive an error: ", err)
return
}
fmt.Println("data received: ", data.(*ProgramWSResult).Value.Account.Owner)
fmt.Println("data received: ", data.(*ProgramResult).Value.Account.Owner)
return

}
29 changes: 29 additions & 0 deletions rpc/ws/logging.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright 2020 dfuse Platform Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ws

import (
"os"

"github.com/dfuse-io/logging"
"go.uber.org/zap"
)

var traceEnabled = os.Getenv("TRACE") == "true"
var zlog *zap.Logger

func init() {
logging.Register("github.com/dfuse-io/solana-go/rpc/ws", &zlog)
}
14 changes: 14 additions & 0 deletions rpc/ws/method.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package ws

import (
"github.com/dfuse-io/solana-go"
"github.com/dfuse-io/solana-go/rpc"
)

func (c *Client) ProgramSubscribe(programId solana.PublicKey, commitment rpc.CommitmentType) (*Subscription, error) {
return c.subscribe([]interface{}{programId.String()}, "programSubscribe", "programUnsubscribe", commitment, ProgramResult{})
}

func (c *Client) AccountSubscribe(account solana.PublicKey, commitment rpc.CommitmentType) (*Subscription, error) {
return c.subscribe([]interface{}{account.String()}, "accountSubscribe", "accountUnsubscribe", commitment, AccountResult{})
}
41 changes: 41 additions & 0 deletions rpc/ws/subscription.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package ws

import "reflect"

type Subscription struct {
req *request
subID uint64
stream chan result
err chan error
reflectType reflect.Type
closeFunc func(err error)
unsubscriptionMethod string
}

func newSubscription(req *request, reflectType reflect.Type, closeFunc func(err error)) *Subscription {
return &Subscription{
req: req,
reflectType: reflectType,
stream: make(chan result, 200),
err: make(chan error, 1),
closeFunc: closeFunc,
}
}

func (s *Subscription) Recv() (interface{}, error) {
select {
case d := <-s.stream:
return d, nil
case err := <-s.err:
return nil, err
}
}

func (s *Subscription) Unsubscribe() {
s.unsubscribe(nil)
}

func (s *Subscription) unsubscribe(err error) {
s.closeFunc(err)

}
Loading

0 comments on commit 76687fa

Please sign in to comment.