Skip to content

Commit

Permalink
fix(injected): handle inline injections (#251)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevearc committed Dec 26, 2023
1 parent 7396fc0 commit f245cca
Show file tree
Hide file tree
Showing 19 changed files with 357 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ jobs:
# issues in my "needs triage" filter.
remove_question:
runs-on: ubuntu-latest
if: github.event.sender.login != 'stevearc'
steps:
- uses: actions/checkout@v2
- uses: actions-ecosystem/action-remove-labels@v1
Expand Down
118 changes: 96 additions & 22 deletions lua/conform/formatters/injected.lua
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,31 @@ local function apply_indent(lines, indentation)
end
end

---@class LangRange
---@field [1] string language
---@field [2] integer start lnum
---@field [3] integer start col
---@field [4] integer end lnum
---@field [5] integer end col

---@param ranges LangRange[]
---@param range LangRange
local function accum_range(ranges, range)
local last_range = ranges[#ranges]
if last_range then
if last_range[1] == range[1] and last_range[4] == range[2] and last_range[5] == range[3] then
last_range[4] = range[4]
last_range[5] = range[5]
return
end
end
table.insert(ranges, range)
end

---@class (exact) conform.InjectedFormatterOptions
---@field ignore_errors boolean
---@field lang_to_ext table<string, string>
---@field lang_to_formatters table<string, conform.FiletypeFormatter>

---@type conform.FileLuaFormatterConfig
return {
Expand All @@ -72,6 +95,26 @@ return {
options = {
-- Set to true to ignore errors
ignore_errors = false,
-- Map of treesitter language to file extension
-- A temporary file name with this extension will be generated during formatting
-- because some formatters care about the filename.
lang_to_ext = {
bash = "sh",
c_sharp = "cs",
elixir = "exs",
javascript = "js",
julia = "jl",
latex = "tex",
markdown = "md",
python = "py",
ruby = "rb",
rust = "rs",
teal = "tl",
typescript = "ts",
},
-- Map of treesitter language to formatters to use
-- (defaults to the value from formatters_by_ft)
lang_to_formatters = {},
},
condition = function(self, ctx)
local ok, parser = pcall(vim.treesitter.get_parser, ctx.buf)
Expand All @@ -93,12 +136,20 @@ return {
end
---@type conform.InjectedFormatterOptions
local options = self.options

---@param lang string
---@return nil|conform.FiletypeFormatter
local function get_formatters(lang)
return options.lang_to_formatters[lang] or conform.formatters_by_ft[lang]
end

--- Disable diagnostic to pass the typecheck github action
--- This is available on nightly, but not on stable
--- Stable doesn't have any parameters, so it's safe to always pass `false`
---@diagnostic disable-next-line: redundant-parameter
parser:parse(false)
local root_lang = parser:lang()
---@type LangRange[]
local regions = {}

for _, tree in pairs(parser:trees()) do
Expand All @@ -124,26 +175,26 @@ return {
do
---@diagnostic disable-next-line: invisible
local lang, combined, ranges = parser:_get_injection(match, metadata)
local has_formatters = conform.formatters_by_ft[lang] ~= nil
if lang and has_formatters and not combined and #ranges > 0 and lang ~= root_lang then
local start_lnum
local end_lnum
-- Merge all of the ranges into a single range
if
lang
and get_formatters(lang) ~= nil
and not combined
and #ranges > 0
and lang ~= root_lang
then
for _, range in ipairs(ranges) do
if not start_lnum or start_lnum > range[1] + 1 then
start_lnum = range[1] + 1
end
if not end_lnum or end_lnum < range[4] then
end_lnum = range[4]
end
end
if in_range(ctx.range, start_lnum, end_lnum) then
table.insert(regions, { lang, start_lnum, end_lnum })
accum_range(regions, { lang, range[1] + 1, range[2], range[4] + 1, range[5] })
end
end
end
end

if ctx.range then
regions = vim.tbl_filter(function(region)
return in_range(ctx.range, region[2], region[4])
end, regions)
end

-- Sort from largest start_lnum to smallest
table.sort(regions, function(a, b)
return a[2] > b[2]
Expand Down Expand Up @@ -171,7 +222,11 @@ return {

local formatted_lines = vim.deepcopy(lines)
for _, replacement in ipairs(replacements) do
local start_lnum, end_lnum, new_lines = unpack(replacement)
local start_lnum, start_col, end_lnum, end_col, new_lines = unpack(replacement)
local prefix = formatted_lines[start_lnum]:sub(1, start_col)
local suffix = formatted_lines[end_lnum]:sub(end_col + 1)
new_lines[1] = prefix .. new_lines[1]
new_lines[#new_lines] = new_lines[#new_lines] .. suffix
for _ = start_lnum, end_lnum do
table.remove(formatted_lines, start_lnum)
end
Expand All @@ -184,12 +239,20 @@ return {

local num_format = 0
local tmp_bufs = {}
local formatter_cb = function(err, idx, start_lnum, end_lnum, new_lines)
local formatter_cb = function(err, idx, region, input_lines, new_lines)
if err then
format_error = errors.coalesce(format_error, err)
replacements[idx] = err
else
replacements[idx] = { start_lnum, end_lnum, new_lines }
-- If the original lines started/ended with a newline, preserve that newline.
-- Many formatters will trim them, but they're important for the document structure.
if input_lines[1] == "" and new_lines[1] ~= "" then
table.insert(new_lines, 1, "")
end
if input_lines[#input_lines] == "" and new_lines[#new_lines] ~= "" then
table.insert(new_lines, "")
end
replacements[idx] = { region[2], region[3], region[4], region[5], new_lines }
end
num_format = num_format - 1
if num_format == 0 then
Expand All @@ -200,14 +263,22 @@ return {
end
end
local last_start_lnum = #lines + 1
for _, region in ipairs(regions) do
local lang, start_lnum, end_lnum = unpack(region)
for i, region in ipairs(regions) do
local lang = region[1]
local start_lnum = region[2]
local start_col = region[3]
local end_lnum = region[4]
local end_col = region[5]
-- Ignore regions that overlap (contain) other regions
if end_lnum < last_start_lnum then
num_format = num_format + 1
last_start_lnum = start_lnum
local input_lines = util.tbl_slice(lines, start_lnum, end_lnum)
local ft_formatters = conform.formatters_by_ft[lang]
input_lines[#input_lines] = input_lines[#input_lines]:sub(1, end_col)
if start_col > 0 then
input_lines[1] = input_lines[1]:sub(start_col + 1)
end
local ft_formatters = assert(get_formatters(lang))
---@type string[]
local formatter_names
if type(ft_formatters) == "function" then
Expand All @@ -226,15 +297,18 @@ return {
-- extension to determine a run mode (see https://github.com/stevearc/conform.nvim/issues/194)
-- This is using the language name as the file extension, but that is a reasonable
-- approximation for now. We can add special cases as the need arises.
local buf = vim.fn.bufadd(string.format("%s.%s", vim.api.nvim_buf_get_name(ctx.buf), lang))
local extension = options.lang_to_ext[lang] or lang
local buf =
vim.fn.bufadd(string.format("%s.%d.%s", vim.api.nvim_buf_get_name(ctx.buf), i, extension))
-- Actually load the buffer to set the buffer context which is required by some formatters such as `filetype`
vim.fn.bufload(buf)
tmp_bufs[buf] = true
local format_opts = { async = true, bufnr = buf, quiet = true }
conform.format_lines(formatter_names, input_lines, format_opts, function(err, new_lines)
log.trace("Injected %s:%d:%d formatted lines %s", lang, start_lnum, end_lnum, new_lines)
-- Preserve indentation in case the code block is indented
apply_indent(new_lines, indent)
formatter_cb(err, idx, start_lnum, end_lnum, new_lines)
vim.schedule_wrap(formatter_cb)(err, idx, region, input_lines, new_lines)
end)
end
end
Expand Down
20 changes: 20 additions & 0 deletions lua/conform/fs.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,24 @@ M.join = function(...)
return table.concat({ ... }, M.sep)
end

---@param filepath string
---@return boolean
M.exists = function(filepath)
local stat = uv.fs_stat(filepath)
return stat ~= nil and stat.type ~= nil
end

---@param filepath string
---@return string?
M.read_file = function(filepath)
if not M.exists(filepath) then
return nil
end
local fd = assert(uv.fs_open(filepath, "r", 420)) -- 0644
local stat = assert(uv.fs_fstat(fd))
local content = uv.fs_read(fd, stat.size)
uv.fs_close(fd)
return content
end

return M
9 changes: 8 additions & 1 deletion lua/conform/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ local M = {}
---@field inherit? boolean
---@field command? string|fun(self: conform.FormatterConfig, ctx: conform.Context): string
---@field prepend_args? string|string[]|fun(self: conform.FormatterConfig, ctx: conform.Context): string|string[]
---@field format? fun(self: conform.LuaFormatterConfig, ctx: conform.Context, lines: string[], callback: fun(err: nil|string, new_lines: nil|string[])) Mutually exclusive with command
---@field options? table

---@class (exact) conform.FormatterMeta
Expand Down Expand Up @@ -569,6 +570,12 @@ M.get_formatter_config = function(formatter, bufnr)
if type(override) == "function" then
override = override(bufnr)
end
if override and override.command and override.format then
local msg =
string.format("Formatter '%s' cannot define both 'command' and 'format' function", formatter)
vim.notify_once(msg, vim.log.levels.ERROR)
return nil
end

---@type nil|conform.FormatterConfig
local config = override
Expand All @@ -581,7 +588,7 @@ M.get_formatter_config = function(formatter, bufnr)
config = mod_config
end
elseif override then
if override.command then
if override.command or override.format then
config = override
else
local msg = string.format(
Expand Down
1 change: 1 addition & 0 deletions lua/conform/runner.lua
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ local function run_formatter(bufnr, formatter, config, ctx, input_lines, opts, c
end
log.debug("%s exited with code %d", formatter.name, code)
log.trace("Output lines: %s", output)
log.trace("%s stderr: %s", formatter.name, stderr)
callback(nil, output)
else
log.info("%s exited with code %d", formatter.name, code)
Expand Down
8 changes: 7 additions & 1 deletion run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@ else
(cd "$PLUGINS/plenary.nvim" && git pull)
fi

if [ ! -e "$PLUGINS/nvim-treesitter" ]; then
git clone --depth=1 https://github.com/nvim-treesitter/nvim-treesitter.git "$PLUGINS/nvim-treesitter"
else
(cd "$PLUGINS/nvim-treesitter" && git pull)
fi

XDG_CONFIG_HOME=".testenv/config" \
XDG_DATA_HOME=".testenv/data" \
XDG_STATE_HOME=".testenv/state" \
XDG_RUNTIME_DIR=".testenv/run" \
XDG_CACHE_HOME=".testenv/cache" \
nvim --headless -u tests/minimal_init.lua \
-c "PlenaryBustedDirectory ${1-tests} { minimal_init = './tests/minimal_init.lua' }"
-c "RunTests ${1-tests}"
echo "Success"
26 changes: 17 additions & 9 deletions tests/fake_formatter.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@

set -e

if [ -e "tests/fake_formatter_output" ]; then
cat tests/fake_formatter_output
else
cat
CODE=0
if [ "$1" = "--fail" ]; then
shift
echo "failure" >&2
CODE=1
fi
if [ "$1" = "--timeout" ]; then
shift
echo "timeout" >&2
sleep 4
fi

if [ "$1" = "--fail" ]; then
echo "failure" >&2
exit 1
elif [ "$1" = "--timeout" ]; then
sleep 4
output_file="$1"

if [ -n "$output_file" ] && [ -e "$output_file" ]; then
cat "$output_file"
else
cat
fi

exit $CODE
5 changes: 5 additions & 0 deletions tests/injected/block_quote.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
text

> ```lua
> local foo = 'bar'
> ```
5 changes: 5 additions & 0 deletions tests/injected/block_quote.md.formatted
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
text

> ```lua
> |local foo = 'bar'|
> ```
14 changes: 14 additions & 0 deletions tests/injected/combined_injections.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
text

<!-- comment -->

```lua
local foo = 'bar'
```


<!-- comment -->

```lua
local foo = 'bar'
```
14 changes: 14 additions & 0 deletions tests/injected/combined_injections.md.formatted
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
text

<!-- comment -->

```lua
|local foo = 'bar'|
```


<!-- comment -->

```lua
|local foo = 'bar'|
```
5 changes: 5 additions & 0 deletions tests/injected/inline.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
foo.innerHTML = `<div> hello </div>`;

bar.innerHTML = `
<div> world </div>
`;
5 changes: 5 additions & 0 deletions tests/injected/inline.ts.formatted
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
foo.innerHTML = `|<div> hello </div>|`;

bar.innerHTML = `
|<div> world </div>|
`;
5 changes: 5 additions & 0 deletions tests/injected/simple.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
text

```lua
local foo = 'bar'
```
5 changes: 5 additions & 0 deletions tests/injected/simple.md.formatted
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
text

```lua
|local foo = 'bar'|
```

0 comments on commit f245cca

Please sign in to comment.