diff --git a/code/code.cpp b/code/code.cpp index 929071a..12fba2e 100644 --- a/code/code.cpp +++ b/code/code.cpp @@ -33,6 +33,7 @@ const Opcode Ops::OpGetLocal{22}; const Opcode Ops::OpSetLocal{23}; const Opcode Ops::OpGetBuiltin{24}; const Opcode Ops::OpClosure{25}; +const Opcode Ops::OpGetFree{26}; const std::unordered_map Code::definitions{ // For OpConstant, we store the index not the number itself @@ -142,6 +143,10 @@ const std::unordered_map Code::definitions{ // the number of free variables Definition{"OpClosure", {2, 1}}, }, + { + Ops::OpGetFree, + Definition{"OpGetFree", {1}}, + }, }; Instructions Code::make(const Opcode &op, const std::vector &operands) { diff --git a/code/code.hpp b/code/code.hpp index 8c0b6d3..e2cb2fa 100644 --- a/code/code.hpp +++ b/code/code.hpp @@ -57,6 +57,7 @@ struct Ops { static const Opcode OpSetLocal; static const Opcode OpGetBuiltin; static const Opcode OpClosure; + static const Opcode OpGetFree; }; /** diff --git a/compiler/compiler.cpp b/compiler/compiler.cpp index b33e935..676762e 100644 --- a/compiler/compiler.cpp +++ b/compiler/compiler.cpp @@ -158,13 +158,7 @@ void Compiler::compile(Node *node) { spdlog::error("identifier not found: {}", identifier->value); return; } - if (symbol.value().get().symbolScope == Symbol::localScope) { - emit(Ops::OpGetLocal, {symbol.value().get().index}); - } else if (symbol.value().get().symbolScope == Symbol::globalScope) { - emit(Ops::OpGetGlobal, {symbol.value().get().index}); - } else { - emit(Ops::OpGetBuiltin, {symbol.value().get().index}); - } + loadSymbol(symbol.value()); } StringLiteral *stringLiteral = dynamic_cast(node); @@ -205,14 +199,19 @@ void Compiler::compile(Node *node) { replaceLastPopWithReturn(); } + auto &freeSymbols = symbolTable->getFreeSymbols(); int numLocals = currentSymbolTable()->getNumDefinition(); Instructions instructions = leaveScope(); + for (auto &&symbol : freeSymbols) { + loadSymbol(symbol); + } + std::unique_ptr compiledFunction = std::make_unique(std::move(instructions), numLocals); int functionIndex = addConstant(compiledFunction); - emit(Ops::OpClosure, {functionIndex, 0}); + emit(Ops::OpClosure, {functionIndex, static_cast(freeSymbols.size())}); } ReturnStatement *returnStatement = dynamic_cast(node); @@ -310,3 +309,15 @@ void Compiler::replaceLastPopWithReturn() { replaceInstruction(lastPosition, Code::make(Ops::OpReturnValue, {})); currentScope().lastInstruction.op = Ops::OpReturnValue; } + +void Compiler::loadSymbol(std::reference_wrapper symbol) { + if (symbol.get().symbolScope == Symbol::localScope) { + emit(Ops::OpGetLocal, {symbol.get().index}); + } else if (symbol.get().symbolScope == Symbol::globalScope) { + emit(Ops::OpGetGlobal, {symbol.get().index}); + } else if (symbol.get().symbolScope == Symbol::builtinScope) { + emit(Ops::OpGetBuiltin, {symbol.get().index}); + } else { + emit(Ops::OpGetFree, {symbol.get().index}); + } +} diff --git a/compiler/compiler.hpp b/compiler/compiler.hpp index 59fadac..e2717ff 100644 --- a/compiler/compiler.hpp +++ b/compiler/compiler.hpp @@ -170,6 +170,13 @@ class Compiler { */ void replaceLastPopWithReturn(); + /** + * @brief load the symbol to the stack + * + * @param symbol + */ + void loadSymbol(std::reference_wrapper symbol); + inline int getScopeIndex() { return scopeIndex; } inline CompilationScope ¤tScope() { return scopes[scopeIndex]; } diff --git a/compiler/symbolTable.cpp b/compiler/symbolTable.cpp index 3d57a1a..8f7d47b 100644 --- a/compiler/symbolTable.cpp +++ b/compiler/symbolTable.cpp @@ -6,6 +6,7 @@ std::string Symbol::globalScope{"GLOBAL"}; std::string Symbol::localScope{"LOCAL"}; std::string Symbol::builtinScope{"BUILTIN"}; +std::string Symbol::freeScope{"FREE"}; Symbol &SymbolTable::define(const std::string &name) { Symbol symbol{name, Symbol::globalScope, numDefinitions}; @@ -25,12 +26,30 @@ Symbol &SymbolTable::defineBuiltin(int index, const std::string &name) { return store[name]; } +Symbol &SymbolTable::defineFree(const Symbol &freeSymbol) { + freeSymbols.push_back(freeSymbol); + + Symbol symbol{freeSymbol.name, Symbol::freeScope, static_cast(freeSymbols.size() - 1)}; + + store[freeSymbol.name] = symbol; + return store[freeSymbol.name]; +} + std::optional> SymbolTable::resolve(const std::string &name) { if (store.find(name) != store.end()) { return store[name]; } else { if (outer != nullptr) { - return outer->resolve(name); + auto obj = outer->resolve(name); + if (obj.has_value()) { + if (obj.value().get().symbolScope == Symbol::globalScope || + obj.value().get().symbolScope == Symbol::builtinScope) { + return obj; + } + + auto &freeSymbol = defineFree(obj.value().get()); + return freeSymbol; + } } } return {}; diff --git a/compiler/symbolTable.hpp b/compiler/symbolTable.hpp index 54cf089..2b61d87 100644 --- a/compiler/symbolTable.hpp +++ b/compiler/symbolTable.hpp @@ -6,11 +6,13 @@ #include #include #include +#include struct Symbol { static std::string globalScope; static std::string localScope; static std::string builtinScope; + static std::string freeScope; std::string name; std::string symbolScope; @@ -31,6 +33,7 @@ class SymbolTable { std::shared_ptr outer{}; std::unordered_map store{}; int numDefinitions{}; + std::vector freeSymbols{}; public: SymbolTable() = default; @@ -38,9 +41,11 @@ class SymbolTable { inline std::shared_ptr &getOuter() { return outer; } inline int getNumDefinition() { return numDefinitions; } + inline std::vector &getFreeSymbols() { return freeSymbols; } Symbol &define(const std::string &name); Symbol &defineBuiltin(int index, const std::string &name); + Symbol &defineFree(const Symbol &freeSymbol); std::optional> resolve(const std::string &name); }; diff --git a/compiler/tests/compilerTest.cpp b/compiler/tests/compilerTest.cpp index 3ddf5e9..5b18d88 100644 --- a/compiler/tests/compilerTest.cpp +++ b/compiler/tests/compilerTest.cpp @@ -784,3 +784,82 @@ TEST(Compiler, TestBuiltins) { EXPECT_TRUE(testInstructions(test.expectedInstructions, instructions)); } } + +TEST(compiler, TestClosures) { + std::vector>>> tests{ + { + "fn (a) {fn (b) {a + b}}", + { + std::variant>{ + std::in_place_index<1>, + { + Code::make(Ops::OpGetFree, {0}), + Code::make(Ops::OpGetLocal, {0}), + Code::make(Ops::OpAdd, {}), + Code::make(Ops::OpReturnValue, {}), + }, + }, + std::variant>{ + std::in_place_index<1>, + { + Code::make(Ops::OpGetLocal, {0}), + Code::make(Ops::OpClosure, {0, 1}), + Code::make(Ops::OpReturnValue, {}), + }, + }, + }, + { + Code::make(Ops::OpClosure, {1, 0}), + Code::make(Ops::OpPop, {}), + + }, + }, + { + "fn(a) {fn(b) {fn(c) { a + b + c} }}", + { + std::variant>{ + std::in_place_index<1>, + { + Code::make(Ops::OpGetFree, {0}), + Code::make(Ops::OpGetFree, {1}), + Code::make(Ops::OpAdd, {}), + Code::make(Ops::OpGetLocal, {}), + Code::make(Ops::OpAdd, {}), + Code::make(Ops::OpReturnValue, {}), + }, + }, + std::variant>{ + std::in_place_index<1>, + { + Code::make(Ops::OpGetFree, {0}), + Code::make(Ops::OpGetLocal, {0}), + Code::make(Ops::OpClosure, {0, 2}), + Code::make(Ops::OpReturnValue, {}), + }, + }, + std::variant>{ + std::in_place_index<1>, + { + Code::make(Ops::OpGetLocal, {0}), + Code::make(Ops::OpClosure, {1, 1}), + Code::make(Ops::OpReturnValue, {}), + }, + }, + }, + { + Code::make(Ops::OpClosure, {2, 0}), + Code::make(Ops::OpPop, {}), + + }, + }, + + }; + + for (auto &&test : tests) { + auto program = parse(test.input); + Compiler compiler; + compiler.compile(program.get()); + auto instructions = compiler.getBytecode().instructions; + EXPECT_TRUE(testInstructions(test.expectedInstructions, instructions)); + } +} diff --git a/compiler/tests/symbolTableTest.cpp b/compiler/tests/symbolTableTest.cpp index dd2ac6a..4ee5cb0 100644 --- a/compiler/tests/symbolTableTest.cpp +++ b/compiler/tests/symbolTableTest.cpp @@ -136,3 +136,74 @@ TEST(SymbolTable, TestDefineResolveBuiltins) { } } } + +TEST(SymbolTable, TestResolveFree) { + auto global = std::make_shared(); + auto firstLocal = std::make_shared(global); + auto secondLocal = std::make_shared(firstLocal); + + global->define("a"); + global->define("b"); + + firstLocal->define("c"); + firstLocal->define("d"); + + secondLocal->define("e"); + secondLocal->define("f"); + + struct TestData { + std::shared_ptr table; + std::vector expectedSymbols; + std::vector expectedFreeSymbols; + + TestData() = default; + }; + + std::vector tests{ + { + firstLocal, + { + {"a", Symbol::globalScope, 0}, + {"b", Symbol::globalScope, 1}, + {"c", Symbol::localScope, 0}, + {"d", Symbol::localScope, 1}, + }, + {}, + }, + { + secondLocal, + { + {"a", Symbol::globalScope, 0}, + {"b", Symbol::globalScope, 1}, + {"c", Symbol::freeScope, 0}, + {"d", Symbol::freeScope, 1}, + {"e", Symbol::localScope, 0}, + {"f", Symbol::localScope, 1}, + }, + { + {"c", Symbol::localScope, 0}, + {"d", Symbol::localScope, 1}, + }, + }, + }; + + for (auto &&test : tests) { + for (auto &&symbol : test.expectedSymbols) { + auto result = test.table->resolve(symbol.name); + if (!result.has_value()) { + FAIL(); + } + + if (result.value().get() != symbol) { + FAIL(); + } + } + + for (int i = 0; i < test.expectedFreeSymbols.size(); i++) { + auto &&result = test.table->getFreeSymbols()[i]; + if (result != test.expectedFreeSymbols[i]) { + FAIL(); + } + } + } +}