Skip to content

Commit

Permalink
feat(auth): add google oidc device flow
Browse files Browse the repository at this point in the history
  • Loading branch information
ymgyt committed Mar 15, 2024
1 parent 19fe8c4 commit 24ccb77
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 3 deletions.
4 changes: 1 addition & 3 deletions crates/synd_auth/src/device_flow/github.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ use tracing::debug;

use crate::device_flow::{
DeviceAccessTokenErrorResponse, DeviceAccessTokenRequest, DeviceAccessTokenResponse,
DeviceAuthorizationRequest, DeviceAuthorizationResponse,
DeviceAuthorizationRequest, DeviceAuthorizationResponse, USER_AGENT,
};

const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));

/// <https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps#device-flow>
#[derive(Clone)]
pub struct DeviceFlow {
Expand Down
120 changes: 120 additions & 0 deletions crates/synd_auth/src/device_flow/google.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use std::{borrow::Cow, time::Duration};

use http::StatusCode;
use reqwest::Client;
use tracing::debug;

use crate::device_flow::{
DeviceAccessTokenErrorResponse, DeviceAccessTokenRequest, DeviceAccessTokenResponse,
DeviceAuthorizationRequest, DeviceAuthorizationResponse, USER_AGENT,
};

pub struct DeviceFlow {
client: Client,
client_id: Cow<'static, str>,
client_secret: Cow<'static, str>,
}

impl DeviceFlow {
const DEVICE_AUTHORIZATION_ENDPOINT: &'static str = "https://oauth2.googleapis.com/device/code";
const TOKEN_ENDPOINT: &'static str = "https://oauth2.googleapis.com/token";
/// <https://developers.google.com/identity/gsi/web/guides/devices#obtain_an_id_token_and_refresh_token>
const GRANT_TYPE: &'static str = "http://oauth.net/grant_type/device/1.0";

pub fn new(
client_id: impl Into<Cow<'static, str>>,
client_secret: impl Into<Cow<'static, str>>,
) -> Self {
let client = reqwest::ClientBuilder::new()
.user_agent(USER_AGENT)
.timeout(Duration::from_secs(5))
.build()
.unwrap();

Self {
client,
client_id: client_id.into(),
client_secret: client_secret.into(),
}
}

#[tracing::instrument(skip(self))]
pub async fn device_authorize_request(&self) -> anyhow::Result<DeviceAuthorizationResponse> {
// https://developers.google.com/identity/gsi/web/guides/devices#obtain_a_user_code_and_verification_url
let scope = "email";
let response = self
.client
.post(Self::DEVICE_AUTHORIZATION_ENDPOINT)
.header(http::header::ACCEPT, "application/json")
.form(&DeviceAuthorizationRequest {
client_id: self.client_id.clone(),
scope: scope.into(),
})
.send()
.await?
.error_for_status()?
.json::<DeviceAuthorizationResponse>()
.await?;

Ok(response)
}

pub async fn poll_device_access_token(
&self,
device_code: String,
interval: Option<i64>,
) -> anyhow::Result<DeviceAccessTokenResponse> {
// poll to check if user authorized the device
macro_rules! continue_or_abort {
( $response_bytes:ident ) => {{
let err_response = serde_json::from_slice::<DeviceAccessTokenErrorResponse>(&$response_bytes)?;
if err_response.error.should_continue_to_poll() {
debug!(error_code=?err_response.error,interval, "Continue to poll");

let interval = interval.unwrap_or(5);

tokio::time::sleep(Duration::from_secs(interval as u64)).await;
} else {
anyhow::bail!(
"Failed to authenticate. authorization server respond with {err_response:?}"
)
}
}};
}

let response = loop {
let response = self
.client
.post(Self::TOKEN_ENDPOINT)
.header(http::header::ACCEPT, "application/json")
.form(
&DeviceAccessTokenRequest::new(&device_code, self.client_id.as_ref())
.with_grant_type(Self::GRANT_TYPE)
.with_client_secret(self.client_secret.clone()),
)
.send()
.await?;

match response.status() {
StatusCode::OK => {
let full = response.bytes().await?;
if let Ok(response) = serde_json::from_slice::<DeviceAccessTokenResponse>(&full)
{
break response;
}
continue_or_abort!(full);
}
StatusCode::BAD_REQUEST => {
let full = response.bytes().await?;
continue_or_abort!(full);
}
other => {
let error_msg = response.text().await.unwrap_or_default();
anyhow::bail!("Failed to authenticate. authorization server respond with {other} {error_msg}")
}
}
};

Ok(response)
}
}
31 changes: 31 additions & 0 deletions crates/synd_auth/src/device_flow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ use http::Uri;
use serde::{Deserialize, Serialize};

pub mod github;
pub mod google;

const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));

/// <https://datatracker.ietf.org/doc/html/rfc8628#section-3.1>
#[derive(Serialize, Deserialize, Debug)]
Expand Down Expand Up @@ -42,6 +45,11 @@ pub struct DeviceAccessTokenRequest<'s> {
/// The device verification code, "device_code" from the device authorization response
pub device_code: Cow<'s, str>,
pub client_id: Cow<'s, str>,

// vendor extensions
/// Google require client secret
/// <https://developers.google.com/identity/gsi/web/guides/devices#obtain_an_id_token_and_refresh_token>
pub client_secret: Option<Cow<'s, str>>,
}

impl<'s> DeviceAccessTokenRequest<'s> {
Expand All @@ -53,6 +61,25 @@ impl<'s> DeviceAccessTokenRequest<'s> {
grant_type: Self::GRANT_TYPE.into(),
device_code: device_code.into(),
client_id: client_id.into(),
client_secret: None,
}
}

/// Configure `grant_type`
#[must_use]
pub fn with_grant_type(self, grant_type: impl Into<Cow<'static, str>>) -> Self {
Self {
grant_type: grant_type.into(),
..self
}
}

/// Configure `client_secret`
#[must_use]
pub fn with_client_secret(self, client_secret: impl Into<Cow<'s, str>>) -> Self {
Self {
client_secret: Some(client_secret.into()),
..self
}
}
}
Expand All @@ -66,6 +93,10 @@ pub struct DeviceAccessTokenResponse {
pub token_type: String,
/// the lifetime in seconds of the access token
pub expires_in: Option<i64>,

// OIDC usecase
pub refresh_token: Option<String>,
pub id_token: Option<String>,
}

/// <https://datatracker.ietf.org/doc/html/rfc6749#section-5.2>
Expand Down
2 changes: 2 additions & 0 deletions crates/synd_test/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ async fn device_access_token(
access_token: "gh_dummy_access_token".into(),
token_type: String::new(),
expires_in: None,
refresh_token: None,
id_token: None,
};

Ok(Json(res))
Expand Down

0 comments on commit 24ccb77

Please sign in to comment.