diff --git a/factory.go b/factory.go index 13d9620..6ecac1d 100644 --- a/factory.go +++ b/factory.go @@ -43,7 +43,7 @@ func createMetricsReceiver( ) (component.MetricsReceiver, error) { wireguardConfig := config.(*Config) - dsr, err := newReceiver(wireguardConfig, params, consumer) + dsr, err := newReceiver(wireguardConfig, params, consumer, nil) if err != nil { return nil, err } diff --git a/receiver.go b/receiver.go index e86c9ac..b8021aa 100644 --- a/receiver.go +++ b/receiver.go @@ -2,45 +2,53 @@ package wireguardreceiver import ( "context" - "sync" "time" "go.opentelemetry.io/collector/component" "go.opentelemetry.io/collector/consumer" "go.opentelemetry.io/collector/pdata/pmetric" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "go.opentelemetry.io/collector/receiver/scraperhelper" ) type receiver struct { - config *Config - wgClient wireguardClient + config *Config + wgClient wireguardClient + clientFactory clientFactory } -func newReceiver(config *Config, set component.ReceiverCreateSettings, nextConsumer consumer.Metrics) (component.MetricsReceiver, error) { +func newReceiver(config *Config, set component.ReceiverCreateSettings, nextConsumer consumer.Metrics, clientFactory clientFactory) (component.MetricsReceiver, error) { err := config.Validate() if err != nil { return nil, err } - client, err := newWireguardClient() - if err != nil { - return nil, err + if clientFactory == nil { + clientFactory = newWireguardClient } recv := &receiver{ - config: config, - wgClient: client, + config: config, + clientFactory: clientFactory, } - scrp, err := scraperhelper.NewScraper("wireguardReceiver", recv.scrape) + scrp, err := scraperhelper.NewScraper(typeStr, recv.scrape, scraperhelper.WithStart(recv.start)) if err != nil { return nil, err } return scraperhelper.NewScraperControllerReceiver(&recv.config.ScraperControllerSettings, set, nextConsumer, scraperhelper.AddScraper(scrp)) } +func (r *receiver) start(_ context.Context, _ component.Host) error { + var err error + r.wgClient, err = r.clientFactory() + if err != nil { + return err + } + + return nil +} + func (r *receiver) scrape(ctx context.Context) (pmetric.Metrics, error) { md := pmetric.NewMetrics() @@ -49,24 +57,10 @@ func (r *receiver) scrape(ctx context.Context) (pmetric.Metrics, error) { return md, err } - results := make(chan pmetric.Metrics) - - wg := &sync.WaitGroup{} - wg.Add(len(devices)) for _, d := range devices { - go func(d *wgtypes.Device) { - defer wg.Done() - for _, peer := range d.Peers { - results <- peerToMetrics(time.Now(), d.Name, &peer) - } - }(d) - } - - wg.Wait() - close(results) - - for res := range results { - res.ResourceMetrics().CopyTo(md.ResourceMetrics()) + for _, peer := range d.Peers { + peerToMetrics(time.Now(), d.Name, &peer).ResourceMetrics().CopyTo(md.ResourceMetrics()) + } } return md, nil diff --git a/receiver_test.go b/receiver_test.go new file mode 100644 index 0000000..cef6490 --- /dev/null +++ b/receiver_test.go @@ -0,0 +1,84 @@ +package wireguardreceiver + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/collector/component/componenttest" + "go.opentelemetry.io/collector/consumer" + "go.opentelemetry.io/collector/pdata/pmetric" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +func genDevice() (*wgtypes.Device, error) { + peer, err := getPeer() + if err != nil { + return nil, err + } + return &wgtypes.Device{ + Name: "wg0", + Peers: []wgtypes.Peer{*peer}, + }, nil +} + +func TestScrape(t *testing.T) { + cfg := createDefaultConfig() + cfg.CollectionInterval = 100 * time.Millisecond + + client := make(mockClient) + consumer := make(mockConsumer) + + r, err := newReceiver(cfg, componenttest.NewNopReceiverCreateSettings(), consumer, client.factory) + require.NoError(t, err) + assert.NotNil(t, r) + + device, err := genDevice() + require.NoError(t, err) + + go func() { + client <- deviceResult{ + devices: []*wgtypes.Device{device}, + err: nil, + } + }() + + assert.NoError(t, r.Start(context.Background(), componenttest.NewNopHost())) + + md := <-consumer + assert.Equal(t, md.ResourceMetrics().Len(), 1) + + assert.NoError(t, r.Shutdown(context.Background())) +} + +type deviceResult struct { + err error + devices []*wgtypes.Device +} + +type mockClient chan deviceResult + +func (c mockClient) factory() (wireguardClient, error) { + return c, nil +} + +func (c mockClient) Devices() ([]*wgtypes.Device, error) { + report := <-c + if report.err != nil { + return nil, report.err + } + return report.devices, nil +} + +type mockConsumer chan pmetric.Metrics + +func (m mockConsumer) Capabilities() consumer.Capabilities { + return consumer.Capabilities{} +} + +func (m mockConsumer) ConsumeMetrics(ctx context.Context, md pmetric.Metrics) error { + m <- md + return nil +} diff --git a/wireguard.go b/wireguard.go index 4538633..3eff0a7 100644 --- a/wireguard.go +++ b/wireguard.go @@ -15,6 +15,8 @@ var ( remoteEndpointAttributes = 9 ) +type clientFactory func() (wireguardClient, error) + type wireguardClient interface { Devices() ([]*wgtypes.Device, error) }