Skip to content

Commit

Permalink
fix: detect default comm clause in select from AST (#678)
Browse files Browse the repository at this point in the history
* fix: detect default comm clause in select from AST

The heuristic to distinguish a default comm clause was too weak.
Make it robust by using AST.

Fixes #646.

* rename test to avoid conflict
  • Loading branch information
mvertes committed Jun 10, 2020
1 parent 0ef7f8f commit 82b499a
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 82 deletions.
16 changes: 16 additions & 0 deletions _test/select13.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package main

func main() {
var c interface{} = int64(1)
q := make(chan struct{})
select {
case q <- struct{}{}:
println("unexpected")
default:
_ = c.(int64)
}
println("bye")
}

// Output:
// bye
156 changes: 81 additions & 75 deletions interp/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const (
caseClause
chanType
commClause
commClauseDefault
compositeLitExpr
constDecl
continueStmt
Expand Down Expand Up @@ -93,80 +94,81 @@ const (
)

var kinds = [...]string{
undefNode: "undefNode",
addressExpr: "addressExpr",
arrayType: "arrayType",
assignStmt: "assignStmt",
assignXStmt: "assignXStmt",
basicLit: "basicLit",
binaryExpr: "binaryExpr",
blockStmt: "blockStmt",
branchStmt: "branchStmt",
breakStmt: "breakStmt",
callExpr: "callExpr",
caseBody: "caseBody",
caseClause: "caseClause",
chanType: "chanType",
commClause: "commClause",
compositeLitExpr: "compositeLitExpr",
constDecl: "constDecl",
continueStmt: "continueStmt",
declStmt: "declStmt",
deferStmt: "deferStmt",
defineStmt: "defineStmt",
defineXStmt: "defineXStmt",
ellipsisExpr: "ellipsisExpr",
exprStmt: "exprStmt",
fallthroughtStmt: "fallthroughStmt",
fieldExpr: "fieldExpr",
fieldList: "fieldList",
fileStmt: "fileStmt",
forStmt0: "forStmt0",
forStmt1: "forStmt1",
forStmt2: "forStmt2",
forStmt3: "forStmt3",
forStmt3a: "forStmt3a",
forStmt4: "forStmt4",
forRangeStmt: "forRangeStmt",
funcDecl: "funcDecl",
funcType: "funcType",
funcLit: "funcLit",
goStmt: "goStmt",
gotoStmt: "gotoStmt",
identExpr: "identExpr",
ifStmt0: "ifStmt0",
ifStmt1: "ifStmt1",
ifStmt2: "ifStmt2",
ifStmt3: "ifStmt3",
importDecl: "importDecl",
importSpec: "importSpec",
incDecStmt: "incDecStmt",
indexExpr: "indexExpr",
interfaceType: "interfaceType",
keyValueExpr: "keyValueExpr",
labeledStmt: "labeledStmt",
landExpr: "landExpr",
lorExpr: "lorExpr",
mapType: "mapType",
parenExpr: "parenExpr",
rangeStmt: "rangeStmt",
returnStmt: "returnStmt",
selectStmt: "selectStmt",
selectorExpr: "selectorExpr",
selectorImport: "selectorImport",
sendStmt: "sendStmt",
sliceExpr: "sliceExpr",
starExpr: "starExpr",
structType: "structType",
switchStmt: "switchStmt",
switchIfStmt: "switchIfStmt",
typeAssertExpr: "typeAssertExpr",
typeDecl: "typeDecl",
typeSpec: "typeSpec",
typeSwitch: "typeSwitch",
unaryExpr: "unaryExpr",
valueSpec: "valueSpec",
varDecl: "varDecl",
undefNode: "undefNode",
addressExpr: "addressExpr",
arrayType: "arrayType",
assignStmt: "assignStmt",
assignXStmt: "assignXStmt",
basicLit: "basicLit",
binaryExpr: "binaryExpr",
blockStmt: "blockStmt",
branchStmt: "branchStmt",
breakStmt: "breakStmt",
callExpr: "callExpr",
caseBody: "caseBody",
caseClause: "caseClause",
chanType: "chanType",
commClause: "commClause",
commClauseDefault: "commClauseDefault",
compositeLitExpr: "compositeLitExpr",
constDecl: "constDecl",
continueStmt: "continueStmt",
declStmt: "declStmt",
deferStmt: "deferStmt",
defineStmt: "defineStmt",
defineXStmt: "defineXStmt",
ellipsisExpr: "ellipsisExpr",
exprStmt: "exprStmt",
fallthroughtStmt: "fallthroughStmt",
fieldExpr: "fieldExpr",
fieldList: "fieldList",
fileStmt: "fileStmt",
forStmt0: "forStmt0",
forStmt1: "forStmt1",
forStmt2: "forStmt2",
forStmt3: "forStmt3",
forStmt3a: "forStmt3a",
forStmt4: "forStmt4",
forRangeStmt: "forRangeStmt",
funcDecl: "funcDecl",
funcType: "funcType",
funcLit: "funcLit",
goStmt: "goStmt",
gotoStmt: "gotoStmt",
identExpr: "identExpr",
ifStmt0: "ifStmt0",
ifStmt1: "ifStmt1",
ifStmt2: "ifStmt2",
ifStmt3: "ifStmt3",
importDecl: "importDecl",
importSpec: "importSpec",
incDecStmt: "incDecStmt",
indexExpr: "indexExpr",
interfaceType: "interfaceType",
keyValueExpr: "keyValueExpr",
labeledStmt: "labeledStmt",
landExpr: "landExpr",
lorExpr: "lorExpr",
mapType: "mapType",
parenExpr: "parenExpr",
rangeStmt: "rangeStmt",
returnStmt: "returnStmt",
selectStmt: "selectStmt",
selectorExpr: "selectorExpr",
selectorImport: "selectorImport",
sendStmt: "sendStmt",
sliceExpr: "sliceExpr",
starExpr: "starExpr",
structType: "structType",
switchStmt: "switchStmt",
switchIfStmt: "switchIfStmt",
typeAssertExpr: "typeAssertExpr",
typeDecl: "typeDecl",
typeSpec: "typeSpec",
typeSwitch: "typeSwitch",
unaryExpr: "unaryExpr",
valueSpec: "valueSpec",
varDecl: "varDecl",
}

func (k nkind) String() string {
Expand Down Expand Up @@ -565,7 +567,11 @@ func (interp *Interpreter) ast(src, name string) (string, *node, error) {
st.push(addChild(&root, anc, pos, chanType, aNop), nod)

case *ast.CommClause:
st.push(addChild(&root, anc, pos, commClause, aNop), nod)
kind := commClause
if a.Comm == nil {
kind = commClauseDefault
}
st.push(addChild(&root, anc, pos, kind, aNop), nod)

case *ast.CommentGroup:
return false
Expand Down
23 changes: 16 additions & 7 deletions interp/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) {
nod.typ = typ
}

case commClauseDefault:
sc = sc.pushBloc()

case commClause:
sc = sc.pushBloc()
if len(n.child) > 0 && n.child[0].action == aAssign {
Expand Down Expand Up @@ -903,19 +906,25 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) {
case caseClause:
sc = sc.pop()

case commClauseDefault:
wireChild(n)
sc = sc.pop()
if len(n.child) == 0 {
return
}
n.start = n.child[0].start
n.lastChild().tnext = n.anc.anc // exit node is selectStmt

case commClause:
wireChild(n)
switch len(n.child) {
case 0:
sc.pop()
sc = sc.pop()
if len(n.child) == 0 {
return
case 1:
n.start = n.child[0].start // default clause
default:
}
if len(n.child) > 1 {
n.start = n.child[1].start // Skip chan operation, performed by select
}
n.lastChild().tnext = n.anc.anc // exit node is selectStmt
sc = sc.pop()

case compositeLitExpr:
wireChild(n)
Expand Down

0 comments on commit 82b499a

Please sign in to comment.