-
Notifications
You must be signed in to change notification settings - Fork 990
/
main.rs
232 lines (200 loc) · 7.08 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
//! Example showing how to convert errors into responses.
//!
//! Run with
//!
//! ```not_rust
//! cargo run -p example-error-handling
//! ```
//!
//! For successful requests the log output will be
//!
//! ```ignore
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_request: started processing request
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_response: finished processing request latency=0 ms status=200
//! ```
//!
//! For failed requests the log output will be
//!
//! ```ignore
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_request: started processing request
//! ERROR request{method=POST uri=/users matched_path="/users"}: example_error_handling: error from time_library err=failed to get time
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_response: finished processing request latency=0 ms status=500
//! ```
use std::{
collections::HashMap,
sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
},
};
use axum::{
extract::{rejection::JsonRejection, FromRequest, MatchedPath, Request, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::post,
Router,
};
use serde::{Deserialize, Serialize};
use time_library::Timestamp;
use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "example_error_handling=debug,tower_http=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let state = AppState::default();
let app = Router::new()
// A dummy route that accepts some JSON but sometimes fails
.route("/users", post(users_create))
.layer(
TraceLayer::new_for_http()
// Create our own span for the request and include the matched path. The matched
// path is useful for figuring out which handler the request was routed to.
.make_span_with(|req: &Request| {
let method = req.method();
let uri = req.uri();
// axum automatically adds this extension.
let matched_path = req
.extensions()
.get::<MatchedPath>()
.map(|matched_path| matched_path.as_str());
tracing::debug_span!("request", %method, %uri, matched_path)
})
// By default `TraceLayer` will log 5xx responses but we're doing our specific
// logging of errors so disable that
.on_failure(()),
)
.with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
#[derive(Default, Clone)]
struct AppState {
next_id: Arc<AtomicU64>,
users: Arc<Mutex<HashMap<u64, User>>>,
}
#[derive(Deserialize)]
struct UserParams {
name: String,
}
#[derive(Serialize, Clone)]
struct User {
id: u64,
name: String,
created_at: Timestamp,
}
async fn users_create(
State(state): State<AppState>,
// Make sure to use our own JSON extractor so we get input errors formatted in a way that
// matches our application
AppJson(params): AppJson<UserParams>,
) -> Result<AppJson<User>, AppError> {
let id = state.next_id.fetch_add(1, Ordering::SeqCst);
// We have implemented `From<time_library::Error> for AppError` which allows us to use `?` to
// automatically convert the error
let created_at = Timestamp::now()?;
let user = User {
id,
name: params.name,
created_at,
};
state.users.lock().unwrap().insert(id, user.clone());
Ok(AppJson(user))
}
// Create our own JSON extractor by wrapping `axum::Json`. This makes it easy to override the
// rejection and provide our own which formats errors to match our application.
//
// `axum::Json` responds with plain text if the input is invalid.
#[derive(FromRequest)]
#[from_request(via(axum::Json), rejection(AppError))]
struct AppJson<T>(T);
impl<T> IntoResponse for AppJson<T>
where
axum::Json<T>: IntoResponse,
{
fn into_response(self) -> Response {
axum::Json(self.0).into_response()
}
}
// The kinds of errors we can hit in our application.
enum AppError {
// The request body contained invalid JSON
JsonRejection(JsonRejection),
// Some error from a third party library we're using
TimeError(time_library::Error),
}
// Tell axum how `AppError` should be converted into a response.
//
// This is also a convenient place to log errors.
impl IntoResponse for AppError {
fn into_response(self) -> Response {
// How we want errors responses to be serialized
#[derive(Serialize)]
struct ErrorResponse {
message: String,
}
let (status, message) = match self {
AppError::JsonRejection(rejection) => {
// This error is caused by bad user input so don't log it
(rejection.status(), rejection.body_text())
}
AppError::TimeError(err) => {
// Because `TraceLayer` wraps each request in a span that contains the request
// method, uri, etc we don't need to include those details here
tracing::error!(%err, "error from time_library");
// Don't expose any details about the error to the client
(
StatusCode::INTERNAL_SERVER_ERROR,
"Something went wrong".to_owned(),
)
}
};
(status, AppJson(ErrorResponse { message })).into_response()
}
}
impl From<JsonRejection> for AppError {
fn from(rejection: JsonRejection) -> Self {
Self::JsonRejection(rejection)
}
}
impl From<time_library::Error> for AppError {
fn from(error: time_library::Error) -> Self {
Self::TimeError(error)
}
}
// Imagine this is some third party library that we're using. It sometimes returns errors which we
// want to log.
mod time_library {
use std::sync::atomic::{AtomicU64, Ordering};
use serde::Serialize;
#[derive(Serialize, Clone)]
pub struct Timestamp(u64);
impl Timestamp {
pub fn now() -> Result<Self, Error> {
static COUNTER: AtomicU64 = AtomicU64::new(0);
// Fail on every third call just to simulate errors
if COUNTER.fetch_add(1, Ordering::SeqCst) % 3 == 0 {
Err(Error::FailedToGetTime)
} else {
Ok(Self(1337))
}
}
}
#[derive(Debug)]
pub enum Error {
FailedToGetTime,
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "failed to get time")
}
}
}