Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement callbacks needed for VPN client #526

Merged
merged 3 commits into from
Sep 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions cmd/apps/vpn-client/vpn-client.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,36 @@ func main() {
}
}

var directIPsCh = make(chan net.IP, 100)
var directIPsCh, nonDirectIPsCh = make(chan net.IP, 100), make(chan net.IP, 100)
defer close(directIPsCh)
defer close(nonDirectIPsCh)

eventSub := appevent.NewSubscriber()

eventSub.OnTCPDial(func(data appevent.TCPDialData) {
ip, ok, err := vpn.ParseIP(data.RemoteAddr)
parseIP := func(addr string) net.IP {
ip, ok, err := vpn.ParseIP(addr)
if err != nil {
log.WithError(err).Errorf("Failed to parse IP %s", data.RemoteAddr)
return
log.WithError(err).Errorf("Failed to parse IP %s", addr)
return nil
}
if !ok {
log.Errorf("Failed to parse IP %s", data.RemoteAddr)
return
log.Errorf("Failed to parse IP %s", addr)
return nil
}

directIPsCh <- ip
return ip
}

eventSub.OnTCPDial(func(data appevent.TCPDialData) {
if ip := parseIP(data.RemoteAddr); ip != nil {
directIPsCh <- ip
}
})

eventSub.OnTCPClose(func(data appevent.TCPCloseData) {
if ip := parseIP(data.RemoteAddr); ip != nil {
nonDirectIPsCh <- ip
}
})

appClient := app.NewClient(eventSub)
Expand Down Expand Up @@ -145,6 +159,14 @@ func main() {
}
}()

go func() {
for ip := range nonDirectIPsCh {
if err := vpnClient.RemoveDirectRoute(ip); err != nil {
log.WithError(err).Errorf("Failed to remove direct route to %s", ip.String())
}
}
}()

if err := vpnClient.Serve(); err != nil {
log.WithError(err).Fatalln("Error serving VPN")
}
Expand Down
39 changes: 33 additions & 6 deletions internal/vpn/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,25 @@ func (c *Client) AddDirectRoute(ip net.IP) error {
return nil
}

// RemoveDirectRoute removes direct route. Packets destined to `ip` will
// go through VPN.
func (c *Client) RemoveDirectRoute(ip net.IP) error {
c.directIPSMu.Lock()
defer c.directIPSMu.Unlock()

for i, storedIP := range c.directIPs {
if ip.Equal(storedIP) {
c.directIPs = append(c.directIPs[:i], c.directIPs[i+1:]...)

if err := c.removeDirectRoute(ip); err != nil {
return err
}
}
}

return nil
}

// Serve performs handshake with the server, sets up routing and starts handling traffic.
func (c *Client) Serve() error {
tunIP, tunGateway, err := c.shakeHands()
Expand Down Expand Up @@ -248,17 +267,25 @@ func (c *Client) setupDirectRoute(ip net.IP) error {
return nil
}

func (c *Client) removeDirectRoute(ip net.IP) error {
if !ip.IsLoopback() {
c.log.Infof("Removing direct route to %s", ip.String())
if err := DeleteRoute(ip.String()+directRouteNetmaskCIDR, c.defaultGateway.String()); err != nil {
return err
}
}

return nil
}

func (c *Client) removeDirectRoutes() {
c.directIPSMu.Lock()
defer c.directIPSMu.Unlock()

for _, ip := range c.directIPs {
if !ip.IsLoopback() {
c.log.Infof("Removing direct route to %s", ip.String())
if err := DeleteRoute(ip.String()+directRouteNetmaskCIDR, c.defaultGateway.String()); err != nil {
// shouldn't return, just keep on trying the other IPs
c.log.WithError(err).Errorf("Error removing direct route to %s", ip.String())
}
if err := c.removeDirectRoute(ip); err != nil {
// shouldn't return, just keep on trying the other IPs
c.log.WithError(err).Warnf("Error removing direct route to %s", ip.String())
}
}
}
Expand Down
62 changes: 47 additions & 15 deletions pkg/transport/managed_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ const (
tpFactor = 1.3
)

// ManagedTransportConfig is a configuration for managed transport.
type ManagedTransportConfig struct {
Net *snet.Network
DC DiscoveryClient
LS LogStore
RemotePK cipher.PubKey
NetName string
AfterClosed TPCloseCallback
}

// ManagedTransport manages a direct line of communication between two visor nodes.
// There is a single underlying connection between two edges.
// Initial dialing can be requested by either edge of the connection.
Expand Down Expand Up @@ -75,22 +85,26 @@ type ManagedTransport struct {
once sync.Once
wg sync.WaitGroup

remoteAddrs []string
remoteAddr string

afterClosedMu sync.RWMutex
afterClosed TPCloseCallback
}

// NewManagedTransport creates a new ManagedTransport.
func NewManagedTransport(n *snet.Network, dc DiscoveryClient, ls LogStore, rPK cipher.PubKey, netName string) *ManagedTransport {
func NewManagedTransport(conf ManagedTransportConfig) *ManagedTransport {
mt := &ManagedTransport{
log: logging.MustGetLogger(fmt.Sprintf("tp:%s", rPK.String()[:6])),
rPK: rPK,
netName: netName,
n: n,
dc: dc,
ls: ls,
Entry: makeEntry(n.LocalPK(), rPK, netName),
LogEntry: new(LogEntry),
connCh: make(chan struct{}, 1),
done: make(chan struct{}),
log: logging.MustGetLogger(fmt.Sprintf("tp:%s", conf.RemotePK.String()[:6])),
rPK: conf.RemotePK,
netName: conf.NetName,
n: conf.Net,
dc: conf.DC,
ls: conf.LS,
Entry: makeEntry(conf.Net.LocalPK(), conf.RemotePK, conf.NetName),
LogEntry: new(LogEntry),
connCh: make(chan struct{}, 1),
done: make(chan struct{}),
afterClosed: conf.AfterClosed,
}
mt.wg.Add(2)
return mt
Expand Down Expand Up @@ -203,6 +217,12 @@ func (mt *ManagedTransport) Serve(readCh chan<- routing.Packet) {
}
}

func (mt *ManagedTransport) onAfterClosed(f TPCloseCallback) {
mt.afterClosedMu.Lock()
mt.afterClosed = f
mt.afterClosedMu.Unlock()
}

func (mt *ManagedTransport) isServing() bool {
select {
case <-mt.done:
Expand All @@ -226,9 +246,21 @@ func (mt *ManagedTransport) Close() (err error) {
return err
}

// close stops serving the transport and ensures that transport status is updated to DOWN.
// It also waits until mt.Serve returns if specified.
func (mt *ManagedTransport) close() {
mt.disconnect()

mt.afterClosedMu.RLock()
afterClosed := mt.afterClosed
mt.afterClosedMu.RUnlock()

if afterClosed != nil {
afterClosed(mt.netName, mt.remoteAddr)
}
}

// disconnect stops serving the transport and ensures that transport status is updated to DOWN.
// It also waits until mt.Serve returns if specified.
func (mt *ManagedTransport) disconnect() {
mt.once.Do(func() { close(mt.done) })
_ = mt.updateStatus(false, 1) //nolint:errcheck
}
Expand Down Expand Up @@ -314,7 +346,7 @@ func (mt *ManagedTransport) redial(ctx context.Context) error {

// If the error is not temporary, it most likely means that the transport is no longer registered.
// Hence, we should close the managed transport.
mt.close()
mt.disconnect()
mt.log.
WithError(err).
Warn("Transport closed due to redial failure. Transport is likely no longer in discovery.")
Expand Down
45 changes: 40 additions & 5 deletions pkg/transport/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ const (
TrustedVisorsDelay = 5 * time.Second
)

// TPCloseCallback triggers after a session is closed.
type TPCloseCallback func(network, addr string)

// ManagerConfig configures a Manager.
type ManagerConfig struct {
PubKey cipher.PubKey
Expand Down Expand Up @@ -53,6 +56,8 @@ type Manager struct {
serveOnce sync.Once // ensure we only serve once.
closeOnce sync.Once // ensure we only close once.
done chan struct{}

afterTPClosed TPCloseCallback
}

// NewManager creates a Manager with the provided configuration and transport factories.
Expand All @@ -73,6 +78,19 @@ func NewManager(log *logging.Logger, n *snet.Network, config *ManagerConfig) (*M
return tm, nil
}

// OnAfterTPClosed sets callback which will fire after transport gets closed.
func (tm *Manager) OnAfterTPClosed(f TPCloseCallback) {
tm.mx.Lock()
defer tm.mx.Unlock()

tm.afterTPClosed = f

// set callback for all already known tps
for _, tp := range tm.tps {
tp.onAfterClosed(f)
}
}

// Serve runs listening loop across all registered factories.
func (tm *Manager) Serve(ctx context.Context) {
tm.serveOnce.Do(func() {
Expand Down Expand Up @@ -212,7 +230,14 @@ func (tm *Manager) acceptTransport(ctx context.Context, lis *snet.Listener) erro
if !ok {
tm.Logger.Debugln("No TP found, creating new one")

mTp = NewManagedTransport(tm.n, tm.Conf.DiscoveryClient, tm.Conf.LogStore, conn.RemotePK(), lis.Network())
mTp = NewManagedTransport(ManagedTransportConfig{
Net: tm.n,
DC: tm.Conf.DiscoveryClient,
LS: tm.Conf.LogStore,
RemotePK: conn.RemotePK(),
NetName: lis.Network(),
AfterClosed: tm.afterTPClosed,
})

go func() {
mTp.Serve(tm.readCh)
Expand Down Expand Up @@ -303,13 +328,23 @@ func (tm *Manager) saveTransport(remote cipher.PubKey, netName string) (*Managed
return oldMTp, nil
}

mTp := NewManagedTransport(tm.n, tm.Conf.DiscoveryClient, tm.Conf.LogStore, remote, netName)
afterTPClosed := tm.afterTPClosed

mTp := NewManagedTransport(ManagedTransportConfig{
Net: tm.n,
DC: tm.Conf.DiscoveryClient,
LS: tm.Conf.LogStore,
RemotePK: remote,
NetName: netName,
AfterClosed: afterTPClosed,
})

if mTp.netName == tptypes.STCPR {
ar := mTp.n.Conf().ARClient
if ar != nil {
visorData, err := ar.Resolve(context.Background(), mTp.netName, remote)
if err == nil {
mTp.remoteAddrs = append(mTp.remoteAddrs, visorData.RemoteAddr)
mTp.remoteAddr = visorData.RemoteAddr
} else {
if err != arclient.ErrNoEntry {
return nil, fmt.Errorf("failed to resolve %s: %w", remote, err)
Expand Down Expand Up @@ -337,8 +372,8 @@ func (tm *Manager) STCPRRemoteAddrs() []string {
defer tm.mx.RUnlock()

for _, tp := range tm.tps {
if tp.Entry.Type == tptypes.STCPR && len(tp.remoteAddrs) > 0 {
addrs = append(addrs, tp.remoteAddrs...)
if tp.Entry.Type == tptypes.STCPR && tp.remoteAddr != "" {
addrs = append(addrs, tp.remoteAddr)
}
}

Expand Down
10 changes: 10 additions & 0 deletions pkg/visor/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,16 @@ func initTransport(v *Visor) bool {
return report(fmt.Errorf("failed to start transport manager: %w", err))
}

tpM.OnAfterTPClosed(func(network, addr string) {
if network == tptypes.STCPR && addr != "" {
data := appevent.TCPCloseData{RemoteNet: network, RemoteAddr: addr}
event := appevent.NewEvent(appevent.TCPClose, data)
if err := v.ebc.Broadcast(context.Background(), event); err != nil {
v.log.WithError(err).Errorln("Failed to broadcast TCPClose event")
}
}
})

ctx, cancel := context.WithCancel(context.Background())
wg := new(sync.WaitGroup)
wg.Add(1)
Expand Down