diff --git a/src/api/mod.rs b/src/api/mod.rs index fac5d63..d9a4aa0 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,8 +1,8 @@ use std::collections::{BTreeSet, btree_map}; -mod resources; -mod struct_enum; -mod types; +pub(crate) mod resources; +pub(crate) mod struct_enum; +pub(crate) mod types; use aide::openapi; use serde::{Deserialize, Serialize}; diff --git a/src/api/resources.rs b/src/api/resources.rs index c0809f7..b89780b 100644 --- a/src/api/resources.rs +++ b/src/api/resources.rs @@ -105,7 +105,7 @@ impl Resource { for operation in &self.operations { for param in &operation.query_params { - if let FieldType::SchemaRef { name } = ¶m.r#type { + if let FieldType::SchemaRef { name, inner: None } = ¶m.r#type { res.insert(name); } } @@ -125,7 +125,7 @@ impl Resource { #[derive(Deserialize, Serialize)] pub(crate) struct Operation { /// The operation ID from the spec. - id: String, + pub(crate) id: String, /// The name to use for the operation in code. pub(crate) name: String, /// Description of the operation to use for documentation. @@ -151,7 +151,7 @@ pub(crate) struct Operation { query_params: Vec, /// Name of the request body type, if any. #[serde(skip_serializing_if = "Option::is_none")] - request_body_schema_name: Option, + pub(crate) request_body_schema_name: Option, /// Some request bodies are required, but all the fields are optional (i.e. the CLI can omit /// this from the argument list). /// Only useful when `request_body_schema_name` is `Some`. diff --git a/src/api/struct_enum.rs b/src/api/struct_enum.rs index 82098f4..3f36f04 100644 --- a/src/api/struct_enum.rs +++ b/src/api/struct_enum.rs @@ -48,7 +48,10 @@ impl TypeData { if variant.properties.len() == 1 { variants.push(SimpleVariant { name: discriminator, - content: EnumVariantType::Ref { schema_ref: None }, + content: EnumVariantType::Ref { + schema_ref: None, + inner: None, + }, }); } else { let (variant_content_field, content) = get_content(variant)?; @@ -91,6 +94,7 @@ fn get_content(variant: &ObjectValidation) -> anyhow::Result<(String, EnumVarian p_name.to_owned(), EnumVariantType::Ref { schema_ref: Some(get_schema_name(Some(schema_ref.as_str())).unwrap()), + inner: None, }, )); } diff --git a/src/api/types.rs b/src/api/types.rs index 420213b..cc5e0b1 100644 --- a/src/api/types.rs +++ b/src/api/types.rs @@ -79,14 +79,14 @@ pub(crate) fn from_referenced_components( types } -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Debug, Clone)] pub(crate) struct Type { name: String, #[serde(skip_serializing_if = "Option::is_none")] description: Option, deprecated: bool, #[serde(flatten)] - data: TypeData, + pub data: TypeData, } impl Type { @@ -159,7 +159,7 @@ fn fields_referenced_schemas(fields: &[Field]) -> BTreeSet<&str> { .collect() } -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Debug, Clone)] #[serde(tag = "kind", rename_all = "snake_case")] pub(crate) enum TypeData { Struct { @@ -270,7 +270,7 @@ impl TypeData { } } -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Debug, Clone)] #[serde(tag = "repr", rename_all = "snake_case")] pub(crate) enum StructEnumRepr { // add more variants here to support other enum representations @@ -295,18 +295,18 @@ impl StructEnumRepr { EnumVariantType::Struct { fields } => { fields.iter().find_map(|f| f.r#type.referenced_schema()) } - EnumVariantType::Ref { schema_ref } => schema_ref.as_deref(), + EnumVariantType::Ref { schema_ref, .. } => schema_ref.as_deref(), }) .collect(), } } } -#[derive(Deserialize, Serialize, Clone)] +#[derive(Deserialize, Serialize, Debug, Clone)] pub(crate) struct Field { name: String, #[serde(serialize_with = "serialize_field_type")] - r#type: FieldType, + pub r#type: FieldType, #[serde(skip_serializing_if = "Option::is_none")] default: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -314,6 +314,8 @@ pub(crate) struct Field { required: bool, nullable: bool, deprecated: bool, + #[serde(skip_serializing_if = "Option::is_none")] + example: Option, } impl Field { @@ -322,6 +324,7 @@ impl Field { Schema::Bool(_) => bail!("unsupported bool schema"), Schema::Object(o) => o, }; + let example = obj.extensions.get("example").cloned(); let metadata = obj.metadata.clone().unwrap_or_default(); let nullable = obj @@ -329,7 +332,6 @@ impl Field { .get("nullable") .and_then(|v| v.as_bool()) .unwrap_or(false); - Ok(Self { name, r#type: FieldType::from_schema_object(obj)?, @@ -338,11 +340,12 @@ impl Field { required, nullable, deprecated: metadata.deprecated, + example, }) } } -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Debug, Clone)] #[serde(tag = "type", rename_all = "camelCase")] pub(crate) enum EnumVariantType { Struct { @@ -351,10 +354,12 @@ pub(crate) enum EnumVariantType { Ref { #[serde(skip_serializing_if = "Option::is_none")] schema_ref: Option, + #[serde(skip_serializing_if = "Option::is_none")] + inner: Option, }, } -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Debug, Clone)] pub(crate) struct SimpleVariant { /// Discriminator value that identifies this variant. pub name: String, @@ -396,6 +401,8 @@ pub(crate) enum FieldType { /// The name of another schema that defines this type. SchemaRef { name: String, + #[serde(skip_serializing_if = "Option::is_none")] + inner: Option, }, /// A string constant, used as an enum discriminator value. @@ -515,7 +522,7 @@ impl FieldType { bail!("unsupported multi-typed parameter: `{types:?}`") } None => match get_schema_name(obj.reference.as_deref()) { - Some(name) => Self::SchemaRef { name }, + Some(name) => Self::SchemaRef { name, inner: None }, None => bail!("unsupported type-less parameter"), }, }; @@ -545,7 +552,7 @@ impl FieldType { Self::List { inner } | Self::Set { inner } => { format!("List<{}>", inner.to_csharp_typename()).into() } - Self::SchemaRef { name } => filter_schema_ref(name, "Object"), + Self::SchemaRef { name, .. } => filter_schema_ref(name, "Object"), Self::StringConst { .. } => "string".into(), } } @@ -565,7 +572,7 @@ impl FieldType { Self::List { inner } | Self::Set { inner } => { format!("[]{}", inner.to_go_typename()).into() } - Self::SchemaRef { name } => filter_schema_ref(name, "map[string]any"), + Self::SchemaRef { name, .. } => filter_schema_ref(name, "map[string]any"), Self::StringConst { .. } => "string".into(), } } @@ -586,7 +593,7 @@ impl FieldType { Self::JsonObject => "Map".into(), Self::List { inner } => format!("List<{}>", inner.to_kotlin_typename()).into(), Self::Set { inner } => format!("Set<{}>", inner.to_kotlin_typename()).into(), - Self::SchemaRef { name } => filter_schema_ref(name, "Map"), + Self::SchemaRef { name, .. } => filter_schema_ref(name, "Map"), Self::StringConst { .. } => "String".into(), } } @@ -606,7 +613,7 @@ impl FieldType { Self::Map { value_ty } => { format!("{{ [key: string]: {} }}", value_ty.to_js_typename()).into() } - Self::SchemaRef { name } => filter_schema_ref(name, "any"), + Self::SchemaRef { name, .. } => filter_schema_ref(name, "any"), Self::StringConst { .. } => "string".into(), } } @@ -633,14 +640,14 @@ impl FieldType { value_ty.to_rust_typename(), ) .into(), - Self::SchemaRef { name } => filter_schema_ref(name, "serde_json::Value"), + Self::SchemaRef { name,.. } => filter_schema_ref(name, "serde_json::Value"), Self::StringConst { .. } => "String".into() } } pub(crate) fn referenced_schema(&self) -> Option<&str> { match self { - Self::SchemaRef { name } => { + Self::SchemaRef { name, .. } => { // TODO(10055): the `BackgroundTaskFinishedEvent2` struct has a field with type of `Data` // this corresponds to a `#[serde(untagged)]` enum `svix_server::v1::endpoints::background_tasks::Data` // we should change this server side, but for now I am changing it here @@ -659,7 +666,7 @@ impl FieldType { Self::Int16 | Self::UInt16 | Self::Int32 | Self::Int64 | Self::UInt64 => "int".into(), Self::String => "str".into(), Self::DateTime => "datetime".into(), - Self::SchemaRef { name } => filter_schema_ref(name, "t.Dict[str, t.Any]"), + Self::SchemaRef { name, .. } => filter_schema_ref(name, "t.Dict[str, t.Any]"), Self::Uri => "str".into(), Self::JsonObject => "t.Dict[str, t.Any]".into(), Self::Set { inner } | Self::List { inner } => { @@ -690,7 +697,7 @@ impl FieldType { FieldType::Map { value_ty } => { format!("Map", value_ty.to_java_typename()).into() } - FieldType::SchemaRef { name } => filter_schema_ref(name, "Object"), + FieldType::SchemaRef { name, .. } => filter_schema_ref(name, "Object"), // backwards compat FieldType::StringConst { .. } => "TypeEnum".into(), } @@ -698,7 +705,7 @@ impl FieldType { fn to_ruby_typename(&self) -> Cow<'_, str> { match self { - FieldType::SchemaRef { name } => name.clone().into(), + FieldType::SchemaRef { name, .. } => name.clone().into(), FieldType::StringConst { .. } => { unreachable!("FieldType::const should never be exposed to template code") } @@ -745,7 +752,7 @@ impl FieldType { | FieldType::List { .. } | FieldType::Set { .. } | FieldType::Map { .. } => "array".into(), - FieldType::SchemaRef { name } => name.clone().into(), + FieldType::SchemaRef { name, .. } => name.clone().into(), } } } @@ -827,6 +834,10 @@ impl minijinja::value::Object for FieldType { ensure_no_args(args, "is_string")?; Ok(matches!(**self, Self::String).into()) } + "is_uri" => { + ensure_no_args(args, "is_uri")?; + Ok(matches!(**self, Self::Uri).into()) + } "is_bool" => { ensure_no_args(args, "is_bool")?; Ok(matches!(**self, Self::Bool).into()) @@ -870,6 +881,17 @@ impl minijinja::value::Object for FieldType { }; Ok(ty.into()) } + "inner_schema_ref_ty" => { + ensure_no_args(args, "inner_schema_ref_ty")?; + let ty = match &**self { + FieldType::SchemaRef { inner, .. } => { + let i = inner.as_ref().unwrap().clone(); + Some(minijinja::Value::from_serialize(i)) + } + _ => None, + }; + Ok(ty.into()) + } // Returns the value type of a map "value_type" => { ensure_no_args(args, "value_type")?; diff --git a/src/codesamples.rs b/src/codesamples.rs new file mode 100644 index 0000000..a4e69c8 --- /dev/null +++ b/src/codesamples.rs @@ -0,0 +1,228 @@ +use std::{ + borrow::Cow, + collections::{BTreeMap, BTreeSet}, + sync::Arc, +}; + +use crate::{ + CodegenLanguage, + api::{ + Api, Resource, + types::{EnumVariantType, Field, FieldType, StructEnumRepr, Type, TypeData}, + }, + template, +}; +use aide::openapi::OpenApi; +use anyhow::Context; +use minijinja::{Value, context}; +use serde::Serialize; + +fn codesample_env( + path_param_to_example: Arc String>, +) -> Result, minijinja::Error> { + let mut env = template::populate_env(minijinja::Environment::new())?; + env.set_debug(true); + + let path_param_fn = path_param_to_example.clone(); + env.add_filter( + // given a path param name (for example `app_id`) return the same example id each time + "path_param_example", + move |path_param: Cow<'_, str>| -> Result { + Ok(path_param_fn(path_param.to_string())) + }, + ); + + let path_param_fn = path_param_to_example.clone(); + env.add_filter( + // format a path string `/api/v1/app/{{app_id}}` => `/api/v1/app/app_1srOrx2ZWZBpBUvZwXKQmoEYga2` + "populate_path_with_examples", + move |s: Cow<'_, str>, path_params: &Vec| -> Result { + let mut path_str = s.to_string(); + for field in path_params { + let field = field.as_str().expect("Expected this to be a string"); + path_str = + path_str.replace(&format!("{{{field}}}"), &path_param_fn(field.to_string())); + } + Ok(path_str) + }, + ); + Ok(env) +} + +fn recursively_resolve_type(ty_name: &str, api: &Api) -> Type { + let mut ty = api.types.get(ty_name).unwrap().clone(); + + let update_fields = |fields: &mut Vec, api: &Api| { + for f in fields.iter_mut() { + if let FieldType::SchemaRef { name, .. } = &f.r#type { + let inner_ty = api.types.get(name).unwrap().clone(); + f.r#type = FieldType::SchemaRef { + name: name.clone(), + inner: Some(inner_ty), + }; + } + } + }; + match ty.data { + TypeData::Struct { ref mut fields } => { + update_fields(fields, api); + } + TypeData::StringEnum { .. } => (), + TypeData::IntegerEnum { .. } => (), + TypeData::StructEnum { + ref mut fields, + ref mut repr, + .. + } => { + match repr { + StructEnumRepr::AdjacentlyTagged { variants, .. } => { + for v in variants.iter_mut() { + match &mut v.content { + EnumVariantType::Struct { fields } => { + update_fields(fields, api); + } + EnumVariantType::Ref { schema_ref, inner } => { + if let Some(schema_ref) = schema_ref { + let inner_ty = api.types.get(schema_ref).unwrap().clone(); + *inner = Some(inner_ty); + } + } + } + } + } + } + + update_fields(fields, api); + } + } + ty +} + +fn generate_sample( + env: &minijinja::Environment<'static>, + samples_map: &mut BTreeMap>, + api: &Api, + resource: &Resource, + resource_parents: &Vec, + templates: &CodesampleTemplates, +) { + for operation in &resource.operations { + for SampleTemplate { + source, + label, + formatting_lang, + lang_name, + } in &templates.templates + { + let req_body_ty = operation + .request_body_schema_name + .as_ref() + .map(|req_body_name| recursively_resolve_type(req_body_name, api)); + + let ctx = context! { operation, resource_parents, req_body_ty }; + + let codesample = env.render_str(source, ctx).unwrap(); + let sample = CodeSample { + lang: lang_name.to_string(), + source: codesample, + formatting_lang: *formatting_lang, + op_id: operation.id.clone(), + label: label.clone(), + }; + + let lang_vec = match samples_map.get_mut(formatting_lang) { + Some(v) => v, + None => { + samples_map.insert(*formatting_lang, vec![]); + samples_map.get_mut(formatting_lang).unwrap() + } + }; + + lang_vec.push(sample); + } + } + + for (subresource_name, subresource) in &resource.subresources { + let mut new_parents = resource_parents.clone(); + new_parents.push(subresource_name.clone()); + + generate_sample(env, samples_map, api, subresource, &new_parents, templates); + } +} + +#[derive(Debug, Serialize, Clone)] +pub struct CodeSample { + pub source: String, + pub lang: String, + pub label: String, + #[serde(skip)] + pub op_id: String, + #[serde(skip)] + pub formatting_lang: CodegenLanguage, +} + +struct SampleTemplate { + source: String, + label: String, + lang_name: String, + formatting_lang: CodegenLanguage, +} + +#[derive(Default)] +pub struct CodesampleTemplates { + templates: Vec, +} + +impl CodesampleTemplates { + pub fn add_template>( + &mut self, + label: S, + lang_name: S, + formatting_lang: CodegenLanguage, + source: S, + ) { + self.templates.push(SampleTemplate { + formatting_lang, + lang_name: lang_name.as_ref().to_string(), + label: label.as_ref().to_string(), + source: source.as_ref().to_string(), + }); + } +} + +pub async fn generate_codesamples( + openapi_spec: &str, + templates: CodesampleTemplates, + excluded_operation_ids: BTreeSet, + path_param_example: fn(String) -> String, +) -> anyhow::Result>> { + let openapi_spec: OpenApi = + serde_json::from_str(openapi_spec).context("failed to parse OpenAPI spec")?; + + let api_ir = crate::api::Api::new( + openapi_spec + .paths + .expect("found no endpoints in input spec"), + &mut openapi_spec.components.unwrap_or_default(), + &[], + crate::IncludeMode::OnlyPublic, + &excluded_operation_ids, + &BTreeSet::new(), + )?; + + let mut samples_map = BTreeMap::new(); + + let env = codesample_env(Arc::new(path_param_example))?; + + for (resource_name, resource) in &api_ir.resources { + generate_sample( + &env, + &mut samples_map, + &api_ir, + resource, + &vec![resource_name.clone()], + &templates, + ); + } + Ok(samples_map) +} diff --git a/src/generator.rs b/src/generator.rs index 3327579..843e959 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -53,7 +53,7 @@ pub(crate) fn generate( let tpl_source = fs::read_to_string(tpl_path)?; - let mut minijinja_env = template::env( + let mut minijinja_env = template::env_with_dir( Utf8Path::new(tpl_path) .parent() .with_context(|| format!("invalid template path `{tpl_path}`"))?, diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..c4e6c87 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,229 @@ +use std::{ + collections::BTreeSet, + io, + path::{Path, PathBuf}, +}; + +use aide::openapi::OpenApi; +use anyhow::{Context as _, bail}; +use camino::Utf8PathBuf; +use clap::{Parser, Subcommand}; +use fs_err::{self as fs}; +use schemars::schema::Schema; +use tempfile::TempDir; + +mod api; +mod codesamples; +mod generator; +mod postprocessing; +mod template; + +use self::{api::Api, generator::generate}; + +pub use crate::{ + codesamples::{CodeSample, CodesampleTemplates, generate_codesamples}, + postprocessing::CodegenLanguage, +}; + +#[derive(Parser)] +struct CliArgs { + /// Which operations to include. + #[arg(global = true, long, value_enum, default_value_t = IncludeMode::OnlyPublic)] + include_mode: IncludeMode, + + /// Ignore a specified operation id + #[arg(global = true, short, long = "exclude-op-id")] + excluded_operations: Vec, + + /// Only include specified operations + /// + /// This option only works with `--include-mode=only-specified`. + /// + /// Use this option, to run the codegen with a limited set of operations. + /// Op webhook models will be excluded from the generation + #[arg(global = true, long = "include-op-id")] + specified_operations: Vec, + + #[command(subcommand)] + command: Command, +} + +#[derive(Clone, Subcommand)] +enum Command { + /// Generate code from an OpenAPI spec. + Generate { + /// Path to a template file to use (`.jinja` extension can be omitted). + #[arg(short, long)] + template: Utf8PathBuf, + + /// Path to the input file(s). + #[arg(short, long)] + input_file: Vec, + + /// Path to the output directory. + #[arg(short, long)] + output_dir: Option, + + /// Disable automatic postprocessing of the output (formatting and automatic style fixes). + #[arg(long)] + no_postprocess: bool, + }, + /// Generate api.ron and types.ron files, for debugging. + Debug { + /// Path to the input file(s). + #[arg(short, long)] + input_file: Vec, + }, +} + +#[derive(Copy, Clone, clap::ValueEnum)] +#[clap(rename_all = "kebab-case")] +enum IncludeMode { + /// Only public options + OnlyPublic, + /// Both public operations and operations marked with `x-hidden` + PublicAndHidden, + /// Only operations marked with `x-hidden` + OnlyHidden, + /// Only operations that were specified in `--include-op-id` + OnlySpecified, +} + +pub fn run_cli_main() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_writer(io::stderr).init(); + + let args = CliArgs::parse(); + + let excluded_operations = BTreeSet::from_iter(args.excluded_operations); + let specified_operations = BTreeSet::from_iter(args.specified_operations); + + let input_files = match &args.command { + Command::Generate { input_file, .. } => input_file, + Command::Debug { input_file } => input_file, + }; + + let api = input_files + .iter() + .map(|input_file| { + let input_file = Path::new(input_file); + let input_file_ext = input_file + .extension() + .context("input file must have a file extension")?; + let input_file_contents = fs::read_to_string(input_file)?; + + if input_file_ext == "json" { + let spec: OpenApi = serde_json::from_str(&input_file_contents) + .context("failed to parse OpenAPI spec")?; + + let webhooks = get_webhooks(&spec); + Api::new( + spec.paths.context("found no endpoints in input spec")?, + &mut spec.components.unwrap_or_default(), + &webhooks, + args.include_mode, + &excluded_operations, + &specified_operations, + ) + .context("converting OpenAPI spec to our own representation") + } else if input_file_ext == "ron" { + ron::from_str(&input_file_contents).context("parsing ron file") + } else { + bail!("input file extension must be .json or .ron"); + } + }) + .collect::>()?; + + match args.command { + Command::Generate { + template, + output_dir, + no_postprocess, + .. + } => { + let generated_paths = match &output_dir { + Some(path) => { + let generated_paths = generate(api, template.into(), path, no_postprocess)?; + println!("done! output written to {path}"); + generated_paths + } + None => { + let output_dir_root = PathBuf::from("out"); + if !output_dir_root.exists() { + fs::create_dir(&output_dir_root).context("failed to create out dir")?; + } + + let tpl_file_name = template + .file_name() + .context("template must have a file name")?; + let prefix = tpl_file_name + .strip_suffix(".jinja") + .unwrap_or(tpl_file_name); + + let output_dir = + TempDir::with_prefix_in(prefix.to_owned() + ".", output_dir_root) + .context("failed to create tempdir")?; + + let path = output_dir + .path() + .try_into() + .context("non-UTF8 tempdir path")?; + + let generated_paths = generate(api, template.into(), path, no_postprocess)?; + println!("done! output written to {path}"); + + // Persist the TempDir if everything was successful + _ = output_dir.keep(); + generated_paths + } + }; + let paths: Vec<&str> = generated_paths.iter().map(|p| p.as_str()).collect(); + let serialized = serde_json::to_string_pretty(&paths)?; + fs::write(".generated_paths.json", serialized)?; + } + Command::Debug { .. } => { + let serialized = ron::ser::to_string_pretty(&api, Default::default())?; + fs::write("debug.ron", serialized)?; + } + } + + Ok(()) +} + +fn get_webhooks(spec: &OpenApi) -> Vec { + let empty_obj = serde_json::Map::new(); + let mut referenced_components = std::collections::BTreeSet::::new(); + if let Some(webhooks) = spec.extensions.get("x-webhooks") { + for req in webhooks.as_object().unwrap_or(&empty_obj).values() { + for method in req.as_object().unwrap_or(&empty_obj).values() { + if let Some(schema_ref) = + method["requestBody"]["content"]["application/json"]["schema"]["$ref"].as_str() + && let Some(schema_name) = schema_ref.split('/').next_back() + { + referenced_components.insert(schema_name.to_string()); + } + } + } + } + + // also check the spec.webhooks + for (_, webhook) in &spec.webhooks { + let Some(item) = webhook.as_item() else { + continue; + }; + + for (_, op) in item.iter() { + if let Some(body) = &op.request_body + && let Some(item) = body.as_item() + && let Some(json_content) = item.content.get("application/json") + && let Some(schema) = &json_content.schema + && let Schema::Object(obj) = &schema.json_schema + && let Some(reference) = &obj.reference + && let Some(component_name) = reference.split('/').next_back() + { + referenced_components.insert(component_name.to_owned()); + } + } + } + + referenced_components.into_iter().collect::>() +} diff --git a/src/main.rs b/src/main.rs index 48bd945..bb95029 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,223 +1,3 @@ -use std::{ - collections::BTreeSet, - io, - path::{Path, PathBuf}, -}; - -use aide::openapi::OpenApi; -use anyhow::{Context as _, bail}; -use camino::Utf8PathBuf; -use clap::{Parser, Subcommand}; -use fs_err::{self as fs}; -use schemars::schema::Schema; -use tempfile::TempDir; - -mod api; -mod generator; -mod postprocessing; -mod template; - -use self::{api::Api, generator::generate}; - -#[derive(Parser)] -struct CliArgs { - /// Which operations to include. - #[arg(global = true, long, value_enum, default_value_t = IncludeMode::OnlyPublic)] - include_mode: IncludeMode, - - /// Ignore a specified operation id - #[arg(global = true, short, long = "exclude-op-id")] - excluded_operations: Vec, - - /// Only include specified operations - /// - /// This option only works with `--include-mode=only-specified`. - /// - /// Use this option, to run the codegen with a limited set of operations. - /// Op webhook models will be excluded from the generation - #[arg(global = true, long = "include-op-id")] - specified_operations: Vec, - - #[command(subcommand)] - command: Command, -} - -#[derive(Clone, Subcommand)] -enum Command { - /// Generate code from an OpenAPI spec. - Generate { - /// Path to a template file to use (`.jinja` extension can be omitted). - #[arg(short, long)] - template: Utf8PathBuf, - - /// Path to the input file(s). - #[arg(short, long)] - input_file: Vec, - - /// Path to the output directory. - #[arg(short, long)] - output_dir: Option, - - /// Disable automatic postprocessing of the output (formatting and automatic style fixes). - #[arg(long)] - no_postprocess: bool, - }, - /// Generate api.ron and types.ron files, for debugging. - Debug { - /// Path to the input file(s). - #[arg(short, long)] - input_file: Vec, - }, -} - -#[derive(Copy, Clone, clap::ValueEnum)] -#[clap(rename_all = "kebab-case")] -enum IncludeMode { - /// Only public options - OnlyPublic, - /// Both public operations and operations marked with `x-hidden` - PublicAndHidden, - /// Only operations marked with `x-hidden` - OnlyHidden, - /// Only operations that were specified in `--include-op-id` - OnlySpecified, -} - fn main() -> anyhow::Result<()> { - tracing_subscriber::fmt().with_writer(io::stderr).init(); - - let args = CliArgs::parse(); - - let excluded_operations = BTreeSet::from_iter(args.excluded_operations); - let specified_operations = BTreeSet::from_iter(args.specified_operations); - - let input_files = match &args.command { - Command::Generate { input_file, .. } => input_file, - Command::Debug { input_file } => input_file, - }; - - let api = input_files - .iter() - .map(|input_file| { - let input_file = Path::new(input_file); - let input_file_ext = input_file - .extension() - .context("input file must have a file extension")?; - let input_file_contents = fs::read_to_string(input_file)?; - - if input_file_ext == "json" { - let spec: OpenApi = serde_json::from_str(&input_file_contents) - .context("failed to parse OpenAPI spec")?; - - let webhooks = get_webhooks(&spec); - Api::new( - spec.paths.context("found no endpoints in input spec")?, - &mut spec.components.unwrap_or_default(), - &webhooks, - args.include_mode, - &excluded_operations, - &specified_operations, - ) - .context("converting OpenAPI spec to our own representation") - } else if input_file_ext == "ron" { - ron::from_str(&input_file_contents).context("parsing ron file") - } else { - bail!("input file extension must be .json or .ron"); - } - }) - .collect::>()?; - - match args.command { - Command::Generate { - template, - output_dir, - no_postprocess, - .. - } => { - let generated_paths = match &output_dir { - Some(path) => { - let generated_paths = generate(api, template.into(), path, no_postprocess)?; - println!("done! output written to {path}"); - generated_paths - } - None => { - let output_dir_root = PathBuf::from("out"); - if !output_dir_root.exists() { - fs::create_dir(&output_dir_root).context("failed to create out dir")?; - } - - let tpl_file_name = template - .file_name() - .context("template must have a file name")?; - let prefix = tpl_file_name - .strip_suffix(".jinja") - .unwrap_or(tpl_file_name); - - let output_dir = - TempDir::with_prefix_in(prefix.to_owned() + ".", output_dir_root) - .context("failed to create tempdir")?; - - let path = output_dir - .path() - .try_into() - .context("non-UTF8 tempdir path")?; - - let generated_paths = generate(api, template.into(), path, no_postprocess)?; - println!("done! output written to {path}"); - - // Persist the TempDir if everything was successful - _ = output_dir.keep(); - generated_paths - } - }; - let paths: Vec<&str> = generated_paths.iter().map(|p| p.as_str()).collect(); - let serialized = serde_json::to_string_pretty(&paths)?; - fs::write(".generated_paths.json", serialized)?; - } - Command::Debug { .. } => { - let serialized = ron::ser::to_string_pretty(&api, Default::default())?; - fs::write("debug.ron", serialized)?; - } - } - - Ok(()) -} - -fn get_webhooks(spec: &OpenApi) -> Vec { - let empty_obj = serde_json::Map::new(); - let mut referenced_components = std::collections::BTreeSet::::new(); - if let Some(webhooks) = spec.extensions.get("x-webhooks") { - for req in webhooks.as_object().unwrap_or(&empty_obj).values() { - for method in req.as_object().unwrap_or(&empty_obj).values() { - if let Some(schema_ref) = - method["requestBody"]["content"]["application/json"]["schema"]["$ref"].as_str() - && let Some(schema_name) = schema_ref.split('/').next_back() - { - referenced_components.insert(schema_name.to_string()); - } - } - } - } - - // also check the spec.webhooks - for (_, webhook) in &spec.webhooks { - let Some(item) = webhook.as_item() else { - continue; - }; - - for (_, op) in item.iter() { - if let Some(body) = &op.request_body - && let Some(item) = body.as_item() - && let Some(json_content) = item.content.get("application/json") - && let Some(schema) = &json_content.schema - && let Schema::Object(obj) = &schema.json_schema - && let Some(reference) = &obj.reference - && let Some(component_name) = reference.split('/').next_back() - { - referenced_components.insert(component_name.to_owned()); - } - } - } - - referenced_components.into_iter().collect::>() + openapi_codegen::run_cli_main() } diff --git a/src/postprocessing.rs b/src/postprocessing.rs index bd649f7..46ee1e3 100644 --- a/src/postprocessing.rs +++ b/src/postprocessing.rs @@ -2,17 +2,51 @@ use std::{io, process::Command}; use anyhow::bail; use camino::{Utf8Path, Utf8PathBuf}; +use serde::Serialize; + +#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, PartialOrd, Ord)] +pub enum CodegenLanguage { + Python, + Rust, + Go, + Kotlin, + CSharp, + Java, + TypeScript, + Ruby, + Php, + Shell, + Unknown, +} + +impl CodegenLanguage { + pub fn ext(self) -> &'static str { + match self { + CodegenLanguage::Python => "py", + CodegenLanguage::Rust => "rs", + CodegenLanguage::Go => "go", + CodegenLanguage::Kotlin => "kt", + CodegenLanguage::CSharp => "cs", + CodegenLanguage::Java => "java", + CodegenLanguage::TypeScript => "ts", + CodegenLanguage::Ruby => "rb", + CodegenLanguage::Php => "php", + CodegenLanguage::Shell => "sh", + CodegenLanguage::Unknown => "txt", + } + } +} #[derive(Clone)] pub(crate) struct Postprocessor<'a> { files_to_process: &'a [Utf8PathBuf], - postprocessor_lang: PostprocessorLanguage, + postprocessor_lang: CodegenLanguage, output_dir: Utf8PathBuf, } impl<'a> Postprocessor<'a> { fn new( - postprocessor_lang: PostprocessorLanguage, + postprocessor_lang: CodegenLanguage, output_dir: Utf8PathBuf, files_to_process: &'a [Utf8PathBuf], ) -> Self { @@ -28,18 +62,19 @@ impl<'a> Postprocessor<'a> { files_to_process: &'a [Utf8PathBuf], ) -> Self { let lang = match ext { - "py" => PostprocessorLanguage::Python, - "rs" => PostprocessorLanguage::Rust, - "go" => PostprocessorLanguage::Go, - "kt" => PostprocessorLanguage::Kotlin, - "cs" => PostprocessorLanguage::CSharp, - "java" => PostprocessorLanguage::Java, - "ts" => PostprocessorLanguage::TypeScript, - "rb" => PostprocessorLanguage::Ruby, - "php" => PostprocessorLanguage::Php, + "py" => CodegenLanguage::Python, + "rs" => CodegenLanguage::Rust, + "go" => CodegenLanguage::Go, + "kt" => CodegenLanguage::Kotlin, + "cs" => CodegenLanguage::CSharp, + "java" => CodegenLanguage::Java, + "ts" => CodegenLanguage::TypeScript, + "rb" => CodegenLanguage::Ruby, + "php" => CodegenLanguage::Php, + "sh" => CodegenLanguage::Shell, _ => { tracing::warn!("no known postprocessing command(s) for {ext} files"); - PostprocessorLanguage::Unknown + CodegenLanguage::Unknown } }; Self::new(lang, output_dir.to_path_buf(), files_to_process) @@ -48,49 +83,35 @@ impl<'a> Postprocessor<'a> { pub(crate) fn run_postprocessor(&self) -> anyhow::Result<()> { match self.postprocessor_lang { // pass each file to postprocessor at once - PostprocessorLanguage::Java | PostprocessorLanguage::Rust => { + CodegenLanguage::Java | CodegenLanguage::Rust => { let commands = self.postprocessor_lang.postprocessing_commands(); for (command, args) in commands { execute_command(command, args, self.files_to_process)?; } } // pass output dir to postprocessor - PostprocessorLanguage::Ruby - | PostprocessorLanguage::Php - | PostprocessorLanguage::Python - | PostprocessorLanguage::Go - | PostprocessorLanguage::Kotlin - | PostprocessorLanguage::CSharp - | PostprocessorLanguage::TypeScript => { + CodegenLanguage::Ruby + | CodegenLanguage::Php + | CodegenLanguage::Python + | CodegenLanguage::Go + | CodegenLanguage::Kotlin + | CodegenLanguage::CSharp + | CodegenLanguage::TypeScript => { let commands = self.postprocessor_lang.postprocessing_commands(); for (command, args) in commands { execute_command(command, args, std::slice::from_ref(&self.output_dir))?; } } - PostprocessorLanguage::Unknown => (), + CodegenLanguage::Unknown | CodegenLanguage::Shell => (), } Ok(()) } } -#[derive(Clone, Copy)] -pub(crate) enum PostprocessorLanguage { - Python, - Rust, - Go, - Kotlin, - CSharp, - Java, - TypeScript, - Ruby, - Php, - Unknown, -} - -impl PostprocessorLanguage { +impl CodegenLanguage { fn postprocessing_commands(&self) -> &[(&'static str, &[&str])] { match self { - Self::Unknown => &[], + Self::Unknown | Self::Shell => &[], // https://github.com/astral-sh/ruff Self::Python => &[ ("ruff", &["check", "--no-respect-gitignore", "--fix"]), // First lint and remove unused imports diff --git a/src/template.rs b/src/template.rs index c5b1b63..6bc9d2d 100644 --- a/src/template.rs +++ b/src/template.rs @@ -3,15 +3,24 @@ use std::borrow::Cow; use camino::Utf8Path; use fs_err as fs; use heck::{ - ToLowerCamelCase as _, ToShoutySnakeCase as _, ToSnakeCase as _, ToUpperCamelCase as _, + ToKebabCase, ToLowerCamelCase as _, ToShoutySnakeCase as _, ToSnakeCase as _, + ToUpperCamelCase as _, }; use itertools::Itertools as _; use minijinja::{State, Value, path_loader, value::Kwargs}; +use serde::Deserialize; -pub(crate) fn env(tpl_dir: &Utf8Path) -> Result, minijinja::Error> { +pub fn env_with_dir( + tpl_dir: &Utf8Path, +) -> Result, minijinja::Error> { let mut env = minijinja::Environment::new(); env.set_loader(path_loader(tpl_dir)); + populate_env(env) +} +pub fn populate_env( + mut env: minijinja::Environment<'static>, +) -> Result, minijinja::Error> { // === Custom filters === // --- Case conversion --- @@ -25,6 +34,7 @@ pub(crate) fn env(tpl_dir: &Utf8Path) -> Result, env.add_filter("to_upper_camel_case", |s: Cow<'_, str>| { s.to_upper_camel_case() }); + env.add_filter("to_kebab_case", |s: Cow<'_, str>| s.to_kebab_case()); // --- OpenAPI utils --- env.add_filter( @@ -121,7 +131,16 @@ pub(crate) fn env(tpl_dir: &Utf8Path) -> Result, None => s.into_owned(), } }); - + env.add_filter( + "strip_trailing_str", + |s: Cow<'_, str>, trailing_str: Cow<'_, str>| match s + .trim_end() + .strip_suffix(&trailing_str.to_string()) + { + Some(stripped) => stripped.to_string(), + None => s.into_owned(), + }, + ); env.add_filter( "generate_kt_path_str", |s: Cow<'_, str>, path_params: &Vec| -> Result { @@ -221,6 +240,50 @@ pub(crate) fn env(tpl_dir: &Utf8Path) -> Result, }, ); + env.add_filter( + "format_json_string", + |s: Cow<'_, str>| -> Result { + let decoded: serde_json::Value = serde_json::from_str(&s).unwrap(); + Ok(serde_json::to_string_pretty(&decoded).unwrap()) + }, + ); + + env.add_function( + "panic", + |message: Cow<'_, str>| -> Result { + Err(minijinja::Error::new( + minijinja::ErrorKind::InvalidOperation, + message.to_string(), + )) + }, + ); + + env.add_filter( + // TODO: fix this ugly ass code :( + // TLDR: If I have a number serde_json::Value, it does not get passed the the template correctly + // see this issue: https://github.com/mitsuhiko/minijinja/issues/641 + "fix_serde_number_repr", + |num: Cow<'_, str>| -> Result { + if num.is_empty() { + // default int example value + return Ok(1); + } + if let Ok(parsed_num) = num.parse::() { + Ok(parsed_num) + } else { + #[derive(Deserialize)] + struct SerdeNum { + #[serde(rename = "$serde_json::private::Number")] + key: String, + } + let num: SerdeNum = serde_json::from_str(&num).unwrap(); + let parsed_num: i64 = num.key.parse().unwrap(); + + Ok(parsed_num) + } + }, + ); + Ok(env) }