diff --git a/CHANGELOG.md b/CHANGELOG.md index 1750661809..37f948d092 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ Use error group handling to ensure tests actually pass [#1535](https://github.co Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480) Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) Restore foreign keys and add constraints [#1562](https://github.com/juanfont/headscale/pull/1562) + ## 0.22.3 (2023-05-12) ### Changes diff --git a/hscontrol/auth.go b/hscontrol/auth.go index b75636597f..99c426a766 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -309,7 +309,11 @@ func (h *Headscale) handleAuthKey( Msg("node was already registered before, refreshing with new auth key") node.NodeKey = nodeKey - node.AuthKeyID = uint(pak.ID) + pakId := uint(pak.ID) + if pakId != 0 { + node.AuthKeyID = &pakId + } + err := h.db.NodeSetExpiry(node, registerRequest.Expiry) if err != nil { log.Error(). @@ -364,10 +368,13 @@ func (h *Headscale) handleAuthKey( Expiry: ®isterRequest.Expiry, NodeKey: nodeKey, LastSeen: &now, - AuthKeyID: uint(pak.ID), ForcedTags: pak.Proto().AclTags, } + pakId := uint(pak.ID) + if pakId != 0 { + nodeToRegister.AuthKeyID = &pakId + } node, err = h.db.RegisterNode( nodeToRegister, ) diff --git a/hscontrol/db/addresses_test.go b/hscontrol/db/addresses_test.go index 781fd896f7..4f2387e320 100644 --- a/hscontrol/db/addresses_test.go +++ b/hscontrol/db/addresses_test.go @@ -33,6 +33,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) { _, err = db.GetNode("test", "testnode") c.Assert(err, check.NotNil) + pakId := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: "foo", @@ -41,7 +42,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, IPAddresses: ips, } db.db.Save(&node) @@ -81,6 +82,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { _, err = db.GetNode("test", "testnode") c.Assert(err, check.NotNil) + pakId := uint(pak.ID) node := types.Node{ ID: uint64(index), MachineKey: "foo", @@ -89,7 +91,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, IPAddresses: ips, } db.db.Save(&node) @@ -171,6 +173,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) { _, err = db.GetNode("test", "testnode") c.Assert(err, check.NotNil) + pakId := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: "foo", @@ -179,7 +182,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, } db.db.Save(&node) diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 54b1cd07b2..1c22e5f7fc 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -25,6 +25,7 @@ func (s *Suite) TestGetNode(c *check.C) { _, err = db.GetNode("test", "testnode") c.Assert(err, check.NotNil) + pakId := uint(pak.ID) node := &types.Node{ ID: 0, MachineKey: "foo", @@ -33,9 +34,10 @@ func (s *Suite) TestGetNode(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, } - db.db.Save(node) + tx := db.db.Save(node) + c.Assert(tx.Error, check.IsNil) _, err = db.GetNode("test", "testnode") c.Assert(err, check.IsNil) @@ -51,6 +53,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) { _, err = db.GetNodeByID(0) c.Assert(err, check.NotNil) + pakId := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: "foo", @@ -59,9 +62,10 @@ func (s *Suite) TestGetNodeByID(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, } - db.db.Save(&node) + tx := db.db.Save(&node) + c.Assert(tx.Error, check.IsNil) _, err = db.GetNodeByID(0) c.Assert(err, check.IsNil) @@ -80,6 +84,7 @@ func (s *Suite) TestGetNodeByNodeKey(c *check.C) { nodeKey := key.NewNode() machineKey := key.NewMachine() + pakId := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), @@ -88,9 +93,10 @@ func (s *Suite) TestGetNodeByNodeKey(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, } - db.db.Save(&node) + tx := db.db.Save(&node) + c.Assert(tx.Error, check.IsNil) _, err = db.GetNodeByNodeKey(nodeKey.Public()) c.Assert(err, check.IsNil) @@ -111,6 +117,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { machineKey := key.NewMachine() + pakId := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), @@ -119,9 +126,10 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, } - db.db.Save(&node) + tx := db.db.Save(&node) + c.Assert(tx.Error, check.IsNil) _, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) c.Assert(err, check.IsNil) @@ -138,9 +146,9 @@ func (s *Suite) TestHardDeleteNode(c *check.C) { Hostname: "testnode3", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(1), } - db.db.Save(&node) + tx := db.db.Save(&node) + c.Assert(tx.Error, check.IsNil) err = db.DeleteNode(&node) c.Assert(err, check.IsNil) @@ -159,6 +167,7 @@ func (s *Suite) TestListPeers(c *check.C) { _, err = db.GetNodeByID(0) c.Assert(err, check.NotNil) + pakId := uint(pak.ID) for index := 0; index <= 10; index++ { node := types.Node{ ID: uint64(index), @@ -168,9 +177,10 @@ func (s *Suite) TestListPeers(c *check.C) { Hostname: "testnode" + strconv.Itoa(index), UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, } - db.db.Save(&node) + tx := db.db.Save(&node) + c.Assert(tx.Error, check.IsNil) } node0ByID, err := db.GetNodeByID(0) @@ -205,6 +215,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { c.Assert(err, check.NotNil) for index := 0; index <= 10; index++ { + pakId := uint(stor[index%2].key.ID) node := types.Node{ ID: uint64(index), MachineKey: "foo" + strconv.Itoa(index), @@ -216,9 +227,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { Hostname: "testnode" + strconv.Itoa(index), UserID: stor[index%2].user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(stor[index%2].key.ID), + AuthKeyID: &pakId, } - db.db.Save(&node) + tx := db.db.Save(&node) + c.Assert(tx.Error, check.IsNil) } aclPolicy := &policy.ACLPolicy{ @@ -288,6 +300,7 @@ func (s *Suite) TestExpireNode(c *check.C) { _, err = db.GetNode("test", "testnode") c.Assert(err, check.NotNil) + pakId := uint(pak.ID) node := &types.Node{ ID: 0, MachineKey: "foo", @@ -296,7 +309,7 @@ func (s *Suite) TestExpireNode(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, Expiry: &time.Time{}, } db.db.Save(node) @@ -345,6 +358,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { _, err = db.GetNode("user-1", "testnode") c.Assert(err, check.NotNil) + pakId := uint(pak.ID) node := &types.Node{ ID: 0, MachineKey: "node-key-1", @@ -354,9 +368,11 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { GivenName: "hostname-1", UserID: user1.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, } - db.db.Save(node) + + tx := db.db.Save(node) + c.Assert(tx.Error, check.IsNil) givenName, err := db.GenerateGivenName("node-key-2", "hostname-2") comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict") @@ -389,6 +405,7 @@ func (s *Suite) TestSetTags(c *check.C) { _, err = db.GetNode("test", "testnode") c.Assert(err, check.NotNil) + pakId := uint(pak.ID) node := &types.Node{ ID: 0, MachineKey: "foo", @@ -397,9 +414,11 @@ func (s *Suite) TestSetTags(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, } - db.db.Save(node) + + tx := db.db.Save(node) + c.Assert(tx.Error, check.IsNil) // assign simple tags sTags := []string{"tag:test", "tag:foo"} @@ -572,6 +591,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { // Check if a subprefix of an autoapproved route is approved route2 := netip.MustParsePrefix("10.11.0.0/24") + pakId := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: "foo", @@ -580,7 +600,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { Hostname: "test", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, HostInfo: types.HostInfo{ RequestTags: []string{"tag:exit"}, RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, @@ -588,7 +608,8 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, } - db.db.Save(&node) + tx := db.db.Save(&node) + c.Assert(tx.Error, check.IsNil) err = db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index e743988fd8..de7bb04a9d 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -204,9 +204,10 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) } nodes := types.Nodes{} + pakId := uint(pak.ID) if err := hsdb.db. Preload("AuthKey"). - Where(&types.Node{AuthKeyID: uint(pak.ID)}). + Where(&types.Node{AuthKeyID: &pakId}). Find(&nodes).Error; err != nil { return nil, err } diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 9bf8c89271..0250e0b4d7 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -75,6 +75,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) + pakId := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: "foo", @@ -83,9 +84,10 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { Hostname: "testest", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, } - db.db.Save(&node) + tx := db.db.Save(&node) + c.Assert(tx.Error, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) @@ -99,6 +101,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) + pakId := uint(pak.ID) node := types.Node{ ID: 1, MachineKey: "foo", @@ -107,9 +110,10 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { Hostname: "testest", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, } - db.db.Save(&node) + tx := db.db.Save(&node) + c.Assert(tx.Error, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.IsNil) @@ -136,6 +140,7 @@ func (*Suite) TestEphemeralKey(c *check.C) { c.Assert(err, check.IsNil) now := time.Now().Add(-time.Second * 30) + pakId := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: "foo", @@ -145,9 +150,10 @@ func (*Suite) TestEphemeralKey(c *check.C) { UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, LastSeen: &now, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, } - db.db.Save(&node) + tx := db.db.Save(&node) + c.Assert(tx.Error, check.IsNil) _, err = db.ValidatePreAuthKey(pak.Key) // Ephemeral keys are by definition reusable diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index ba5882b591..b7066b2dfe 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -27,6 +27,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { RoutableIPs: []netip.Prefix{route}, } + pakId := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: "foo", @@ -35,10 +36,11 @@ func (s *Suite) TestGetRoutes(c *check.C) { Hostname: "test_get_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, HostInfo: types.HostInfo(hostInfo), } - db.db.Save(&node) + tx := db.db.Save(&node) + c.Assert(tx.Error, check.IsNil) err = db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) @@ -78,6 +80,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { RoutableIPs: []netip.Prefix{route, route2}, } + pakId := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: "foo", @@ -86,10 +89,11 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, HostInfo: types.HostInfo(hostInfo), } - db.db.Save(&node) + tx := db.db.Save(&node) + c.Assert(tx.Error, check.IsNil) err = db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) @@ -152,6 +156,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { hostInfo1 := tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{route, route2}, } + pakId := uint(pak.ID) node1 := types.Node{ ID: 1, MachineKey: "foo", @@ -160,10 +165,11 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, HostInfo: types.HostInfo(hostInfo1), } - db.db.Save(&node1) + tx := db.db.Save(&node1) + c.Assert(tx.Error, check.IsNil) err = db.SaveNodeRoutes(&node1) c.Assert(err, check.IsNil) @@ -185,10 +191,11 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, HostInfo: types.HostInfo(hostInfo2), } - db.db.Save(&node2) + tx = db.db.Save(&node2) + c.Assert(tx.Error, check.IsNil) err = db.SaveNodeRoutes(&node2) c.Assert(err, check.IsNil) @@ -238,6 +245,7 @@ func (s *Suite) TestSubnetFailover(c *check.C) { } now := time.Now() + pakId := uint(pak.ID) node1 := types.Node{ ID: 1, MachineKey: "foo", @@ -246,11 +254,12 @@ func (s *Suite) TestSubnetFailover(c *check.C) { Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, HostInfo: types.HostInfo(hostInfo1), LastSeen: &now, } - db.db.Save(&node1) + tx := db.db.Save(&node1) + c.Assert(tx.Error, check.IsNil) err = db.SaveNodeRoutes(&node1) c.Assert(err, check.IsNil) @@ -283,11 +292,12 @@ func (s *Suite) TestSubnetFailover(c *check.C) { Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, HostInfo: types.HostInfo(hostInfo2), LastSeen: &now, } - db.db.Save(&node2) + tx = db.db.Save(&node2) + c.Assert(tx.Error, check.IsNil) err = db.saveNodeRoutes(&node2) c.Assert(err, check.IsNil) @@ -380,6 +390,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { } now := time.Now() + pakId := uint(pak.ID) node1 := types.Node{ ID: 1, MachineKey: "foo", @@ -388,11 +399,12 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, HostInfo: types.HostInfo(hostInfo1), LastSeen: &now, } - db.db.Save(&node1) + tx := db.db.Save(&node1) + c.Assert(tx.Error, check.IsNil) err = db.SaveNodeRoutes(&node1) c.Assert(err, check.IsNil) diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 0c43b979ba..00ae4403d5 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -46,6 +46,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) + pakId := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: "foo", @@ -54,9 +55,10 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, } - db.db.Save(&node) + tx := db.db.Save(&node) + c.Assert(tx.Error, check.IsNil) err = db.DestroyUser("test") c.Assert(err, check.Equals, ErrUserStillHasNodes) @@ -101,6 +103,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) { pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) c.Assert(err, check.IsNil) + pakId := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: "foo", @@ -109,9 +112,10 @@ func (s *Suite) TestSetMachineUser(c *check.C) { Hostname: "testnode", UserID: oldUser.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: &pakId, } - db.db.Save(&node) + tx := db.db.Save(&node) + c.Assert(tx.Error, check.IsNil) c.Assert(node.UserID, check.Equals, oldUser.ID) err = db.AssignNodeToUser(&node, newUser.Name) diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 84cc9344f1..07e2197acc 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -176,7 +176,6 @@ func Test_fullMapResponse(t *testing.T) { UserID: 0, User: types.User{Name: "mini"}, ForcedTags: []string{}, - AuthKeyID: 0, AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, Expiry: &expire, diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 13d025c186..5a1376734e 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -78,7 +78,6 @@ func TestTailNode(t *testing.T) { Name: "mini", }, ForcedTags: []string{}, - AuthKeyID: 0, AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, Expiry: &expire, diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 11d7b2e541..4898308979 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -50,7 +50,7 @@ type Node struct { ForcedTags StringList // TODO(kradalby): This seems like irrelevant information? - AuthKeyID uint + AuthKeyID *uint `sql:"DEFAULT:NULL"` AuthKey *PreAuthKey `gorm:"constraint:OnDelete:SET NULL;"` LastSeen *time.Time