Skip to content

Commit

Permalink
add signal test
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadrad committed Jun 21, 2020
1 parent 4c6ba14 commit 6aab376
Showing 1 changed file with 40 additions and 16 deletions.
56 changes: 40 additions & 16 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ xZe9KaFo+tXOg6ThEf/IFPjcGjJxNfNwYaszzdyXoS9HmM6S0GUqbrF84IjFNCqsNtnK2I
L+Ha2sPh5OB4w+j/xdvWwdevCA11HE3MDqjN6Uq0EMKfAlEbgkqePQB+uiFhSf3laAybgm
KNj5a3Q/DLNfsAAAAbbWVocmRhZEBNLU1hY0Jvb2stUHJvLmxvY2Fs
-----END OPENSSH PRIVATE KEY-----`)
sshTestSrvAddr = "127.0.0.1:5522"
testSrvStdin = ""
testSrvAddr = "127.0.0.1:5522"
testSrvStdin = ""
testSrvSig ssh.Signal
)

func init() {
Expand Down Expand Up @@ -121,7 +122,7 @@ func TestSessionsMaxOutRaceQueries(t *testing.T) {
config := GetConfigUserPass("vssh", "vssh")

vs := New().StartWithContext(ctx)
vs.AddClient(sshTestSrvAddr, config, SetMaxSessions(1), DisableRequestPty())
vs.AddClient(testSrvAddr, config, SetMaxSessions(1), DisableRequestPty())

vs.Wait(100)

Expand Down Expand Up @@ -164,7 +165,7 @@ func TestTimeout(t *testing.T) {
config := GetConfigUserPass("vssh", "vssh")

vs := New().StartWithContext(ctx)
vs.AddClient(sshTestSrvAddr, config, SetMaxSessions(1), DisableRequestPty())
vs.AddClient(testSrvAddr, config, SetMaxSessions(1), DisableRequestPty())

vs.Wait()

Expand All @@ -189,7 +190,7 @@ func TestGoroutineLeak(t *testing.T) {
config := GetConfigUserPass("vssh", "vssh")

vs := New().StartWithContext(ctx)
vs.AddClient(sshTestSrvAddr, config, SetMaxSessions(1), DisableRequestPty())
vs.AddClient(testSrvAddr, config, SetMaxSessions(1), DisableRequestPty())

vs.Wait()

Expand Down Expand Up @@ -219,7 +220,7 @@ func TestQueryWithLabel(t *testing.T) {
config := GetConfigUserPass("vssh", "vssh")

labels := map[string]string{"POP": "ORD"}
vs.AddClient(sshTestSrvAddr, config, SetLabels(labels), DisableRequestPty())
vs.AddClient(testSrvAddr, config, SetLabels(labels), DisableRequestPty())

vs.Wait(100)

Expand All @@ -245,7 +246,7 @@ func TestQueryWithLabel(t *testing.T) {
t.Fatal("expect to get true but got", ok)
}

client, ok := vs.clients.get(sshTestSrvAddr)
client, ok := vs.clients.get(testSrvAddr)
if !ok {
t.Error("expect test-client but not exist")
}
Expand All @@ -272,7 +273,7 @@ func TestOnDemand(t *testing.T) {

timeout, _ := time.ParseDuration("2s")
config := GetConfigUserPass("vssh", "vssh")
vs.AddClient(sshTestSrvAddr, config, SetMaxSessions(2), DisableRequestPty())
vs.AddClient(testSrvAddr, config, SetMaxSessions(2), DisableRequestPty())

d, _ := vs.Wait()
if d != 0 {
Expand All @@ -281,7 +282,7 @@ func TestOnDemand(t *testing.T) {

time.Sleep(time.Second)

client, ok := vs.clients.get(sshTestSrvAddr)
client, ok := vs.clients.get(testSrvAddr)
if !ok {
t.Error("expect test-client but not exist")
}
Expand Down Expand Up @@ -314,7 +315,7 @@ func TestStream(t *testing.T) {

timeout, _ := time.ParseDuration("6s")
config := GetConfigUserPass("vssh", "vssh")
vs.AddClient(sshTestSrvAddr, config, SetMaxSessions(2), DisableRequestPty())
vs.AddClient(testSrvAddr, config, SetMaxSessions(2), DisableRequestPty())

vs.Wait()
respChan := vs.Run(ctx, "ping", timeout)
Expand Down Expand Up @@ -350,7 +351,7 @@ func sshServer() {

config.AddHostKey(private)

l, err := net.Listen("tcp", sshTestSrvAddr)
l, err := net.Listen("tcp", testSrvAddr)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -390,6 +391,9 @@ func handler(newChannel ssh.NewChannel) {
go func(in <-chan *ssh.Request) {
for req := range in {
req.Reply(req.Type == "shell", nil)
if req.Type == "signal" {
testSrvSig = ssh.Signal(req.Payload[4:])
}
}
}(requests)

Expand Down Expand Up @@ -483,10 +487,10 @@ func TestClientsRace(t *testing.T) {
func TestNewSession(t *testing.T) {
vs := New().Start()
config := GetConfigUserPass("vssh", "vssh")
vs.AddClient(sshTestSrvAddr, config, SetMaxSessions(2), DisableRequestPty())
vs.AddClient(testSrvAddr, config, SetMaxSessions(2), DisableRequestPty())
vs.Wait()

client, _ := vs.clients.get(sshTestSrvAddr)
client, _ := vs.clients.get(testSrvAddr)
session, err := client.newSession()
if err != nil {
t.Fatal("unexpected error", err)
Expand Down Expand Up @@ -598,7 +602,7 @@ func TestStreamScanStderr(t *testing.T) {
}

func TestClientAttrRun(t *testing.T) {
addr := sshTestSrvAddr
addr := testSrvAddr
ch := make(chan *Response, 1)
client := &clientAttr{
addr: addr,
Expand Down Expand Up @@ -658,10 +662,10 @@ func TestResponseID(t *testing.T) {
func TestGetScanners(t *testing.T) {
vs := New().Start()
config := GetConfigUserPass("vssh", "vssh")
vs.AddClient(sshTestSrvAddr, config, SetMaxSessions(2), DisableRequestPty())
vs.AddClient(testSrvAddr, config, SetMaxSessions(2), DisableRequestPty())
vs.Wait()

client, _ := vs.clients.get(sshTestSrvAddr)
client, _ := vs.clients.get(testSrvAddr)
session, err := client.newSession()
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -695,3 +699,23 @@ func TestClientConnect(t *testing.T) {
t.Fatal("shouldn't try connect when the max sessions is zero")
}
}

func TestStreamSignal(t *testing.T) {
vs := New().Start()

timeout, _ := time.ParseDuration("6s")
config := GetConfigUserPass("vssh", "vssh")
vs.AddClient(testSrvAddr, config, SetMaxSessions(2), DisableRequestPty())
vs.Wait()

respChan := vs.Run(context.Background(), "ping", timeout)
resp := <-respChan
stream := resp.GetStream()
defer stream.Close()

stream.Signal(ssh.SIGALRM)
time.Sleep(time.Millisecond * 100)
if testSrvSig != ssh.SIGALRM {
t.Error("expect to get SIGALRM but got", testSrvSig)
}
}

0 comments on commit 6aab376

Please sign in to comment.