/
middleware.py
227 lines (166 loc) · 5.73 KB
/
middleware.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import hashlib
import hmac
import os
import time
import uuid
from flask import abort, g, redirect, request, session, url_for
import rollbar
from standardweb import app, stats
from standardweb.lib import api, csrf, geoip
from standardweb.lib import helpers as h
from standardweb.lib import player as libplayer
from standardweb.models import User, ForumBan
from standardweb.tasks.access_log import log as log_task
from sqlalchemy.orm import joinedload
@app.before_request
def user_session():
if request.endpoint and 'static' not in request.endpoint \
and request.endpoint != 'face' and session.get('user_id'):
g.user = User.query.options(
joinedload(User.player)
).options(
joinedload(User.posttracking)
).get(session['user_id'])
else:
g.user = None
if not session.get('client_uuid'):
session['client_uuid'] = uuid.uuid4()
session.permanent = True
@app.before_request
def csrf_protect():
if request.method == "POST":
func = app.view_functions.get(request.endpoint)
if func and func not in csrf.exempt_funcs and 'debugtoolbar' not in request.endpoint:
session_token = session.get('csrf_token')
request_token = request.form.get('csrf_token') or request.headers.get('X-CSRFToken')
if not session_token or session_token != request_token:
rollbar.report_message('CSRF mismatch', request=request, extra_data={
'session_token': session_token
})
csrf.regenerate_token()
abort(403)
@app.before_request
def force_auth_ssl():
# minimize MITM by making sure logged in sessions are secure after first non-secure request
if (
g.user and
app.config.get('SSL_REDIRECTION') and
not request.is_secure
):
return redirect(request.url.replace('http://', 'https://'))
@app.before_request
def first_login():
first_login = False
if request.endpoint and 'static' not in request.endpoint \
and request.endpoint != 'face' and session.get('user_id'):
if 'first_login' in session:
first_login = session.pop('first_login')
g.first_login = first_login
@app.before_request
def track_request_time():
g._start_time = time.time()
@app.before_request
def ensure_valid_user():
if request.method == "POST" and g.user and geoip.is_nok(request.remote_addr):
user = g.user
player = user.player
if not user.forum_ban:
ban = ForumBan(user_id=g.user.id)
ban.save(commit=True)
if player and not player.banned:
libplayer.ban_player(player, source='invalid_user', commit=True)
@app.after_request
def access_log(response):
if not hasattr(g, '_start_time'):
return response
endpoint = request.endpoint
route = request.url_rule.rule if request.url_rule else None
if endpoint and (
'static' in endpoint or endpoint == 'face'
):
return response
response_time = int(1000 * (time.time() - g._start_time))
stats.timing('endpoints.%s.%s' % (endpoint, request.method), response_time)
if route and route.startswith('/api'):
return response
client_uuid = str(session.get('client_uuid'))
user_id = g.user.id if g.user else None
log_task.apply_async((
client_uuid,
user_id,
request.method,
route,
request.full_path.rstrip('?'),
request.referrer,
response.status_code,
response_time,
request.headers.get('User-Agent'),
request.remote_addr
))
return response
@app.context_processor
def inject_user():
return dict(current_user=g.user)
@app.context_processor
def inject_h():
return dict(h=h)
@app.context_processor
def inject_debug():
return dict(is_debug=app.config['DEBUG'])
@app.context_processor
def inject_cdn_domain():
if not app.config['DEBUG']:
cdn_domain = '//%s' % app.config['CDN_DOMAIN']
else:
cdn_domain = ''
return dict(cdn_domain=cdn_domain)
@app.context_processor
def inject_new_messages():
new_messages = 0
if g.user:
new_messages = g.user.get_unread_message_count()
return dict(new_messages=new_messages)
@app.context_processor
def inject_new_notifications():
new_notifications = 0
if g.user:
new_notifications = g.user.get_unread_notification_count()
return dict(new_notifications=new_notifications)
def _dated_url_for(endpoint, **values):
if endpoint == 'static':
filename = values.get('filename', None)
if filename:
file_path = os.path.join(app.root_path, endpoint, filename)
try:
values['t'] = int(os.stat(file_path).st_mtime)
except:
pass
return url_for(endpoint, **values)
@app.context_processor
def rts_auth_data():
data = {}
if g.user:
user_id = g.user.id
username = g.user.player.username if g.user.player else g.user.username
uuid = g.user.player.uuid if g.user.player else ''
admin = g.user.admin
moderator = g.user.moderator
content = '-'.join([str(user_id), username, uuid, str(int(admin)), str(int(moderator))])
token = hmac.new(
app.config['RTS_SECRET'],
msg=content,
digestmod=hashlib.sha256
).hexdigest()
data = {
'user_id': user_id,
'username': username,
'uuid': uuid,
'is_superuser': int(admin),
'is_moderator': int(moderator),
'token': token
}
return {
'rts_base_url': app.config['RTS_BASE_URL'],
'rts_prefix': app.config['RTS_PREFIX'],
'rts_auth_data': data
}