Skip to content

Commit

Permalink
create session content in the context if do not exist yet
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaokangwang committed Mar 6, 2021
1 parent b585f22 commit 867bbb4
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
2 changes: 1 addition & 1 deletion app/dispatcher/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
var handler outbound.Handler

if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" {
session.SetForcedOutboundTagToContext(ctx, "")
ctx = session.SetForcedOutboundTagToContext(ctx, "")
if h := d.ohm.GetHandler(forcedOutboundTag); h != nil {
newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
handler = h
Expand Down
16 changes: 13 additions & 3 deletions common/session/context.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package session

import "context"
import (
"context"
)

type sessionKey int

Expand Down Expand Up @@ -92,8 +94,12 @@ func GetTransportLayerProxyTagFromContext(ctx context.Context) string {
return ContentFromContext(ctx).Attribute("transportLayerOutgoingTag")
}

func SetTransportLayerProxyTagToContext(ctx context.Context, tag string) {
func SetTransportLayerProxyTagToContext(ctx context.Context, tag string) context.Context {
if contentFromContext := ContentFromContext(ctx); contentFromContext == nil {
ctx = ContextWithContent(ctx, &Content{})
}
ContentFromContext(ctx).SetAttribute("transportLayerOutgoingTag", tag)
return ctx
}

func GetForcedOutboundTagFromContext(ctx context.Context) string {
Expand All @@ -103,6 +109,10 @@ func GetForcedOutboundTagFromContext(ctx context.Context) string {
return ContentFromContext(ctx).Attribute("forcedOutboundTag")
}

func SetForcedOutboundTagToContext(ctx context.Context, tag string) {
func SetForcedOutboundTagToContext(ctx context.Context, tag string) context.Context {
if contentFromContext := ContentFromContext(ctx); contentFromContext == nil {
ctx = ContextWithContent(ctx, &Content{})
}
ContentFromContext(ctx).SetAttribute("forcedOutboundTag", tag)
return ctx
}
2 changes: 1 addition & 1 deletion transport/internet/tagged/taggedimpl/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func DialTaggedOutbound(ctx context.Context, dest net.Destination, tag string) (
content.SkipDNSResolve = true

ctx = session.ContextWithContent(ctx, content)
session.SetForcedOutboundTagToContext(ctx, tag)
ctx = session.SetForcedOutboundTagToContext(ctx, tag)

r, err := dispatcher.Dispatch(ctx, dest)
if err != nil {
Expand Down

0 comments on commit 867bbb4

Please sign in to comment.