Skip to content

Commit

Permalink
Update DisjointSet with generic type constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
subpop committed Dec 22, 2023
1 parent 9cb93b5 commit 397db04
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
14 changes: 7 additions & 7 deletions disjoint_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ package adt

// DisjointSet is a data structure that tracks a set of elements partitioned
// into a number of non-overlapping (disjoint) subsets.
type DisjointSet struct {
Value interface{}
parent *DisjointSet
type DisjointSet[V any] struct {
Value V
parent *DisjointSet[V]
size int
}

// NewDisjointSet returns a disjoint-set initialized to contain only value v.
func NewDisjointSet(v interface{}) *DisjointSet {
d := &DisjointSet{
func NewDisjointSet[V any](v V) *DisjointSet[V] {
d := &DisjointSet[V]{
Value: v,
size: 1,
}
Expand All @@ -19,7 +19,7 @@ func NewDisjointSet(v interface{}) *DisjointSet {
}

// Find finds the root (or representative element) of disjoint-set d.
func (d *DisjointSet) Find() *DisjointSet {
func (d *DisjointSet[V]) Find() *DisjointSet[V] {
if d.parent != d {
d.parent = d.parent.Find()
}
Expand All @@ -28,7 +28,7 @@ func (d *DisjointSet) Find() *DisjointSet {

// Union finds thes the representative element of x and y and merges the
// smaller of the two sets into the other.
func Union(x, y *DisjointSet) *DisjointSet {
func Union[V any](x, y *DisjointSet[V]) *DisjointSet[V] {
xRoot := x.Find()
yRoot := y.Find()

Expand Down
10 changes: 5 additions & 5 deletions disjoint_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ import (

func TestDisjointSet(t *testing.T) {
tests := []struct {
input []*DisjointSet
want *DisjointSet
input []*DisjointSet[string]
want *DisjointSet[string]
}{
{
input: []*DisjointSet{
input: []*DisjointSet[string]{
NewDisjointSet("a"),
NewDisjointSet("b"),
},
want: func() *DisjointSet {
want: func() *DisjointSet[string] {
c := NewDisjointSet("c")
c.size = 3
return c
Expand All @@ -30,7 +30,7 @@ func TestDisjointSet(t *testing.T) {
got = Union(got, s)
}

if !cmp.Equal(got, test.want, cmp.AllowUnexported(DisjointSet{})) {
if !cmp.Equal(got, test.want, cmp.AllowUnexported(DisjointSet[string]{})) {
t.Errorf("%+v != %+v", got, test.want)
}
}
Expand Down

0 comments on commit 397db04

Please sign in to comment.