diff --git a/logger.go b/logger.go index 2cff32f..7c7ed5a 100644 --- a/logger.go +++ b/logger.go @@ -54,8 +54,8 @@ func (logger *Logger) WithFields(fields logrus.Fields) *Entry { return NewEntry(logger).WithFields(fields) } -func (logger *Logger) With(o Identifiable) *Entry { - return NewEntry(logger).With(o) +func (logger *Logger) With(os ...Identifiable) *Entry { + return NewEntry(logger).With(os...) } func (logger *Logger) WithTime(t time.Time) *Entry { @@ -119,8 +119,17 @@ func (entry *Entry) WithFields(fields logrus.Fields) *Entry { return wrapEntry(entry.Entry.WithFields(fields)) } -func (entry *Entry) With(o Identifiable) *Entry { - return wrapEntry(entry.Entry.WithFields(o.LogIdentity())) +func (entry *Entry) With(os ...Identifiable) *Entry { + merged := map[string]any{} + + for _, o := range os { + fields := o.LogIdentity() + for k, v := range fields { + merged[k] = v + } + } + + return wrapEntry(entry.Entry.WithFields(merged)) } func (entry *Entry) WithTime(t time.Time) *Entry { diff --git a/logger_test.go b/logger_test.go index f51de81..4b990e6 100644 --- a/logger_test.go +++ b/logger_test.go @@ -71,10 +71,12 @@ func Test_Logger_WithError(t *testing.T) { a.Regexp(`^time.* level\=error msg\=oops error\="found an error: an error occurred"`, b.String()) } -type identifiable struct{} +type identifiable struct { + value string +} func (i identifiable) LogIdentity() map[string]any { - return map[string]any{"field": "value"} + return map[string]any{"field": i.value} } func Test_Logger_With(t *testing.T) { @@ -84,10 +86,24 @@ func Test_Logger_With(t *testing.T) { logrusLogger.Out = &out logger := Logger{Logger: logrusLogger} - i := identifiable{} + i := identifiable{value: "__value__"} logger.With(i).Info("message") - a.Contains(out.String(), "field=value") + a.Contains(out.String(), "field=__value__") +} + +func Test_Logger_With_MergesMultipleObjects(t *testing.T) { + a := assert.New(t) + var out strings.Builder + logrusLogger := logrus.New() + logrusLogger.Out = &out + logger := Logger{Logger: logrusLogger} + + logger. + With(identifiable{value: "__overwritten__"}, identifiable{value: "__value__"}). + Info("message") + + a.Contains(out.String(), "field=__value__") } func Test_Logger_Call(t *testing.T) {