Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement splat operator #49

Merged
merged 5 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions slox/Expression.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ indirect enum Expression: Equatable {
case subscriptGet(Expression, Expression)
case subscriptSet(Expression, Expression, Expression)
case dictionary([(Expression, Expression)])
case splat(Expression)

static func == (lhs: Expression, rhs: Expression) -> Bool {
switch (lhs, rhs) {
Expand Down
37 changes: 27 additions & 10 deletions slox/Interpreter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ class Interpreter {
return try handleSubscriptSetExpression(collectionExpr: listExpr,
indexExpr: indexExpr,
valueExpr: valueExpr)
case .splat(let listExpr):
return try handleSplatExpression(listExpr: listExpr)
case .dictionary(let kvPairs):
return try handleDictionary(kvExprPairs: kvPairs)
}
Expand Down Expand Up @@ -477,16 +479,12 @@ class Interpreter {
throw RuntimeError.notACallableObject
}

let argValues = try evaluateAndFlatten(exprs: args)

guard let parameterList = actualCallable.parameterList else {
fatalError()
}
try parameterList.checkArity(argCount: args.count)

var argValues: [LoxValue] = []
for arg in args {
let argValue = try evaluate(expr: arg)
argValues.append(argValue)
}
try parameterList.checkArity(argCount: argValues.count)

return try actualCallable.call(interpreter: self, args: argValues)
}
Expand Down Expand Up @@ -553,9 +551,7 @@ class Interpreter {
}

private func handleListExpression(elements: [ResolvedExpression]) throws -> LoxValue {
let elementValues = try elements.map { element in
return try evaluate(expr: element)
}
let elementValues = try evaluateAndFlatten(exprs: elements)

return try makeList(elements: elementValues)
}
Expand Down Expand Up @@ -604,6 +600,10 @@ class Interpreter {
return value
}

private func handleSplatExpression(listExpr: ResolvedExpression) throws -> LoxValue {
return try evaluate(expr: listExpr)
}

private func handleDictionary(kvExprPairs: [(ResolvedExpression, ResolvedExpression)]) throws -> LoxValue {
var kvPairs: [LoxValue: LoxValue] = [:]

Expand All @@ -621,6 +621,23 @@ class Interpreter {
return .instance(dictionary)
}

// Utility functions
private func evaluateAndFlatten(exprs: [ResolvedExpression]) throws -> [LoxValue] {
let values = try exprs.flatMap { expr in
if case .splat = expr {
guard case .instance(let list as LoxList) = try evaluate(expr: expr) else {
throw RuntimeError.notAList
}
return list.elements
} else {
let elementValue = try evaluate(expr: expr)
return [elementValue]
}
}

return values
}

func makeList(elements: [LoxValue]) throws -> LoxValue {
guard case .instance(let listClass as LoxClass) = try environment.getValue(name: "List") else {
fatalError()
Expand Down
7 changes: 6 additions & 1 deletion slox/Parser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ struct Parser {
// comparison → term ( ( ">" | ">=" | "<" | "<=" ) term )* ;
// term → factor ( ( "-" | "+" ) factor )* ;
// factor → unary ( ( "/" | "*" | "%" ) unary )* ;
// unary → ( "!" | "-" ) unary
// unary → ( "!" | "-" | "*" ) unary
// | postfix ;
// postfix → primary ( "(" arguments? ")" | "." IDENTIFIER | "[" logicOr "]" )* ;
// primary → NUMBER | STRING | "true" | "false" | "nil"
Expand Down Expand Up @@ -541,6 +541,11 @@ struct Parser {
return .unary(oper, expr)
}

if currentTokenMatchesAny(types: [.star]) {
let expr = try parseUnary()
return .splat(expr)
}

return try parsePostfix()
}

Expand Down
1 change: 1 addition & 0 deletions slox/ResolvedExpression.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ indirect enum ResolvedExpression: Equatable {
case subscriptGet(ResolvedExpression, ResolvedExpression)
case subscriptSet(ResolvedExpression, ResolvedExpression, ResolvedExpression)
case dictionary([(ResolvedExpression, ResolvedExpression)])
case splat(ResolvedExpression)

static func == (lhs: ResolvedExpression, rhs: ResolvedExpression) -> Bool {
switch (lhs, rhs) {
Expand Down
31 changes: 31 additions & 0 deletions slox/Resolver.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,17 @@ struct Resolver {
case loop
}

private enum ArgumentListType {
case none
case functionCall
case listInitializer
}

private var scopeStack: [[String: Bool]] = []
private var currentFunctionType: FunctionType = .none
private var currentClassType: ClassType = .none
private var currentLoopType: LoopType = .none
private var currentArgumentListType: ArgumentListType = .none

// Main point of entry
mutating func resolve(statements: [Statement]) throws -> [ResolvedStatement] {
Expand Down Expand Up @@ -334,6 +341,8 @@ struct Resolver {
return try handleSubscriptGet(listExpr: listExpr, indexExpr: indexExpr)
case .subscriptSet(let listExpr, let indexExpr, let valueExpr):
return try handleSubscriptSet(listExpr: listExpr, indexExpr: indexExpr, valueExpr: valueExpr)
case .splat(let listExpr):
return try handleSplat(listExpr: listExpr)
case .dictionary(let kvPairs):
return try handleDictionary(kvPairs: kvPairs)
}
Expand Down Expand Up @@ -373,6 +382,12 @@ struct Resolver {
mutating private func handleCall(calleeExpr: Expression,
rightParenToken: Token,
args: [Expression]) throws -> ResolvedExpression {
let previousArgumentListType = currentArgumentListType
currentArgumentListType = .functionCall
defer {
currentArgumentListType = previousArgumentListType
}

let resolvedCalleeExpr = try resolve(expression: calleeExpr)

let resolvedArgs = try args.map { arg in
Expand Down Expand Up @@ -470,6 +485,12 @@ struct Resolver {
}

mutating private func handleList(elements: [Expression]) throws -> ResolvedExpression {
let previousArgumentListType = currentArgumentListType
currentArgumentListType = .listInitializer
defer {
currentArgumentListType = previousArgumentListType
}

let resolvedElements = try elements.map { element in
return try resolve(expression: element)
}
Expand All @@ -494,6 +515,16 @@ struct Resolver {
return .subscriptSet(resolvedListExpr, resolvedIndexExpr, resolvedValueExpr)
}

mutating private func handleSplat(listExpr: Expression) throws -> ResolvedExpression {
if currentArgumentListType == .none {
throw ResolverError.cannotUseSplatOperatorOutOfContext
}

let resolvedListExpr = try resolve(expression: listExpr)

return .splat(resolvedListExpr)
}

mutating private func handleDictionary(kvPairs: [(Expression, Expression)]) throws -> ResolvedExpression {
var resolvedKVPairs: [(ResolvedExpression, ResolvedExpression)] = []

Expand Down
3 changes: 3 additions & 0 deletions slox/ResolverError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ enum ResolverError: CustomStringConvertible, Equatable, LocalizedError {
case cannotBreakOutsideLoop
case cannotContinueOutsideLoop
case functionsMustHaveAParameterList
case cannotUseSplatOperatorOutOfContext

var description: String {
switch self {
Expand Down Expand Up @@ -50,6 +51,8 @@ enum ResolverError: CustomStringConvertible, Equatable, LocalizedError {
return "Can only `continue` from inside a `while` or `for` loop"
case .functionsMustHaveAParameterList:
return "Functions must have a parameter list"
case .cannotUseSplatOperatorOutOfContext:
return "Cannot use splat operator in this context"
}
}
}
50 changes: 50 additions & 0 deletions sloxTests/InterpreterTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,37 @@ avg(1, 2, 3, 4, 5)
XCTAssertEqual(actual, expected)
}

func testInterpretSplattingIntoAnArgumentList() throws {
let input = """
var foo = [1, 2, 3];
fun sum(a, b, c) {
return a+b+c;
}
sum(*foo)
"""

let interpreter = Interpreter()
let actual = try interpreter.interpretRepl(source: input)
let expected: LoxValue = .int(6)
XCTAssertEqual(actual, expected)
}

func testInterpretSplattingWorksProperlyWithArityChecker() throws {
let input = """
fun sum(a, b, c) {
return a+b+c;
}
var foo = [1, 2, 3];
sum(*foo, 4, 5)
"""

let interpreter = Interpreter()
let expectedError = RuntimeError.wrongArity(3, 5)
XCTAssertThrowsError(try interpreter.interpretRepl(source: input)!) { actualError in
XCTAssertEqual(actualError as! RuntimeError, expectedError)
}
}

func testInterpretClassDeclarationAndInstantiation() throws {
let input = """
class Person {}
Expand Down Expand Up @@ -724,6 +755,25 @@ foo.reduce(0, fun(acc, n) { return acc+n; })
XCTAssertEqual(actual, expected)
}

func testInterpretSplattingAListIntoAnotherList() throws {
let input = """
var foo = [1, 2, 3];
[*foo, 4, 5, 6]
"""

let interpreter = Interpreter()
let actual = try interpreter.interpretRepl(source: input)
let expected = try interpreter.makeList(elements: [
.int(1),
.int(2),
.int(3),
.int(4),
.int(5),
.int(6),
])
XCTAssertEqual(actual, expected)
}

func testInterpretForLoopWithBreakStatement() throws {
let input = """
var sum = 0;
Expand Down
35 changes: 35 additions & 0 deletions sloxTests/ParserTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,41 @@ final class ParserTests: XCTestCase {
XCTAssertEqual(actual, expected)
}

func testParseFunctionCallWithSplattedArgument() throws {
// add(*[1, 2, 3])
let tokens: [Token] = [
Token(type: .identifier, lexeme: "add", line: 1),
Token(type: .leftParen, lexeme: "(", line: 1),
Token(type: .star, lexeme: "*", line: 1),
Token(type: .leftBracket, lexeme: "[", line: 1),
Token(type: .int, lexeme: "1", line: 1),
Token(type: .comma, lexeme: ",", line: 1),
Token(type: .int, lexeme: "2", line: 1),
Token(type: .comma, lexeme: ",", line: 1),
Token(type: .int, lexeme: "3", line: 1),
Token(type: .rightBracket, lexeme: "]", line: 1),
Token(type: .rightParen, lexeme: ")", line: 1),
Token(type: .eof, lexeme: "", line: 1),
]

var parser = Parser(tokens: tokens)
let actual = try parser.parse()
let expected: [Statement] = [
.expression(
.call(
.variable(Token(type: .identifier, lexeme: "add", line: 1)),
Token(type: .rightParen, lexeme: ")", line: 1),
[
.splat(
.list([
.literal(.int(1)),
.literal(.int(2)),
.literal(.int(3)),
]))
])),
]
}

func testParseLambdaExpression() throws {
// fun (a, b) { return a + b; }
let tokens: [Token] = [
Expand Down
19 changes: 19 additions & 0 deletions sloxTests/ResolverTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -521,4 +521,23 @@ final class ResolverTests: XCTestCase {
XCTAssertEqual(actualError as! ResolverError, expectedError)
}
}

func testResolveTopLevelExpressionWithSplatOperator() throws {
// *[1, 2, 3]
let statements: [Statement] = [
.expression(
.splat(
.list([
.literal(.int(1)),
.literal(.int(2)),
.literal(.int(3)),
])))
]

var resolver = Resolver()
let expectedError = ResolverError.cannotUseSplatOperatorOutOfContext
XCTAssertThrowsError(try resolver.resolve(statements: statements)) { actualError in
XCTAssertEqual(actualError as! ResolverError, expectedError)
}
}
}