Skip to content

Commit

Permalink
feat: add closure compilation support
Browse files Browse the repository at this point in the history
This commit adds closure compilations support. The idea is fairly
simple. When we resolve a new symbol, if we can find it in the outer
symbol table, we should add this value to the free symbol vector also
record the mapping into the table.
  • Loading branch information
shejialuo committed Apr 6, 2023
1 parent 751e763 commit 1db7f9e
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 9 deletions.
5 changes: 5 additions & 0 deletions code/code.cpp
Expand Up @@ -33,6 +33,7 @@ const Opcode Ops::OpGetLocal{22};
const Opcode Ops::OpSetLocal{23};
const Opcode Ops::OpGetBuiltin{24};
const Opcode Ops::OpClosure{25};
const Opcode Ops::OpGetFree{26};

const std::unordered_map<Opcode, Definition> Code::definitions{
// For OpConstant, we store the index not the number itself
Expand Down Expand Up @@ -142,6 +143,10 @@ const std::unordered_map<Opcode, Definition> Code::definitions{
// the number of free variables
Definition{"OpClosure", {2, 1}},
},
{
Ops::OpGetFree,
Definition{"OpGetFree", {1}},
},
};

Instructions Code::make(const Opcode &op, const std::vector<int> &operands) {
Expand Down
1 change: 1 addition & 0 deletions code/code.hpp
Expand Up @@ -57,6 +57,7 @@ struct Ops {
static const Opcode OpSetLocal;
static const Opcode OpGetBuiltin;
static const Opcode OpClosure;
static const Opcode OpGetFree;
};

/**
Expand Down
27 changes: 19 additions & 8 deletions compiler/compiler.cpp
Expand Up @@ -158,13 +158,7 @@ void Compiler::compile(Node *node) {
spdlog::error("identifier not found: {}", identifier->value);
return;
}
if (symbol.value().get().symbolScope == Symbol::localScope) {
emit(Ops::OpGetLocal, {symbol.value().get().index});
} else if (symbol.value().get().symbolScope == Symbol::globalScope) {
emit(Ops::OpGetGlobal, {symbol.value().get().index});
} else {
emit(Ops::OpGetBuiltin, {symbol.value().get().index});
}
loadSymbol(symbol.value());
}

StringLiteral *stringLiteral = dynamic_cast<StringLiteral *>(node);
Expand Down Expand Up @@ -205,14 +199,19 @@ void Compiler::compile(Node *node) {
replaceLastPopWithReturn();
}

auto &freeSymbols = symbolTable->getFreeSymbols();
int numLocals = currentSymbolTable()->getNumDefinition();
Instructions instructions = leaveScope();

for (auto &&symbol : freeSymbols) {
loadSymbol(symbol);
}

std::unique_ptr<Object> compiledFunction = std::make_unique<CompiledFunction>(std::move(instructions), numLocals);

int functionIndex = addConstant(compiledFunction);

emit(Ops::OpClosure, {functionIndex, 0});
emit(Ops::OpClosure, {functionIndex, static_cast<int>(freeSymbols.size())});
}

ReturnStatement *returnStatement = dynamic_cast<ReturnStatement *>(node);
Expand Down Expand Up @@ -310,3 +309,15 @@ void Compiler::replaceLastPopWithReturn() {
replaceInstruction(lastPosition, Code::make(Ops::OpReturnValue, {}));
currentScope().lastInstruction.op = Ops::OpReturnValue;
}

void Compiler::loadSymbol(std::reference_wrapper<Symbol> symbol) {
if (symbol.get().symbolScope == Symbol::localScope) {
emit(Ops::OpGetLocal, {symbol.get().index});
} else if (symbol.get().symbolScope == Symbol::globalScope) {
emit(Ops::OpGetGlobal, {symbol.get().index});
} else if (symbol.get().symbolScope == Symbol::builtinScope) {
emit(Ops::OpGetBuiltin, {symbol.get().index});
} else {
emit(Ops::OpGetFree, {symbol.get().index});
}
}
7 changes: 7 additions & 0 deletions compiler/compiler.hpp
Expand Up @@ -170,6 +170,13 @@ class Compiler {
*/
void replaceLastPopWithReturn();

/**
* @brief load the symbol to the stack
*
* @param symbol
*/
void loadSymbol(std::reference_wrapper<Symbol> symbol);

inline int getScopeIndex() { return scopeIndex; }

inline CompilationScope &currentScope() { return scopes[scopeIndex]; }
Expand Down
21 changes: 20 additions & 1 deletion compiler/symbolTable.cpp
Expand Up @@ -6,6 +6,7 @@
std::string Symbol::globalScope{"GLOBAL"};
std::string Symbol::localScope{"LOCAL"};
std::string Symbol::builtinScope{"BUILTIN"};
std::string Symbol::freeScope{"FREE"};

Symbol &SymbolTable::define(const std::string &name) {
Symbol symbol{name, Symbol::globalScope, numDefinitions};
Expand All @@ -25,12 +26,30 @@ Symbol &SymbolTable::defineBuiltin(int index, const std::string &name) {
return store[name];
}

Symbol &SymbolTable::defineFree(const Symbol &freeSymbol) {
freeSymbols.push_back(freeSymbol);

Symbol symbol{freeSymbol.name, Symbol::freeScope, static_cast<int>(freeSymbols.size() - 1)};

store[freeSymbol.name] = symbol;
return store[freeSymbol.name];
}

std::optional<std::reference_wrapper<Symbol>> SymbolTable::resolve(const std::string &name) {
if (store.find(name) != store.end()) {
return store[name];
} else {
if (outer != nullptr) {
return outer->resolve(name);
auto obj = outer->resolve(name);
if (obj.has_value()) {
if (obj.value().get().symbolScope == Symbol::globalScope ||
obj.value().get().symbolScope == Symbol::builtinScope) {
return obj;
}

auto &freeSymbol = defineFree(obj.value().get());
return freeSymbol;
}
}
}
return {};
Expand Down
5 changes: 5 additions & 0 deletions compiler/symbolTable.hpp
Expand Up @@ -6,11 +6,13 @@
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

struct Symbol {
static std::string globalScope;
static std::string localScope;
static std::string builtinScope;
static std::string freeScope;

std::string name;
std::string symbolScope;
Expand All @@ -31,16 +33,19 @@ class SymbolTable {
std::shared_ptr<SymbolTable> outer{};
std::unordered_map<std::string, Symbol> store{};
int numDefinitions{};
std::vector<Symbol> freeSymbols{};

public:
SymbolTable() = default;
SymbolTable(std::shared_ptr<SymbolTable> &o) : outer{o} {}

inline std::shared_ptr<SymbolTable> &getOuter() { return outer; }
inline int getNumDefinition() { return numDefinitions; }
inline std::vector<Symbol> &getFreeSymbols() { return freeSymbols; }

Symbol &define(const std::string &name);
Symbol &defineBuiltin(int index, const std::string &name);
Symbol &defineFree(const Symbol &freeSymbol);
std::optional<std::reference_wrapper<Symbol>> resolve(const std::string &name);
};

Expand Down
79 changes: 79 additions & 0 deletions compiler/tests/compilerTest.cpp
Expand Up @@ -784,3 +784,82 @@ TEST(Compiler, TestBuiltins) {
EXPECT_TRUE(testInstructions(test.expectedInstructions, instructions));
}
}

TEST(compiler, TestClosures) {
std::vector<CompilerTestCase<std::variant<int, std::vector<Instructions>>>> tests{
{
"fn (a) {fn (b) {a + b}}",
{
std::variant<int, std::vector<Instructions>>{
std::in_place_index<1>,
{
Code::make(Ops::OpGetFree, {0}),
Code::make(Ops::OpGetLocal, {0}),
Code::make(Ops::OpAdd, {}),
Code::make(Ops::OpReturnValue, {}),
},
},
std::variant<int, std::vector<Instructions>>{
std::in_place_index<1>,
{
Code::make(Ops::OpGetLocal, {0}),
Code::make(Ops::OpClosure, {0, 1}),
Code::make(Ops::OpReturnValue, {}),
},
},
},
{
Code::make(Ops::OpClosure, {1, 0}),
Code::make(Ops::OpPop, {}),

},
},
{
"fn(a) {fn(b) {fn(c) { a + b + c} }}",
{
std::variant<int, std::vector<Instructions>>{
std::in_place_index<1>,
{
Code::make(Ops::OpGetFree, {0}),
Code::make(Ops::OpGetFree, {1}),
Code::make(Ops::OpAdd, {}),
Code::make(Ops::OpGetLocal, {}),
Code::make(Ops::OpAdd, {}),
Code::make(Ops::OpReturnValue, {}),
},
},
std::variant<int, std::vector<Instructions>>{
std::in_place_index<1>,
{
Code::make(Ops::OpGetFree, {0}),
Code::make(Ops::OpGetLocal, {0}),
Code::make(Ops::OpClosure, {0, 2}),
Code::make(Ops::OpReturnValue, {}),
},
},
std::variant<int, std::vector<Instructions>>{
std::in_place_index<1>,
{
Code::make(Ops::OpGetLocal, {0}),
Code::make(Ops::OpClosure, {1, 1}),
Code::make(Ops::OpReturnValue, {}),
},
},
},
{
Code::make(Ops::OpClosure, {2, 0}),
Code::make(Ops::OpPop, {}),

},
},

};

for (auto &&test : tests) {
auto program = parse(test.input);
Compiler compiler;
compiler.compile(program.get());
auto instructions = compiler.getBytecode().instructions;
EXPECT_TRUE(testInstructions(test.expectedInstructions, instructions));
}
}
71 changes: 71 additions & 0 deletions compiler/tests/symbolTableTest.cpp
Expand Up @@ -136,3 +136,74 @@ TEST(SymbolTable, TestDefineResolveBuiltins) {
}
}
}

TEST(SymbolTable, TestResolveFree) {
auto global = std::make_shared<SymbolTable>();
auto firstLocal = std::make_shared<SymbolTable>(global);
auto secondLocal = std::make_shared<SymbolTable>(firstLocal);

global->define("a");
global->define("b");

firstLocal->define("c");
firstLocal->define("d");

secondLocal->define("e");
secondLocal->define("f");

struct TestData {
std::shared_ptr<SymbolTable> table;
std::vector<Symbol> expectedSymbols;
std::vector<Symbol> expectedFreeSymbols;

TestData() = default;
};

std::vector<TestData> tests{
{
firstLocal,
{
{"a", Symbol::globalScope, 0},
{"b", Symbol::globalScope, 1},
{"c", Symbol::localScope, 0},
{"d", Symbol::localScope, 1},
},
{},
},
{
secondLocal,
{
{"a", Symbol::globalScope, 0},
{"b", Symbol::globalScope, 1},
{"c", Symbol::freeScope, 0},
{"d", Symbol::freeScope, 1},
{"e", Symbol::localScope, 0},
{"f", Symbol::localScope, 1},
},
{
{"c", Symbol::localScope, 0},
{"d", Symbol::localScope, 1},
},
},
};

for (auto &&test : tests) {
for (auto &&symbol : test.expectedSymbols) {
auto result = test.table->resolve(symbol.name);
if (!result.has_value()) {
FAIL();
}

if (result.value().get() != symbol) {
FAIL();
}
}

for (int i = 0; i < test.expectedFreeSymbols.size(); i++) {
auto &&result = test.table->getFreeSymbols()[i];
if (result != test.expectedFreeSymbols[i]) {
FAIL();
}
}
}
}

0 comments on commit 1db7f9e

Please sign in to comment.