-
Notifications
You must be signed in to change notification settings - Fork 0
/
lib.rs
66 lines (60 loc) · 2.88 KB
/
lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
use proc_macro::TokenStream;
use std::fs;
use std::path::PathBuf;
use std::str::FromStr;
/// Convenience macro for checking all snapshots relative to the current package.
///
/// Brittle assumptions:
///
/// - Every usage of the macro contains a package name attribute (eg. `#[check_snapshots(docvim_parser)]`).
/// - Every usage of the macro is attached to a function named `transform`.
/// - Snapshots live at "test/snapshots" or its subdirectories.
/// - There are no symlinks under "test/snapshots".
/// - Snapshots are named "$something.snap".
/// - File and directory names are well-formed (eg. valid Unicode) with no spaces(!!!).
///
#[proc_macro_attribute]
pub fn check_snapshots(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut base = std::env::current_dir().expect("Could not get current directory");
base.push("libs");
base.push(attr.to_string());
base.push("tests/snapshots");
let mut tests = item.to_string();
fn walk(dir: &PathBuf, base: &PathBuf, tests: &mut String) {
for entry in fs::read_dir(dir).expect(&format!("Could not read directory {:?}", dir)) {
let entry = entry.expect("Could not access file");
let file_type = entry.file_type().expect("Could not get file type");
if file_type.is_symlink() {
panic!("Found symlink");
} else if file_type.is_dir() {
walk(&entry.path().to_path_buf(), &base, tests);
} else {
let snapshot = String::from(entry.path().to_str().expect("Invalid UTF-8 string"));
let snapshot_name = String::from(
entry
.path()
.strip_prefix(base)
.expect("Base is not prefix of path")
.to_str()
.expect("Invalid UTF-8 string"),
);
let snapshot_name = snapshot_name.replace("/", "_");
let snapshot_name = snapshot_name.trim_end_matches(".snap");
// TODO: panic if snapshot name isn't a valid function identifier; the trouble is, have to wade
// deep into Unicode to actually determine that - see:
// https://doc.rust-lang.org/reference/identifiers.html
tests.push_str("\n");
tests.push_str("#[test]\n");
tests.push_str(&format!(
"fn test_{}() -> Result<(), Box<dyn std::error::Error>> {{\n",
snapshot_name
));
tests.push_str(&format!(" assert!(docvim_snapshot::check_snapshot(std::path::Path::new(r####\"{}\"####), &transform, false)?);\n", snapshot));
tests.push_str(" Ok(())\n");
tests.push_str("}\n");
}
}
}
walk(&base, &base, &mut tests);
TokenStream::from_str(&tests).expect("Could not generate token stream")
}