Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(contacts): use bitmap for skipped messages #893

Merged
merged 8 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/version/version.go
Expand Up @@ -7,7 +7,7 @@ import (
"strings"
)

var tag = "v4.2.8"
var tag = "v4.2.9"

var commit = func() string {
if info, ok := debug.ReadBuildInfo(); ok {
Expand Down
47 changes: 34 additions & 13 deletions contracts/integration-test/L1MessageQueue.spec.ts
Expand Up @@ -264,8 +264,8 @@ describe("L1MessageQueue", async () => {
});

it("should succeed", async () => {
// append 100 messages
for (let i = 0; i < 100; i++) {
// append 512 messages
for (let i = 0; i < 256 * 2; i++) {
await queue.connect(messenger).appendCrossDomainMessage(constants.AddressZero, 1000000, "0x");
}

Expand All @@ -274,7 +274,8 @@ describe("L1MessageQueue", async () => {
.to.emit(queue, "DequeueTransaction")
.withArgs(0, 50, 0);
for (let i = 0; i < 50; i++) {
expect(await queue.getCrossDomainMessage(i)).to.eq(constants.HashZero);
expect(await queue.isMessageSkipped(i)).to.eq(false);
expect(await queue.isMessageDropped(i)).to.eq(false);
}
expect(await queue.pendingQueueIndex()).to.eq(50);

Expand All @@ -284,7 +285,8 @@ describe("L1MessageQueue", async () => {
.withArgs(50, 10, 1023);
expect(await queue.pendingQueueIndex()).to.eq(60);
for (let i = 50; i < 60; i++) {
expect(BigNumber.from(await queue.getCrossDomainMessage(i))).to.gt(constants.Zero);
expect(await queue.isMessageSkipped(i)).to.eq(true);
expect(await queue.isMessageDropped(i)).to.eq(false);
}

// pop 20 messages, skip first 5
Expand All @@ -293,10 +295,27 @@ describe("L1MessageQueue", async () => {
.withArgs(60, 20, 31);
expect(await queue.pendingQueueIndex()).to.eq(80);
for (let i = 60; i < 65; i++) {
expect(BigNumber.from(await queue.getCrossDomainMessage(i))).to.gt(constants.Zero);
expect(await queue.isMessageSkipped(i)).to.eq(true);
expect(await queue.isMessageDropped(i)).to.eq(false);
}
for (let i = 65; i < 80; i++) {
expect(await queue.getCrossDomainMessage(i)).to.eq(constants.HashZero);
expect(await queue.isMessageSkipped(i)).to.eq(false);
expect(await queue.isMessageDropped(i)).to.eq(false);
}

// pop 256 messages with random skip
const bitmap = BigNumber.from("0x496525059c3f33758d17030403e45afe067b8a0ae1317cda0487fd2932cbea1a");
const tx = await queue.connect(scrollChain).popCrossDomainMessage(80, 256, bitmap);
await expect(tx).to.emit(queue, "DequeueTransaction").withArgs(80, 256, bitmap);
console.log("gas used:", (await tx.wait()).gasUsed.toString());
for (let i = 80; i < 80 + 256; i++) {
expect(await queue.isMessageSkipped(i)).to.eq(
bitmap
.shr(i - 80)
.and(1)
.eq(1)
);
expect(await queue.isMessageDropped(i)).to.eq(false);
}
});
});
Expand All @@ -308,7 +327,7 @@ describe("L1MessageQueue", async () => {
);
});

it("should revert, when drop executed message", async () => {
it("should revert, when drop non-skipped message", async () => {
// append 10 messages
for (let i = 0; i < 10; i++) {
await queue.connect(messenger).appendCrossDomainMessage(constants.AddressZero, 1000000, "0x");
Expand All @@ -318,14 +337,13 @@ describe("L1MessageQueue", async () => {
.to.emit(queue, "DequeueTransaction")
.withArgs(0, 5, 0);
for (let i = 0; i < 5; i++) {
expect(await queue.getCrossDomainMessage(i)).to.eq(constants.HashZero);
expect(await queue.isMessageSkipped(i)).to.eq(false);
expect(await queue.isMessageDropped(i)).to.eq(false);
}
expect(await queue.pendingQueueIndex()).to.eq(5);

for (let i = 0; i < 5; i++) {
await expect(queue.connect(messenger).dropCrossDomainMessage(i)).to.revertedWith(
"message already dropped or executed"
);
await expect(queue.connect(messenger).dropCrossDomainMessage(i)).to.revertedWith("drop non-skipped message");
}

// drop pending message
Expand All @@ -345,9 +363,12 @@ describe("L1MessageQueue", async () => {
.withArgs(0, 10, 0x3ff);

for (let i = 0; i < 10; i++) {
expect(BigNumber.from(await queue.getCrossDomainMessage(i))).to.gt(constants.Zero);
expect(await queue.isMessageSkipped(i)).to.eq(true);
expect(await queue.isMessageDropped(i)).to.eq(false);
await expect(queue.connect(messenger).dropCrossDomainMessage(i)).to.emit(queue, "DropTransaction").withArgs(i);
expect(await queue.getCrossDomainMessage(i)).to.eq(constants.HashZero);
await expect(queue.connect(messenger).dropCrossDomainMessage(i)).to.revertedWith("message already dropped");
expect(await queue.isMessageSkipped(i)).to.eq(true);
expect(await queue.isMessageDropped(i)).to.eq(true);
}
});
});
Expand Down
8 changes: 8 additions & 0 deletions contracts/src/L1/rollup/IL1MessageQueue.sol
Expand Up @@ -72,6 +72,14 @@ interface IL1MessageQueue {
bytes calldata data
) external view returns (bytes32);

/// @notice Return whether the message is skipped.
/// @param queueIndex The queue index of the message to check.
function isMessageSkipped(uint256 queueIndex) external view returns (bool);

/// @notice Return whether the message is dropped.
/// @param queueIndex The queue index of the message to check.
function isMessageDropped(uint256 queueIndex) external view returns (bool);

/*****************************
* Public Mutating Functions *
*****************************/
Expand Down
64 changes: 58 additions & 6 deletions contracts/src/L1/rollup/L1MessageQueue.sol
Expand Up @@ -3,6 +3,7 @@
pragma solidity =0.8.16;

import {OwnableUpgradeable} from "@openzeppelin/contracts-upgradeable/access/OwnableUpgradeable.sol";
import {BitMapsUpgradeable} from "@openzeppelin/contracts-upgradeable/utils/structs/BitMapsUpgradeable.sol";

import {IL2GasPriceOracle} from "./IL2GasPriceOracle.sol";
import {IL1MessageQueue} from "./IL1MessageQueue.sol";
Expand All @@ -17,6 +18,8 @@ import {AddressAliasHelper} from "../../libraries/common/AddressAliasHelper.sol"
/// @notice This contract will hold all L1 to L2 messages.
/// Each appended message is assigned with a unique and increasing `uint256` index.
contract L1MessageQueue is OwnableUpgradeable, IL1MessageQueue {
using BitMapsUpgradeable for BitMapsUpgradeable.BitMap;

/**********
* Events *
**********/
Expand Down Expand Up @@ -61,6 +64,15 @@ contract L1MessageQueue is OwnableUpgradeable, IL1MessageQueue {
/// @notice The max gas limit of L1 transactions.
uint256 public maxGasLimit;

/// @dev The bitmap for skipped messages.
BitMapsUpgradeable.BitMap private droppedMessageBitmap;

/// @dev The bitmap for skipped messages, where `skippedMessageBitmap[i]` keeps the bits from `[i*256, (i+1)*256)`.
mapping(uint256 => uint256) private skippedMessageBitmap;

/// @notice The start queue index when we enable the `skippedMessageBitmap`.
uint256 public bitmapEnabledQueueIndex;
Thegaram marked this conversation as resolved.
Show resolved Hide resolved

/**********************
* Function Modifiers *
**********************/
Expand Down Expand Up @@ -94,6 +106,10 @@ contract L1MessageQueue is OwnableUpgradeable, IL1MessageQueue {
maxGasLimit = _maxGasLimit;
}

function initializeV2() external reinitializer(2) {
Thegaram marked this conversation as resolved.
Show resolved Hide resolved
bitmapEnabledQueueIndex = pendingQueueIndex;
}

/*************************
* Public View Functions *
*************************/
Expand Down Expand Up @@ -256,6 +272,29 @@ contract L1MessageQueue is OwnableUpgradeable, IL1MessageQueue {
return hash;
}

/// @inheritdoc IL1MessageQueue
/// @dev Before the `bitmapEnabledQueueIndex`, if the message is dropped, the function will return `false`.
function isMessageSkipped(uint256 _queueIndex) external view returns (bool) {
if (_queueIndex >= pendingQueueIndex) return false;

if (_queueIndex < bitmapEnabledQueueIndex) {
Thegaram marked this conversation as resolved.
Show resolved Hide resolved
return messageQueue[_queueIndex] != bytes32(0);
} else {
return _getBitmap(_queueIndex);
}
}

/// @inheritdoc IL1MessageQueue
function isMessageDropped(uint256 _queueIndex) external view returns (bool) {
if (_queueIndex < bitmapEnabledQueueIndex) {
// @note This will also include the executed messages.
return messageQueue[_queueIndex] == bytes32(0);
} else {
// it should be a skipped message first.
return _getBitmap(_queueIndex) && droppedMessageBitmap.get(_queueIndex);
}
}

/*****************************
* Public Mutating Functions *
*****************************/
Expand Down Expand Up @@ -305,10 +344,15 @@ contract L1MessageQueue is OwnableUpgradeable, IL1MessageQueue {
require(pendingQueueIndex == _startIndex, "start index mismatch");

unchecked {
for (uint256 i = 0; i < _count; i++) {
if ((_skippedBitmap >> i) & 1 == 0) {
messageQueue[_startIndex + i] = bytes32(0);
}
// clear extra bits in `_skippedBitmap`, and if _count = 256, the overflow is designed
zimpha marked this conversation as resolved.
Show resolved Hide resolved
uint256 mask = (1 << _count) - 1;
_skippedBitmap &= mask;

uint256 bucket = _startIndex >> 8;
uint256 offset = _startIndex & 0xff;
skippedMessageBitmap[bucket] |= _skippedBitmap << offset;
if (offset + _count > 256) {
Thegaram marked this conversation as resolved.
Show resolved Hide resolved
skippedMessageBitmap[bucket + 1] = _skippedBitmap >> (256 - offset);
}

pendingQueueIndex = _startIndex + _count;
Expand All @@ -320,9 +364,10 @@ contract L1MessageQueue is OwnableUpgradeable, IL1MessageQueue {
/// @inheritdoc IL1MessageQueue
function dropCrossDomainMessage(uint256 _index) external onlyMessenger {
require(_index < pendingQueueIndex, "cannot drop pending message");
require(messageQueue[_index] != bytes32(0), "message already dropped or executed");

messageQueue[_index] = bytes32(0);
require(_getBitmap(_index), "drop non-skipped message");
require(!droppedMessageBitmap.get(_index), "message already dropped");
droppedMessageBitmap.set(_index);

emit DropTransaction(_index);
}
Expand Down Expand Up @@ -393,4 +438,11 @@ contract L1MessageQueue is OwnableUpgradeable, IL1MessageQueue {
uint256 intrinsicGas = calculateIntrinsicGasFee(_calldata);
require(_gasLimit >= intrinsicGas, "Insufficient gas limit, must be above intrinsic gas");
}

/// @dev Returns whether the bit at `index` is set.
function _getBitmap(uint256 index) internal view returns (bool) {
zimpha marked this conversation as resolved.
Show resolved Hide resolved
uint256 bucket = index >> 8;
uint256 mask = 1 << (index & 0xff);
return skippedMessageBitmap[bucket] & mask != 0;
}
}
20 changes: 14 additions & 6 deletions contracts/src/test/L1ScrollMessengerTest.t.sol
Expand Up @@ -271,13 +271,15 @@ contract L1ScrollMessengerTest is L1GatewayTestBase {
assertEq(messageQueue.pendingQueueIndex(), 2);
hevm.stopPrank();
for (uint256 i = 0; i < 2; ++i) {
assertGt(uint256(messageQueue.getCrossDomainMessage(i)), 0);
assertBoolEq(messageQueue.isMessageSkipped(i), true);
assertBoolEq(messageQueue.isMessageDropped(i), false);
}
hevm.expectEmit(false, false, false, true);
emit OnDropMessageCalled(new bytes(0));
l1Messenger.dropMessage(address(this), address(0), 0, 0, new bytes(0));
for (uint256 i = 0; i < 2; ++i) {
assertEq(messageQueue.getCrossDomainMessage(i), bytes32(0));
assertBoolEq(messageQueue.isMessageSkipped(i), true);
assertBoolEq(messageQueue.isMessageDropped(i), true);
}

// send one message with nonce 2 and replay 3 times
Expand All @@ -293,9 +295,13 @@ contract L1ScrollMessengerTest is L1GatewayTestBase {
messageQueue.popCrossDomainMessage(2, 4, 0x7);
assertEq(messageQueue.pendingQueueIndex(), 6);
hevm.stopPrank();
for (uint256 i = 2; i < 6; i++) {
assertBoolEq(messageQueue.isMessageSkipped(i), i < 5);
assertBoolEq(messageQueue.isMessageDropped(i), false);
}

// message already dropped or executed, revert
hevm.expectRevert("message already dropped or executed");
// drop non-skipped message, revert
hevm.expectRevert("drop non-skipped message");
l1Messenger.dropMessage(address(this), address(0), 0, 2, new bytes(0));

// send one message with nonce 6 and replay 4 times
Expand All @@ -311,13 +317,15 @@ contract L1ScrollMessengerTest is L1GatewayTestBase {
assertEq(messageQueue.pendingQueueIndex(), 11);
hevm.stopPrank();
for (uint256 i = 6; i < 11; ++i) {
assertGt(uint256(messageQueue.getCrossDomainMessage(i)), 0);
assertBoolEq(messageQueue.isMessageSkipped(i), true);
assertBoolEq(messageQueue.isMessageDropped(i), false);
}
hevm.expectEmit(false, false, false, true);
emit OnDropMessageCalled(new bytes(0));
l1Messenger.dropMessage(address(this), address(0), 0, 6, new bytes(0));
for (uint256 i = 6; i < 11; ++i) {
assertEq(messageQueue.getCrossDomainMessage(i), bytes32(0));
assertBoolEq(messageQueue.isMessageSkipped(i), true);
assertBoolEq(messageQueue.isMessageDropped(i), true);
}

// Message already dropped, revert
Expand Down
12 changes: 6 additions & 6 deletions contracts/src/test/ScrollChain.t.sol
Expand Up @@ -328,7 +328,7 @@ contract ScrollChainTest is DSTestPlus {
assertEq(rollup.finalizedStateRoots(1), bytes32(uint256(2)));
assertEq(rollup.withdrawRoots(1), bytes32(uint256(3)));
assertEq(rollup.lastFinalizedBatchIndex(), 1);
assertEq(messageQueue.getCrossDomainMessage(0), bytes32(0));
assertBoolEq(messageQueue.isMessageSkipped(0), false);
assertEq(messageQueue.pendingQueueIndex(), 1);

// commit batch2 with two chunks, correctly
Expand Down Expand Up @@ -453,22 +453,22 @@ contract ScrollChainTest is DSTestPlus {
assertEq(messageQueue.pendingQueueIndex(), 265);
// 1 ~ 4, zero
for (uint256 i = 1; i < 4; i++) {
assertEq(messageQueue.getCrossDomainMessage(i), bytes32(0));
assertBoolEq(messageQueue.isMessageSkipped(i), false);
}
// 4 ~ 9, even is nonzero, odd is zero
for (uint256 i = 4; i < 9; i++) {
if (i % 2 == 1 || i == 8) {
assertEq(messageQueue.getCrossDomainMessage(i), bytes32(0));
assertBoolEq(messageQueue.isMessageSkipped(i), false);
} else {
assertGt(uint256(messageQueue.getCrossDomainMessage(i)), 0);
assertBoolEq(messageQueue.isMessageSkipped(i), true);
}
}
// 9 ~ 265, even is nonzero, odd is zero
for (uint256 i = 9; i < 265; i++) {
if (i % 2 == 1 || i == 264) {
assertEq(messageQueue.getCrossDomainMessage(i), bytes32(0));
assertBoolEq(messageQueue.isMessageSkipped(i), false);
} else {
assertGt(uint256(messageQueue.getCrossDomainMessage(i)), 0);
assertBoolEq(messageQueue.isMessageSkipped(i), true);
}
}
}
Expand Down