Skip to content

Commit

Permalink
feat(contacts): use bitmap for skipped messages (#893)
Browse files Browse the repository at this point in the history
Co-authored-by: zimpha <zimpha@users.noreply.github.com>
  • Loading branch information
zimpha and zimpha committed Sep 7, 2023
1 parent 49b72bd commit 7559dc4
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 31 deletions.
47 changes: 34 additions & 13 deletions contracts/integration-test/L1MessageQueue.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ describe("L1MessageQueue", async () => {
});

it("should succeed", async () => {
// append 100 messages
for (let i = 0; i < 100; i++) {
// append 512 messages
for (let i = 0; i < 256 * 2; i++) {
await queue.connect(messenger).appendCrossDomainMessage(constants.AddressZero, 1000000, "0x");
}

Expand All @@ -274,7 +274,8 @@ describe("L1MessageQueue", async () => {
.to.emit(queue, "DequeueTransaction")
.withArgs(0, 50, 0);
for (let i = 0; i < 50; i++) {
expect(await queue.getCrossDomainMessage(i)).to.eq(constants.HashZero);
expect(await queue.isMessageSkipped(i)).to.eq(false);
expect(await queue.isMessageDropped(i)).to.eq(false);
}
expect(await queue.pendingQueueIndex()).to.eq(50);

Expand All @@ -284,7 +285,8 @@ describe("L1MessageQueue", async () => {
.withArgs(50, 10, 1023);
expect(await queue.pendingQueueIndex()).to.eq(60);
for (let i = 50; i < 60; i++) {
expect(BigNumber.from(await queue.getCrossDomainMessage(i))).to.gt(constants.Zero);
expect(await queue.isMessageSkipped(i)).to.eq(true);
expect(await queue.isMessageDropped(i)).to.eq(false);
}

// pop 20 messages, skip first 5
Expand All @@ -293,10 +295,27 @@ describe("L1MessageQueue", async () => {
.withArgs(60, 20, 31);
expect(await queue.pendingQueueIndex()).to.eq(80);
for (let i = 60; i < 65; i++) {
expect(BigNumber.from(await queue.getCrossDomainMessage(i))).to.gt(constants.Zero);
expect(await queue.isMessageSkipped(i)).to.eq(true);
expect(await queue.isMessageDropped(i)).to.eq(false);
}
for (let i = 65; i < 80; i++) {
expect(await queue.getCrossDomainMessage(i)).to.eq(constants.HashZero);
expect(await queue.isMessageSkipped(i)).to.eq(false);
expect(await queue.isMessageDropped(i)).to.eq(false);
}

// pop 256 messages with random skip
const bitmap = BigNumber.from("0x496525059c3f33758d17030403e45afe067b8a0ae1317cda0487fd2932cbea1a");
const tx = await queue.connect(scrollChain).popCrossDomainMessage(80, 256, bitmap);
await expect(tx).to.emit(queue, "DequeueTransaction").withArgs(80, 256, bitmap);
console.log("gas used:", (await tx.wait()).gasUsed.toString());
for (let i = 80; i < 80 + 256; i++) {
expect(await queue.isMessageSkipped(i)).to.eq(
bitmap
.shr(i - 80)
.and(1)
.eq(1)
);
expect(await queue.isMessageDropped(i)).to.eq(false);
}
});
});
Expand All @@ -308,7 +327,7 @@ describe("L1MessageQueue", async () => {
);
});

it("should revert, when drop executed message", async () => {
it("should revert, when drop non-skipped message", async () => {
// append 10 messages
for (let i = 0; i < 10; i++) {
await queue.connect(messenger).appendCrossDomainMessage(constants.AddressZero, 1000000, "0x");
Expand All @@ -318,14 +337,13 @@ describe("L1MessageQueue", async () => {
.to.emit(queue, "DequeueTransaction")
.withArgs(0, 5, 0);
for (let i = 0; i < 5; i++) {
expect(await queue.getCrossDomainMessage(i)).to.eq(constants.HashZero);
expect(await queue.isMessageSkipped(i)).to.eq(false);
expect(await queue.isMessageDropped(i)).to.eq(false);
}
expect(await queue.pendingQueueIndex()).to.eq(5);

for (let i = 0; i < 5; i++) {
await expect(queue.connect(messenger).dropCrossDomainMessage(i)).to.revertedWith(
"message already dropped or executed"
);
await expect(queue.connect(messenger).dropCrossDomainMessage(i)).to.revertedWith("drop non-skipped message");
}

// drop pending message
Expand All @@ -345,9 +363,12 @@ describe("L1MessageQueue", async () => {
.withArgs(0, 10, 0x3ff);

for (let i = 0; i < 10; i++) {
expect(BigNumber.from(await queue.getCrossDomainMessage(i))).to.gt(constants.Zero);
expect(await queue.isMessageSkipped(i)).to.eq(true);
expect(await queue.isMessageDropped(i)).to.eq(false);
await expect(queue.connect(messenger).dropCrossDomainMessage(i)).to.emit(queue, "DropTransaction").withArgs(i);
expect(await queue.getCrossDomainMessage(i)).to.eq(constants.HashZero);
await expect(queue.connect(messenger).dropCrossDomainMessage(i)).to.revertedWith("message already dropped");
expect(await queue.isMessageSkipped(i)).to.eq(true);
expect(await queue.isMessageDropped(i)).to.eq(true);
}
});
});
Expand Down
8 changes: 8 additions & 0 deletions contracts/src/L1/rollup/IL1MessageQueue.sol
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ interface IL1MessageQueue {
bytes calldata data
) external view returns (bytes32);

/// @notice Return whether the message is skipped.
/// @param queueIndex The queue index of the message to check.
function isMessageSkipped(uint256 queueIndex) external view returns (bool);

/// @notice Return whether the message is dropped.
/// @param queueIndex The queue index of the message to check.
function isMessageDropped(uint256 queueIndex) external view returns (bool);

/*****************************
* Public Mutating Functions *
*****************************/
Expand Down
47 changes: 41 additions & 6 deletions contracts/src/L1/rollup/L1MessageQueue.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
pragma solidity =0.8.16;

import {OwnableUpgradeable} from "@openzeppelin/contracts-upgradeable/access/OwnableUpgradeable.sol";
import {BitMapsUpgradeable} from "@openzeppelin/contracts-upgradeable/utils/structs/BitMapsUpgradeable.sol";

import {IL2GasPriceOracle} from "./IL2GasPriceOracle.sol";
import {IL1MessageQueue} from "./IL1MessageQueue.sol";
Expand All @@ -17,6 +18,8 @@ import {AddressAliasHelper} from "../../libraries/common/AddressAliasHelper.sol"
/// @notice This contract will hold all L1 to L2 messages.
/// Each appended message is assigned with a unique and increasing `uint256` index.
contract L1MessageQueue is OwnableUpgradeable, IL1MessageQueue {
using BitMapsUpgradeable for BitMapsUpgradeable.BitMap;

/**********
* Events *
**********/
Expand Down Expand Up @@ -61,6 +64,12 @@ contract L1MessageQueue is OwnableUpgradeable, IL1MessageQueue {
/// @notice The max gas limit of L1 transactions.
uint256 public maxGasLimit;

/// @dev The bitmap for skipped messages.
BitMapsUpgradeable.BitMap private droppedMessageBitmap;

/// @dev The bitmap for skipped messages, where `skippedMessageBitmap[i]` keeps the bits from `[i*256, (i+1)*256)`.
mapping(uint256 => uint256) private skippedMessageBitmap;

/**********************
* Function Modifiers *
**********************/
Expand Down Expand Up @@ -256,6 +265,19 @@ contract L1MessageQueue is OwnableUpgradeable, IL1MessageQueue {
return hash;
}

/// @inheritdoc IL1MessageQueue
function isMessageSkipped(uint256 _queueIndex) external view returns (bool) {
if (_queueIndex >= pendingQueueIndex) return false;

return _isMessageSkipped(_queueIndex);
}

/// @inheritdoc IL1MessageQueue
function isMessageDropped(uint256 _queueIndex) external view returns (bool) {
// it should be a skipped message first.
return _isMessageSkipped(_queueIndex) && droppedMessageBitmap.get(_queueIndex);
}

/*****************************
* Public Mutating Functions *
*****************************/
Expand Down Expand Up @@ -305,10 +327,15 @@ contract L1MessageQueue is OwnableUpgradeable, IL1MessageQueue {
require(pendingQueueIndex == _startIndex, "start index mismatch");

unchecked {
for (uint256 i = 0; i < _count; i++) {
if ((_skippedBitmap >> i) & 1 == 0) {
messageQueue[_startIndex + i] = bytes32(0);
}
// clear extra bits in `_skippedBitmap`, and if _count = 256, it's designed to overflow.
uint256 mask = (1 << _count) - 1;
_skippedBitmap &= mask;

uint256 bucket = _startIndex >> 8;
uint256 offset = _startIndex & 0xff;
skippedMessageBitmap[bucket] |= _skippedBitmap << offset;
if (offset + _count > 256) {
skippedMessageBitmap[bucket + 1] = _skippedBitmap >> (256 - offset);
}

pendingQueueIndex = _startIndex + _count;
Expand All @@ -320,9 +347,10 @@ contract L1MessageQueue is OwnableUpgradeable, IL1MessageQueue {
/// @inheritdoc IL1MessageQueue
function dropCrossDomainMessage(uint256 _index) external onlyMessenger {
require(_index < pendingQueueIndex, "cannot drop pending message");
require(messageQueue[_index] != bytes32(0), "message already dropped or executed");

messageQueue[_index] = bytes32(0);
require(_isMessageSkipped(_index), "drop non-skipped message");
require(!droppedMessageBitmap.get(_index), "message already dropped");
droppedMessageBitmap.set(_index);

emit DropTransaction(_index);
}
Expand Down Expand Up @@ -393,4 +421,11 @@ contract L1MessageQueue is OwnableUpgradeable, IL1MessageQueue {
uint256 intrinsicGas = calculateIntrinsicGasFee(_calldata);
require(_gasLimit >= intrinsicGas, "Insufficient gas limit, must be above intrinsic gas");
}

/// @dev Returns whether the bit at `index` is set.
function _isMessageSkipped(uint256 index) internal view returns (bool) {
uint256 bucket = index >> 8;
uint256 mask = 1 << (index & 0xff);
return skippedMessageBitmap[bucket] & mask != 0;
}
}
20 changes: 14 additions & 6 deletions contracts/src/test/L1ScrollMessengerTest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,15 @@ contract L1ScrollMessengerTest is L1GatewayTestBase {
assertEq(messageQueue.pendingQueueIndex(), 2);
hevm.stopPrank();
for (uint256 i = 0; i < 2; ++i) {
assertGt(uint256(messageQueue.getCrossDomainMessage(i)), 0);
assertBoolEq(messageQueue.isMessageSkipped(i), true);
assertBoolEq(messageQueue.isMessageDropped(i), false);
}
hevm.expectEmit(false, false, false, true);
emit OnDropMessageCalled(new bytes(0));
l1Messenger.dropMessage(address(this), address(0), 0, 0, new bytes(0));
for (uint256 i = 0; i < 2; ++i) {
assertEq(messageQueue.getCrossDomainMessage(i), bytes32(0));
assertBoolEq(messageQueue.isMessageSkipped(i), true);
assertBoolEq(messageQueue.isMessageDropped(i), true);
}

// send one message with nonce 2 and replay 3 times
Expand All @@ -293,9 +295,13 @@ contract L1ScrollMessengerTest is L1GatewayTestBase {
messageQueue.popCrossDomainMessage(2, 4, 0x7);
assertEq(messageQueue.pendingQueueIndex(), 6);
hevm.stopPrank();
for (uint256 i = 2; i < 6; i++) {
assertBoolEq(messageQueue.isMessageSkipped(i), i < 5);
assertBoolEq(messageQueue.isMessageDropped(i), false);
}

// message already dropped or executed, revert
hevm.expectRevert("message already dropped or executed");
// drop non-skipped message, revert
hevm.expectRevert("drop non-skipped message");
l1Messenger.dropMessage(address(this), address(0), 0, 2, new bytes(0));

// send one message with nonce 6 and replay 4 times
Expand All @@ -311,13 +317,15 @@ contract L1ScrollMessengerTest is L1GatewayTestBase {
assertEq(messageQueue.pendingQueueIndex(), 11);
hevm.stopPrank();
for (uint256 i = 6; i < 11; ++i) {
assertGt(uint256(messageQueue.getCrossDomainMessage(i)), 0);
assertBoolEq(messageQueue.isMessageSkipped(i), true);
assertBoolEq(messageQueue.isMessageDropped(i), false);
}
hevm.expectEmit(false, false, false, true);
emit OnDropMessageCalled(new bytes(0));
l1Messenger.dropMessage(address(this), address(0), 0, 6, new bytes(0));
for (uint256 i = 6; i < 11; ++i) {
assertEq(messageQueue.getCrossDomainMessage(i), bytes32(0));
assertBoolEq(messageQueue.isMessageSkipped(i), true);
assertBoolEq(messageQueue.isMessageDropped(i), true);
}

// Message already dropped, revert
Expand Down
12 changes: 6 additions & 6 deletions contracts/src/test/ScrollChain.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ contract ScrollChainTest is DSTestPlus {
assertEq(rollup.finalizedStateRoots(1), bytes32(uint256(2)));
assertEq(rollup.withdrawRoots(1), bytes32(uint256(3)));
assertEq(rollup.lastFinalizedBatchIndex(), 1);
assertEq(messageQueue.getCrossDomainMessage(0), bytes32(0));
assertBoolEq(messageQueue.isMessageSkipped(0), false);
assertEq(messageQueue.pendingQueueIndex(), 1);

// commit batch2 with two chunks, correctly
Expand Down Expand Up @@ -462,22 +462,22 @@ contract ScrollChainTest is DSTestPlus {
assertEq(messageQueue.pendingQueueIndex(), 265);
// 1 ~ 4, zero
for (uint256 i = 1; i < 4; i++) {
assertEq(messageQueue.getCrossDomainMessage(i), bytes32(0));
assertBoolEq(messageQueue.isMessageSkipped(i), false);
}
// 4 ~ 9, even is nonzero, odd is zero
for (uint256 i = 4; i < 9; i++) {
if (i % 2 == 1 || i == 8) {
assertEq(messageQueue.getCrossDomainMessage(i), bytes32(0));
assertBoolEq(messageQueue.isMessageSkipped(i), false);
} else {
assertGt(uint256(messageQueue.getCrossDomainMessage(i)), 0);
assertBoolEq(messageQueue.isMessageSkipped(i), true);
}
}
// 9 ~ 265, even is nonzero, odd is zero
for (uint256 i = 9; i < 265; i++) {
if (i % 2 == 1 || i == 264) {
assertEq(messageQueue.getCrossDomainMessage(i), bytes32(0));
assertBoolEq(messageQueue.isMessageSkipped(i), false);
} else {
assertGt(uint256(messageQueue.getCrossDomainMessage(i)), 0);
assertBoolEq(messageQueue.isMessageSkipped(i), true);
}
}
}
Expand Down

0 comments on commit 7559dc4

Please sign in to comment.