Skip to content

Commit

Permalink
interp: add support of Go generics in interpreter
Browse files Browse the repository at this point in the history
Status:
* [x] parsing code with generics
* [x] instantiate generics from concrete types
* [x] automatic type inference
* [x] support of generic recursive types 
* [x] support of generic methods
* [x] support of generic receivers in methods
* [x] support of multiple type parameters
* [x] support of generic constraints
* [x] tests (see _test/gen*.go)

Fixes #1363.
  • Loading branch information
mvertes authored Aug 3, 2022
1 parent 255b1cf commit 14bc3b5
Show file tree
Hide file tree
Showing 19 changed files with 986 additions and 123 deletions.
39 changes: 39 additions & 0 deletions _test/gen1.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package main

import "fmt"

// SumInts adds together the values of m.
func SumInts(m map[string]int64) int64 {
var s int64
for _, v := range m {
s += v
}
return s
}

// SumFloats adds together the values of m.
func SumFloats(m map[string]float64) float64 {
var s float64
for _, v := range m {
s += v
}
return s
}

func main() {
// Initialize a map for the integer values
ints := map[string]int64{
"first": 34,
"second": 12,
}

// Initialize a map for the float values
floats := map[string]float64{
"first": 35.98,
"second": 26.99,
}

fmt.Printf("Non-Generic Sums: %v and %v\n",
SumInts(ints),
SumFloats(floats))
}
34 changes: 34 additions & 0 deletions _test/gen2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package main

import "fmt"

// SumIntsOrFloats sums the values of map m. It supports both int64 and float64
// as types for map values.
func SumIntsOrFloats[K comparable, V int64 | float64](m map[K]V) V {
var s V
for _, v := range m {
s += v
}
return s
}

func main() {
// Initialize a map for the integer values
ints := map[string]int64{
"first": 34,
"second": 12,
}

// Initialize a map for the float values
floats := map[string]float64{
"first": 35.98,
"second": 26.99,
}

fmt.Printf("Generic Sums: %v and %v\n",
SumIntsOrFloats[string, int64](ints),
SumIntsOrFloats[string, float64](floats))
}

// Output:
// Generic Sums: 46 and 62.97
22 changes: 22 additions & 0 deletions _test/gen3.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package main

type Number interface {
int | int64 | ~float64
}

func Sum[T Number](numbers []T) T {
var total T
for _, x := range numbers {
total += x
}
return total
}

func main() {
xs := []int{3, 5, 10}
total := Sum(xs)
println(total)
}

// Output:
// 18
42 changes: 42 additions & 0 deletions _test/gen4.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package main

import "fmt"

type List[T any] struct {
head, tail *element[T]
}

// A recursive generic type.
type element[T any] struct {
next *element[T]
val T
}

func (lst *List[T]) Push(v T) {
if lst.tail == nil {
lst.head = &element[T]{val: v}
lst.tail = lst.head
} else {
lst.tail.next = &element[T]{val: v}
lst.tail = lst.tail.next
}
}

func (lst *List[T]) GetAll() []T {
var elems []T
for e := lst.head; e != nil; e = e.next {
elems = append(elems, e.val)
}
return elems
}

func main() {
lst := List[int]{}
lst.Push(10)
lst.Push(13)
lst.Push(23)
fmt.Println("list:", lst.GetAll())
}

// Output:
// list: [10 13 23]
24 changes: 24 additions & 0 deletions _test/gen5.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package main

import "fmt"

type Set[Elem comparable] struct {
m map[Elem]struct{}
}

func Make[Elem comparable]() Set[Elem] {
return Set[Elem]{m: make(map[Elem]struct{})}
}

func (s Set[Elem]) Add(v Elem) {
s.m[v] = struct{}{}
}

func main() {
s := Make[int]()
s.Add(1)
fmt.Println(s)
}

// Output:
// {map[1:{}]}
19 changes: 19 additions & 0 deletions _test/gen6.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package main

func MapKeys[K comparable, V any](m map[K]V) []K {
r := make([]K, 0, len(m))
for k := range m {
r = append(r, k)
}
return r
}

func main() {
var m = map[int]string{1: "2", 2: "4", 4: "8"}

// Test type inference
println(len(MapKeys(m)))
}

// Output:
// 3
19 changes: 19 additions & 0 deletions _test/gen7.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package main

func MapKeys[K comparable, V any](m map[K]V) []K {
r := make([]K, 0, len(m))
for k := range m {
r = append(r, k)
}
return r
}

func main() {
var m = map[int]string{1: "2", 2: "4", 4: "8"}

// Test type inference
println(len(MapKeys))
}

// Error:
// invalid argument for len
15 changes: 15 additions & 0 deletions _test/gen8.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package main

type Float interface {
~float32 | ~float64
}

func add[T Float](a, b T) float64 { return float64(a) + float64(b) }

func main() {
var x, y int = 1, 2
println(add(x, y))
}

// Error:
// int does not implement main.Float
14 changes: 14 additions & 0 deletions _test/gen9.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package main

type Float interface {
~float32 | ~float64
}

func add[T Float](a, b T) float64 { return float64(a) + float64(b) }

func main() {
println(add(1, 2))
}

// Error:
// untyped int does not implement main.Float
15 changes: 13 additions & 2 deletions interp/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ const (
importSpec
incDecStmt
indexExpr
indexListExpr
interfaceType
keyValueExpr
labeledStmt
Expand Down Expand Up @@ -155,6 +156,7 @@ var kinds = [...]string{
importSpec: "importSpec",
incDecStmt: "incDecStmt",
indexExpr: "indexExpr",
indexListExpr: "indexListExpr",
interfaceType: "interfaceType",
keyValueExpr: "keyValueExpr",
labeledStmt: "labeledStmt",
Expand Down Expand Up @@ -694,7 +696,7 @@ func (interp *Interpreter) ast(f ast.Node) (string, *node, error) {
n := addChild(&root, anc, pos, funcDecl, aNop)
n.val = n
if a.Recv == nil {
// function is not a method, create an empty receiver list
// Function is not a method, create an empty receiver list.
addChild(&root, astNode{n, nod}, pos, fieldList, aNop)
}
st.push(n, nod)
Expand All @@ -706,7 +708,13 @@ func (interp *Interpreter) ast(f ast.Node) (string, *node, error) {
st.push(n, nod)

case *ast.FuncType:
st.push(addChild(&root, anc, pos, funcType, aNop), nod)
n := addChild(&root, anc, pos, funcType, aNop)
n.val = n
if a.TypeParams == nil {
// Function has no type parameters, create an empty fied list.
addChild(&root, astNode{n, nod}, pos, fieldList, aNop)
}
st.push(n, nod)

case *ast.GenDecl:
var kind nkind
Expand Down Expand Up @@ -776,6 +784,9 @@ func (interp *Interpreter) ast(f ast.Node) (string, *node, error) {
case *ast.IndexExpr:
st.push(addChild(&root, anc, pos, indexExpr, aGetIndex), nod)

case *ast.IndexListExpr:
st.push(addChild(&root, anc, pos, indexListExpr, aNop), nod)

case *ast.InterfaceType:
st.push(addChild(&root, anc, pos, interfaceType, aNop), nod)

Expand Down
Loading

0 comments on commit 14bc3b5

Please sign in to comment.