diff --git a/app/demo/dispatch/dispatch.go b/app/demo/dispatch/dispatch.go index bcdb8616..4f848e32 100644 --- a/app/demo/dispatch/dispatch.go +++ b/app/demo/dispatch/dispatch.go @@ -12,18 +12,16 @@ import ( "encoding/json" "flag" "fmt" - "io" - "net" - "net/http" - "os" - "strings" - "time" - "github.com/q191201771/lal/app/demo/dispatch/datamanager" "github.com/q191201771/lal/pkg/base" "github.com/q191201771/naza/pkg/nazahttp" "github.com/q191201771/naza/pkg/nazalog" "github.com/q191201771/naza/pkg/unique" + "io" + "net" + "net/http" + "os" + "strings" ) // @@ -122,21 +120,7 @@ func OnSubStartHandler(w http.ResponseWriter, r *http.Request) { nazalog.Assert(true, exist) // 向汇报节点,发送pull级联拉流的命令,其中包含pub所在节点信息 - // TODO(chef): start_relay_pull封装成函数,所有的http请求都应该封装成函数 202405 - // TODO(chef): 还没有测试新的接口start_relay_pull,只是保证可以编译通过 - url := fmt.Sprintf("http://%s/api/ctrl/start_relay_pull", reqServer.ApiAddr) - var b base.ApiCtrlStartRelayPullReq - b.Url = fmt.Sprintf("%s://%s/%s/%s?%s", "rtmp", pubServer.RtmpAddr, info.AppName, info.StreamName, config.PullSecretParam) - //b.Protocol = base.ProtocolRtmp - //b.Addr = pubServer.RtmpAddr - //b.AppName = info.AppName - //b.StreamName = info.StreamName - //b.UrlParam = config.PullSecretParam - - nazalog.Infof("[%s] ctrl pull. send to %s with %+v", id, reqServer.ApiAddr, b) - if _, err := nazahttp.PostJson(url, b, nil); err != nil { - nazalog.Errorf("[%s] post json error. err=%+v", id, err) - } + startRelayPull(id, reqServer.ApiAddr, pubServer.RtmpAddr, info.AppName, info.StreamName) } func OnSubStopHandler(w http.ResponseWriter, r *http.Request) { @@ -169,65 +153,9 @@ func OnUpdateHandler(w http.ResponseWriter, r *http.Request) { } dataManager.UpdatePub(info.ServerId, streamNameList) - if config.MaxSubSessionPerIp > 0 { - ip2SubSessions := make(map[string][]base.StatSub) - sessionId2StreamName := make(map[string]string) - for _, g := range info.Groups { - for _, sub := range g.StatSubs { - host, _, err := net.SplitHostPort(sub.RemoteAddr) - if err != nil { - nazalog.Warnf("split host port failed. remote addr=%s", sub.RemoteAddr) - continue - } - ip2SubSessions[host] = append(ip2SubSessions[host], sub) - sessionId2StreamName[sub.SessionId] = g.StreamName - } - } - for ip, subs := range ip2SubSessions { - if len(subs) > config.MaxSubSessionPerIp { - nazalog.Debugf("close session. ip=%s, session count=%d", ip, len(subs)) - for _, sub := range subs { - if sub.Protocol == base.SessionProtocolHlsStr { - host, _, err := net.SplitHostPort(sub.RemoteAddr) - if err != nil { - nazalog.Warnf("split host port failed. remote addr=%s", sub.RemoteAddr) - continue - } - addIpBlacklist(info.ServerId, host, 60) - } else { - kickSession(info.ServerId, sessionId2StreamName[sub.SessionId], sub.SessionId) - } - } - } - } - } + securityMaxSubSessionPerIp(info) - if config.MaxSubDurationSec > 0 { - now := time.Now() - for _, g := range info.Groups { - for _, sub := range g.StatSubs { - st, err := base.ParseReadableTime(sub.StartTime) - if err != nil { - nazalog.Warnf("parse readable time failed. start time=%s, err=%+v", sub.StartTime, err) - continue - } - diff := int(now.Sub(st).Seconds()) - if diff > config.MaxSubDurationSec { - nazalog.Infof("close session. sub session start time=%s, diff=%d", sub.StartTime, diff) - if sub.Protocol == base.SessionProtocolHlsStr { - host, _, err := net.SplitHostPort(sub.RemoteAddr) - if err != nil { - nazalog.Warnf("split host port failed. remote addr=%s", sub.RemoteAddr) - continue - } - addIpBlacklist(info.ServerId, host, 60) - } else { - kickSession(info.ServerId, g.StreamName, sub.SessionId) - } - } - } - } - } + securityMaxSubDurationSec(info) } func logHandler(w http.ResponseWriter, r *http.Request) { diff --git a/app/demo/dispatch/dispatch__security.go b/app/demo/dispatch/dispatch__security.go new file mode 100644 index 00000000..31fb7ca5 --- /dev/null +++ b/app/demo/dispatch/dispatch__security.go @@ -0,0 +1,80 @@ +// Copyright 2024, Chef. All rights reserved. +// https://github.com/q191201771/lal +// +// Use of this source code is governed by a MIT-style license +// that can be found in the License file. +// +// Author: Chef (191201771@qq.com) + +package main + +import ( + "github.com/q191201771/lal/pkg/base" + "github.com/q191201771/naza/pkg/nazalog" + "net" + "time" +) + +func securityMaxSubSessionPerIp(info base.UpdateInfo) { + if config.MaxSubSessionPerIp > 0 { + ip2SubSessions := make(map[string][]base.StatSub) + sessionId2StreamName := make(map[string]string) + for _, g := range info.Groups { + for _, sub := range g.StatSubs { + host, _, err := net.SplitHostPort(sub.RemoteAddr) + if err != nil { + nazalog.Warnf("split host port failed. remote addr=%s", sub.RemoteAddr) + continue + } + ip2SubSessions[host] = append(ip2SubSessions[host], sub) + sessionId2StreamName[sub.SessionId] = g.StreamName + } + } + for ip, subs := range ip2SubSessions { + if len(subs) > config.MaxSubSessionPerIp { + nazalog.Debugf("close session. ip=%s, session count=%d", ip, len(subs)) + for _, sub := range subs { + if sub.Protocol == base.SessionProtocolHlsStr { + host, _, err := net.SplitHostPort(sub.RemoteAddr) + if err != nil { + nazalog.Warnf("split host port failed. remote addr=%s", sub.RemoteAddr) + continue + } + addIpBlacklist(info.ServerId, host, 60) + } else { + kickSession(info.ServerId, sessionId2StreamName[sub.SessionId], sub.SessionId) + } + } + } + } + } +} + +func securityMaxSubDurationSec(info base.UpdateInfo) { + if config.MaxSubDurationSec > 0 { + now := time.Now() + for _, g := range info.Groups { + for _, sub := range g.StatSubs { + st, err := base.ParseReadableTime(sub.StartTime) + if err != nil { + nazalog.Warnf("parse readable time failed. start time=%s, err=%+v", sub.StartTime, err) + continue + } + diff := int(now.Sub(st).Seconds()) + if diff > config.MaxSubDurationSec { + nazalog.Infof("close session. sub session start time=%s, diff=%d", sub.StartTime, diff) + if sub.Protocol == base.SessionProtocolHlsStr { + host, _, err := net.SplitHostPort(sub.RemoteAddr) + if err != nil { + nazalog.Warnf("split host port failed. remote addr=%s", sub.RemoteAddr) + continue + } + addIpBlacklist(info.ServerId, host, 60) + } else { + kickSession(info.ServerId, g.StreamName, sub.SessionId) + } + } + } + } + } +} diff --git a/app/demo/dispatch/http_api_client.go b/app/demo/dispatch/http_api_client.go index 743a72c6..c921e90c 100644 --- a/app/demo/dispatch/http_api_client.go +++ b/app/demo/dispatch/http_api_client.go @@ -31,7 +31,6 @@ func kickSession(serverId, streamName, sessionId string) { if _, err := nazahttp.PostJson(url, b, nil); err != nil { nazalog.Errorf("[%s] post json error. err=%+v", serverId, err) } - return } func addIpBlacklist(serverId, ip string, durationSec int) { @@ -50,5 +49,21 @@ func addIpBlacklist(serverId, ip string, durationSec int) { if _, err := nazahttp.PostJson(url, b, nil); err != nil { nazalog.Errorf("[%s] post json error. err=%+v", serverId, err) } - return +} + +func startRelayPull(reqId, reqApiAddr, pubRtmpAddr, appName, streamName string) { + // TODO(chef): 还没有测试新的接口start_relay_pull,只是保证可以编译通过 + url := fmt.Sprintf("http://%s/api/ctrl/start_relay_pull", reqApiAddr) + var b base.ApiCtrlStartRelayPullReq + b.Url = fmt.Sprintf("%s://%s/%s/%s?%s", "rtmp", pubRtmpAddr, appName, streamName, config.PullSecretParam) + //b.Protocol = base.ProtocolRtmp + //b.Addr = pubServer.RtmpAddr + //b.AppName = info.AppName + //b.StreamName = info.StreamName + //b.UrlParam = config.PullSecretParam + + nazalog.Infof("[%s] startRelayPull. send to %s with %+v", reqId, reqApiAddr, b) + if _, err := nazahttp.PostJson(url, b, nil); err != nil { + nazalog.Errorf("[%s] post json error. err=%+v", reqId, err) + } }