diff --git a/lua/mason-core/result.lua b/lua/mason-core/result.lua index 2521fff1d..6e7f942cf 100644 --- a/lua/mason-core/result.lua +++ b/lua/mason-core/result.lua @@ -176,6 +176,38 @@ function Result.run_catching(fn) end end +---@generic V +---@param fn fun(try: fun(result: Result): any): V +---@return Result # Result +function Result.try(fn) + local thread = coroutine.create(fn) + local step + step = function(...) + local ok, result = coroutine.resume(thread, ...) + if not ok then + -- l'exception! panique!!! + error(result, 0) + end + if coroutine.status(thread) == "dead" then + if getmetatable(result) == Result then + return result + else + return Result.success(result) + end + elseif getmetatable(result) == Result then + if result:is_failure() then + return result + else + return step(result:get_or_nil()) + end + else + -- yield to parent coroutine + return step(coroutine.yield(result)) + end + end + return step(coroutine.yield) +end + function Result.pcall(fn, ...) local ok, res = pcall(fn, ...) if ok then diff --git a/tests/mason-core/result_spec.lua b/tests/mason-core/result_spec.lua index d7d629f5e..b2900753c 100644 --- a/tests/mason-core/result_spec.lua +++ b/tests/mason-core/result_spec.lua @@ -2,6 +2,7 @@ local Result = require "mason-core.result" local match = require "luassert.match" local spy = require "luassert.spy" local Optional = require "mason-core.optional" +local a = require "mason-core.async" describe("result", function() it("should create success", function() @@ -212,3 +213,109 @@ describe("result", function() ) end) end) + +describe("Result.try", function() + it("should try functions", function() + assert.same( + Result.success "Hello, world!", + Result.try(function(try) + local hello = try(Result.success "Hello, ") + local world = try(Result.success "world!") + return hello .. world + end) + ) + + assert.same( + Result.success(), + Result.try(function(try) + try(Result.success "Hello, ") + try(Result.success "world!") + end) + ) + + assert.same( + Result.failure "Trouble, world!", + Result.try(function(try) + local trouble = try(Result.success "Trouble, ") + local world = try(Result.success "world!") + return try(Result.failure(trouble .. world)) + end) + ) + + local err = assert.has_error(function() + Result.try(function(try) + local err = try(Result.success "42") + error(err) + end) + end) + assert.equals("42", err) + end) + + it( + "should allow calling async functions inside try blocks", + async_test(function() + assert.same( + Result.success "Hello, world!", + Result.try(function(try) + a.sleep(10) + local hello = try(Result.success "Hello, ") + local world = try(Result.success "world!") + return hello .. world + end) + ) + local err = assert.has_error(function() + Result.try(function(try) + a.sleep(10) + local err = try(Result.success "42") + error(err) + end) + end) + assert.equals("42", err) + end) + ) + + it("should not unwrap result values in try blocks", function() + assert.same( + Result.failure "Error!", + Result.try(function() + return Result.failure "Error!" + end) + ) + + assert.same( + Result.success "Success!", + Result.try(function() + return Result.success "Success!" + end) + ) + end) + + it("should allow nesting try blocks", function() + assert.same( + Result.success "Hello from the underworld!", + Result.try(function(try) + local greeting = try(Result.success "Hello from the %s!") + return greeting:format(try(Result.try(function(try) + return try(Result.success "underworld") + end))) + end) + ) + end) + + it( + "should allow nesting try blocks in async scope", + async_test(function() + assert.same( + Result.success "Hello from the underworld!", + Result.try(function(try) + a.sleep(10) + local greeting = try(Result.success "Hello from the %s!") + return greeting:format(try(Result.try(function(try) + a.sleep(10) + return try(Result.success "underworld") + end))) + end) + ) + end) + ) +end)