Permalink
Browse files

Merge pull request #55 from Arxisos/master

Filter attributes
  • Loading branch information...
2 parents 94d9a48 + 94c3154 commit b629abcdabe000ae42a86d418e5f2cb876837404 @mythz mythz committed Jan 1, 2012
@@ -0,0 +1,22 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace ServiceStack.ServiceHost
+{
+ /// <summary>
+ /// This interface can be implemented by an attribute
+ /// which adds an request filter for the specific request DTO the attribute marked.
+ /// </summary>
+ public interface IHasRequestFilter
+ {
+ /// <summary>
+ /// The request filter is executed before the service.
+ /// </summary>
+ /// <param name="req">The http request wrapper</param>
+ /// <param name="res">The http response wrapper</param>
+ /// <param name="requestDto">The request DTO</param>
+ void RequestFilter(IHttpRequest req, IHttpResponse res, object requestDto);
+ }
+}
@@ -0,0 +1,22 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace ServiceStack.ServiceHost
+{
+ /// <summary>
+ /// This interface can be implemented by an attribute
+ /// which adds an response filter for the specific response DTO the attribute marked.
+ /// </summary>
+ public interface IHasResponseFilter
+ {
+ /// <summary>
+ /// The response filter is executed after the service
+ /// </summary>
+ /// <param name="req">The http request wrapper</param>
+ /// <param name="res">The http response wrapper</param>
+ /// <param name="requestDto">The response DTO</param>
+ void ResponseFilter(IHttpRequest req, IHttpResponse res, object responseDto);
+ }
+}
@@ -211,6 +211,8 @@
<Compile Include="SearchIndex\FullTextIndexAttribute.cs" />
<Compile Include="SearchIndex\FullTextIndexDocumentAttribute.cs" />
<Compile Include="SearchIndex\FullTextIndexFieldAttribute.cs" />
+ <Compile Include="ServiceHost\IHasRequestFilter.cs" />
+ <Compile Include="ServiceHost\IHasResponseFilter.cs" />
<Compile Include="ServiceHost\IRestPatchService.cs" />
<Compile Include="ServiceHost\EndpointAttributes.cs" />
<Compile Include="ServiceHost\Feature.cs" />
@@ -11,6 +11,7 @@
using ServiceStack.Text;
using ServiceStack.WebHost.Endpoints;
using ServiceStack.WebHost.Endpoints.Extensions;
+using ServiceStack.Text;
namespace ServiceStack.ServiceInterface.Auth
{
@@ -79,43 +80,6 @@ public static void Init(IAppHost appHost, Func<IAuthSession> sessionFactory, par
appHost.RegisterService<AuthService>();
SessionFeature.Init(appHost);
-
- appHost.RequestFilters.Add((req, res, dto) => {
- var requiresAuth = dto.GetType().FirstAttribute<AuthenticateAttribute>();
-
- if (requiresAuth != null)
- {
- ApplyTo httpMethod = req.HttpMethodAsApplyTo();
- if(requiresAuth.ApplyTo.Has(httpMethod))
- {
- var matchingOAuthConfigs = AuthConfigs.Where(x =>
- requiresAuth.Provider.IsNullOrEmpty()
- || x.Provider == requiresAuth.Provider).ToList();
-
- if (matchingOAuthConfigs.Count == 0)
- {
- res.WriteError(req, dto, "No OAuth Configs found matching {0} provider"
- .Fmt(requiresAuth.Provider ?? "any"));
- res.Close();
- return;
- }
-
- using (var cache = appHost.GetCacheClient())
- {
- var sessionId = req.GetPermanentSessionId();
- var session = sessionId != null ? cache.GetSession(sessionId) : null;
-
- if (session == null || !matchingOAuthConfigs.Any(x => session.IsAuthorized(x.Provider)))
- {
- res.StatusCode = (int)HttpStatusCode.Unauthorized;
- res.AddHeader(HttpHeaders.WwwAuthenticate, "OAuth realm=\"{0}\"".Fmt(matchingOAuthConfigs[0].AuthRealm));
- res.Close();
- return;
- }
- }
- }
- }
- });
}
private void AssertAuthProviders()
@@ -1,4 +1,13 @@
using System;
+using System.Linq;
+using System.Net;
+using ServiceStack.CacheAccess;
+using ServiceStack.Common;
+using ServiceStack.Common.Web;
+using ServiceStack.ServiceHost;
+using ServiceStack.ServiceInterface.Auth;
+using ServiceStack.Text;
+using ServiceStack.WebHost.Endpoints.Extensions;
namespace ServiceStack.ServiceInterface
{
@@ -7,30 +16,63 @@ namespace ServiceStack.ServiceInterface
/// requires authentication.
/// </summary>
[AttributeUsage(AttributeTargets.Class, Inherited = false, AllowMultiple = false)]
- public class AuthenticateAttribute : Attribute
+ public class AuthenticateAttribute : RequestFilterAttribute
{
public string Provider { get; set; }
public ApplyTo ApplyTo { get; set; }
public AuthenticateAttribute()
- : this(ApplyTo.All)
+ : base(ApplyTo.All)
{
}
public AuthenticateAttribute(string provider)
- : this(ApplyTo.All, provider)
+ : base(ApplyTo.All)
{
+ this.Provider = provider;
}
public AuthenticateAttribute(ApplyTo applyTo)
+ : base(applyTo)
{
- this.ApplyTo = applyTo;
}
public AuthenticateAttribute(ApplyTo applyTo, string provider)
+ : base(applyTo)
{
this.Provider = provider;
- this.ApplyTo = applyTo;
}
- }
+
+ public override void Execute(IHttpRequest req, IHttpResponse res, object requestDto)
+ {
+ if (AuthService.AuthConfigs == null) throw new InvalidOperationException("The AuthService must be initialized by calling "
+ + "AuthService.Init to use an authenticate attribute");
+
+ var matchingOAuthConfigs = AuthService.AuthConfigs.Where(x =>
+ this.Provider.IsNullOrEmpty()
+ || x.Provider == this.Provider).ToList();
+
+ if (matchingOAuthConfigs.Count == 0)
+ {
+ res.WriteError(req, requestDto, "No OAuth Configs found matching {0} provider"
+ .Fmt(this.Provider ?? "any"));
+ res.Close();
+ return;
+ }
+
+ using (var cache = req.GetCacheClient())
+ {
+ var sessionId = req.GetPermanentSessionId();
+ var session = sessionId != null ? cache.GetSession(sessionId) : null;
+
+ if (session == null || !matchingOAuthConfigs.Any(x => session.IsAuthorized(x.Provider)))
+ {
+ res.StatusCode = (int)HttpStatusCode.Unauthorized;
+ res.AddHeader(HttpHeaders.WwwAuthenticate, "OAuth realm=\"{0}\"".Fmt(matchingOAuthConfigs[0].AuthRealm));
+ res.Close();
+ return;
+ }
+ }
+ }
+ }
}
@@ -0,0 +1,46 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using ServiceStack.ServiceHost;
+using ServiceStack.Common;
+
+namespace ServiceStack.ServiceInterface
+{
+ /// <summary>
+ /// Base class to create request filter attributes only for specific HTTP methods (GET, POST...)
+ /// </summary>
+ public abstract class RequestFilterAttribute : Attribute, IHasRequestFilter
+ {
+ public ApplyTo ApplyTo { get; set; }
+
+ public RequestFilterAttribute()
+ {
+ ApplyTo = ApplyTo.All;
+ }
+
+ /// <summary>
+ /// Creates a new <see cref="RequestFilterAttribute"/>
+ /// </summary>
+ /// <param name="applyTo">Defines when the filter should be executed</param>
+ public RequestFilterAttribute(ApplyTo applyTo)
+ {
+ ApplyTo = applyTo;
+ }
+
+ public void RequestFilter(IHttpRequest req, IHttpResponse res, object requestDto)
+ {
+ ApplyTo httpMethod = req.HttpMethodAsApplyTo();
+ if (ApplyTo.Has(httpMethod))
+ this.Execute(req, res, requestDto);
+ }
+
+ /// <summary>
+ /// This method is only executed if the HTTP method matches the <see cref="ApplyTo"/> property.
+ /// </summary>
+ /// <param name="req">The http request wrapper</param>
+ /// <param name="res">The http response wrapper</param>
+ /// <param name="requestDto">The request DTO</param>
+ public abstract void Execute(IHttpRequest req, IHttpResponse res, object requestDto);
+ }
+}
@@ -0,0 +1,46 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using ServiceStack.ServiceHost;
+using ServiceStack.Common;
+
+namespace ServiceStack.ServiceInterface
+{
+ /// <summary>
+ /// Base class to create response filter attributes only for specific HTTP methods (GET, POST...)
+ /// </summary>
+ public abstract class ResponseFilterAttribute : Attribute, IHasResponseFilter
+ {
+ public ApplyTo ApplyTo { get; set; }
+
+ public ResponseFilterAttribute()
+ {
+ ApplyTo = ApplyTo.All;
+ }
+
+ /// <summary>
+ /// Creates a new <see cref="ResponseFilterAttribute"/>
+ /// </summary>
+ /// <param name="applyTo">Defines when the filter should be executed</param>
+ public ResponseFilterAttribute(ApplyTo applyTo)
+ {
+ ApplyTo = applyTo;
+ }
+
+ public void ResponseFilter(IHttpRequest req, IHttpResponse res, object responseDto)
+ {
+ ApplyTo httpMethod = req.HttpMethodAsApplyTo();
+ if (ApplyTo.Has(httpMethod))
+ this.Execute(req, res, responseDto);
+ }
+
+ /// <summary>
+ /// This method is only executed if the HTTP method matches the <see cref="ApplyTo"/> property.
+ /// </summary>
+ /// <param name="req">The http request wrapper</param>
+ /// <param name="res">The http response wrapper</param>
+ /// <param name="requestDto">The response DTO</param>
+ public abstract void Execute(IHttpRequest req, IHttpResponse res, object requestDto);
+ }
+}
@@ -74,6 +74,13 @@ public static ICacheClient GetCacheClient(this IAppHost appHost)
?? DefaultCache;
}
+ public static ICacheClient GetCacheClient(this IHttpRequest httpRequest)
+ {
+ return httpRequest.TryResolve<ICacheClient>()
+ ?? (ICacheClient)httpRequest.TryResolve<IRedisClientsManager>()
+ ?? DefaultCache;
+ }
+
public static void SaveSession(this IServiceBase service, IAuthSession session)
{
using (var cache = service.GetCacheClient())
@@ -188,7 +188,9 @@
<Compile Include="FluentValidation\Validators\PropertyValidatorContext.cs" />
<Compile Include="FluentValidation\Validators\RegularExpressionValidator.cs" />
<Compile Include="IServiceBase.cs" />
+ <Compile Include="RequestFilterAttribute.cs" />
<Compile Include="RequiredPermissionAttribute.cs" />
+ <Compile Include="ResponseFilterAttribute.cs" />
<Compile Include="ServiceExtensions.cs" />
<Compile Include="RestServiceBase.cs" />
<Compile Include="ServiceBase.cs" />
@@ -282,6 +282,7 @@
<Compile Include="WebHost.EndPoints\EndpointHost.cs" />
<Compile Include="WebHost.EndPoints\EndpointHostConfig.cs" />
<Compile Include="WebHost.EndPoints\Support\SoapHandler.cs" />
+ <Compile Include="WebHost.EndPoints\Utils\FilterAttributeCache.cs" />
<Compile Include="WebHost.EndPoints\WebServerType.cs" />
<Compile Include="WebHost.EndPoints\XmlSyncReplyHandler.cs" />
<Compile Include="WebHost.EndPoints\Metadata\IndexMetadataHandler.cs" />
@@ -9,6 +9,7 @@
using ServiceStack.ServiceModel.Serialization;
using ServiceStack.WebHost.EndPoints.Formats;
using ServiceStack.WebHost.Endpoints.Formats;
+using ServiceStack.WebHost.EndPoints.Utils;
namespace ServiceStack.WebHost.Endpoints
{
@@ -156,6 +157,13 @@ public static bool ApplyRequestFilters(IHttpRequest httpReq, IHttpResponse httpR
if (httpRes.IsClosed) break;
}
+ IEnumerable<IHasRequestFilter> attributes = FilterAttributeCache.GetRequestFilterAttributes(requestDto.GetType());
+ foreach (var attribute in attributes)
+ {
+ attribute.RequestFilter(httpReq, httpRes, requestDto);
+ if (httpRes.IsClosed) break;
+ }
+
return httpRes.IsClosed;
}
}
@@ -178,6 +186,13 @@ public static bool ApplyResponseFilters(IHttpRequest httpReq, IHttpResponse http
if (httpRes.IsClosed) break;
}
+ IEnumerable<IHasResponseFilter> attributes = FilterAttributeCache.GetResponseFilterAttributes(responseDto.GetType());
+ foreach (var attribute in attributes)
+ {
+ attribute.ResponseFilter(httpReq, httpRes, responseDto);
+ if (httpRes.IsClosed) break;
+ }
+
return httpRes.IsClosed;
}
}
@@ -11,6 +11,7 @@
using ServiceStack.ServiceModel.Serialization;
using ServiceStack.Text;
using ServiceStack.WebHost.Endpoints.Extensions;
+using ServiceStack.WebHost.EndPoints.Utils;
namespace ServiceStack.WebHost.Endpoints.Support
{
@@ -65,23 +66,29 @@ protected Message ExecuteMessage(Message requestMsg, EndpointAttributes endpoint
IHttpRequest httpReq = null;
IHttpResponse httpRes = null;
- var hasRequestFilters = EndpointHost.RequestFilters.Count > 0;
- var hasResponseFilters = EndpointHost.ResponseFilters.Count > 0;
- if (hasRequestFilters || hasResponseFilters)
+ var hasRequestFilters = EndpointHost.RequestFilters.Count > 0
+ || FilterAttributeCache.GetRequestFilterAttributes(request.GetType()).Count() > 0;
+ if (hasRequestFilters)
{
httpReq = HttpContext.Current != null
? new HttpRequestWrapper(requestType.Name, HttpContext.Current.Request)
- : null;
- httpRes = HttpContext.Current != null
- ? new HttpResponseWrapper(HttpContext.Current.Response)
- : null;
+ : null;
}
if (hasRequestFilters && EndpointHost.ApplyRequestFilters(httpReq, httpRes, request))
return EmptyResponse(requestMsg, requestType);
var response = ExecuteService(request, endpointAttributes, httpReq, httpRes);
+ var hasResponseFilters = EndpointHost.ResponseFilters.Count > 0
+ || FilterAttributeCache.GetResponseFilterAttributes(response.GetType()).Count() > 0;
+ if (hasResponseFilters)
+ {
+ httpRes = HttpContext.Current != null
+ ? new HttpResponseWrapper(HttpContext.Current.Response)
+ : null;
+ }
+
if (hasResponseFilters && EndpointHost.ApplyResponseFilters(httpReq, httpRes, response))
return EmptyResponse(requestMsg, requestType);
Oops, something went wrong.

0 comments on commit b629abc

Please sign in to comment.