Skip to content

Commit

Permalink
feat: add reentrancy guard to accountManager (#237)
Browse files Browse the repository at this point in the history
* feat: add reentrancy gaurd to accountManager

* feat: fork reentrancy gaurd

* fix: update modifiers

* fix: move to internal _repay function
  • Loading branch information
r0ohafza committed Oct 18, 2022
1 parent 2fbdb54 commit 7bfa8d4
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 17 deletions.
50 changes: 33 additions & 17 deletions src/core/AccountManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ import {IRiskEngine} from "../interface/core/IRiskEngine.sol";
import {IAccountFactory} from "../interface/core/IAccountFactory.sol";
import {IAccountManager} from "../interface/core/IAccountManager.sol";
import {IControllerFacade} from "controller/core/IControllerFacade.sol";
import {ReentrancyGuard} from "../utils/ReentrancyGuard.sol";

/**
@title Account Manager
@notice Sentiment Account Manager,
All account interactions go via the account manager
*/
contract AccountManager is Pausable, IAccountManager {
contract AccountManager is ReentrancyGuard, Pausable, IAccountManager {
using Helpers for address;

/* -------------------------------------------------------------------------- */
Expand Down Expand Up @@ -67,6 +68,7 @@ contract AccountManager is Pausable, IAccountManager {
*/
function init(IRegistry _registry) external {
if (initialized) revert Errors.ContractAlreadyInitialized();
locked = 1;
initialized = true;
initPausable(msg.sender);
registry = _registry;
Expand All @@ -87,7 +89,7 @@ contract AccountManager is Pausable, IAccountManager {
Emits AccountAssigned(account, owner) event
@param owner Owner of the newly opened account
*/
function openAccount(address owner) external whenNotPaused {
function openAccount(address owner) external nonReentrant whenNotPaused {
if (owner == address(0)) revert Errors.ZeroAddress();
address account;
uint length = inactiveAccountsOf[owner].length;
Expand All @@ -110,7 +112,7 @@ contract AccountManager is Pausable, IAccountManager {
Emits AccountClosed(account, owner) event
@param _account Address of account to be closed
*/
function closeAccount(address _account) public onlyOwner(_account) {
function closeAccount(address _account) public nonReentrant onlyOwner(_account) {
IAccount account = IAccount(_account);
if (account.activationBlock() == block.number)
revert Errors.AccountDeactivationFailure();
Expand All @@ -129,6 +131,7 @@ contract AccountManager is Pausable, IAccountManager {
function depositEth(address account)
external
payable
nonReentrant
whenNotPaused
onlyOwner(account)
{
Expand All @@ -144,6 +147,7 @@ contract AccountManager is Pausable, IAccountManager {
*/
function withdrawEth(address account, uint amt)
external
nonReentrant
onlyOwner(account)
{
if(!riskEngine.isWithdrawAllowed(account, address(0), amt))
Expand All @@ -161,6 +165,7 @@ contract AccountManager is Pausable, IAccountManager {
*/
function deposit(address account, address token, uint amt)
external
nonReentrant
whenNotPaused
onlyOwner(account)
{
Expand All @@ -182,6 +187,7 @@ contract AccountManager is Pausable, IAccountManager {
*/
function withdraw(address account, address token, uint amt)
external
nonReentrant
onlyOwner(account)
{
if (!riskEngine.isWithdrawAllowed(account, token, amt))
Expand All @@ -202,6 +208,7 @@ contract AccountManager is Pausable, IAccountManager {
*/
function borrow(address account, address token, uint amt)
external
nonReentrant
whenNotPaused
onlyOwner(account)
{
Expand All @@ -226,19 +233,10 @@ contract AccountManager is Pausable, IAccountManager {
*/
function repay(address account, address token, uint amt)
public
nonReentrant
onlyOwner(account)
{
ILToken LToken = ILToken(registry.LTokenFor(token));
if (address(LToken) == address(0))
revert Errors.LTokenUnavailable();
LToken.updateState();
if (amt == type(uint256).max) amt = LToken.getBorrowBalance(account);
account.withdraw(address(LToken), token, amt);
if (LToken.collectFrom(account, amt))
IAccount(account).removeBorrow(token);
if (IERC20(token).balanceOf(account) == 0)
IAccount(account).removeAsset(token);
emit Repay(account, msg.sender, token, amt);
_repay(account, token, amt);
}

/**
Expand All @@ -247,7 +245,7 @@ contract AccountManager is Pausable, IAccountManager {
Emits AccountLiquidated(account, owner) event
@param account Address of account
*/
function liquidate(address account) external {
function liquidate(address account) external nonReentrant {
if (riskEngine.isAccountHealthy(account))
revert Errors.AccountNotLiquidatable();
_liquidate(account);
Expand All @@ -270,6 +268,7 @@ contract AccountManager is Pausable, IAccountManager {
uint amt
)
external
nonReentrant
onlyOwner(account)
{
if(address(controller.controllerFor(spender)) == address(0))
Expand All @@ -294,6 +293,7 @@ contract AccountManager is Pausable, IAccountManager {
bytes calldata data
)
external
nonReentrant
onlyOwner(account)
{
bool isAllowed;
Expand All @@ -315,10 +315,10 @@ contract AccountManager is Pausable, IAccountManager {
@notice Settles an account by repaying all the loans
@param account Address of account
*/
function settle(address account) external onlyOwner(account) {
function settle(address account) external nonReentrant onlyOwner(account) {
address[] memory borrows = IAccount(account).getBorrows();
for (uint i; i < borrows.length; i++) {
repay(account, borrows[i], type(uint).max);
_repay(account, borrows[i], type(uint).max);
}
}

Expand Down Expand Up @@ -381,6 +381,22 @@ contract AccountManager is Pausable, IAccountManager {
account.sweepTo(msg.sender);
}

function _repay(address account, address token, uint amt)
internal
{
ILToken LToken = ILToken(registry.LTokenFor(token));
if (address(LToken) == address(0))
revert Errors.LTokenUnavailable();
LToken.updateState();
if (amt == type(uint256).max) amt = LToken.getBorrowBalance(account);
account.withdraw(address(LToken), token, amt);
if (LToken.collectFrom(account, amt))
IAccount(account).removeBorrow(token);
if (IERC20(token).balanceOf(account) == 0)
IAccount(account).removeAsset(token);
emit Repay(account, msg.sender, token, amt);
}

/* -------------------------------------------------------------------------- */
/* ADMIN FUNCTIONS */
/* -------------------------------------------------------------------------- */
Expand Down
18 changes: 18 additions & 0 deletions src/utils/ReentrancyGuard.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// SPDX-License-Identifier: AGPL-3.0-only
pragma solidity >=0.8.0;

/// @notice Gas optimized reentrancy protection for smart contracts.
/// @author Modified from Solmate (https://github.com/transmissions11/solmate/blob/main/src/utils/ReentrancyGuard.sol)
contract ReentrancyGuard {
uint256 internal locked;

modifier nonReentrant() virtual {
require(locked == 1, "REENTRANCY");

locked = 2;

_;

locked = 1;
}
}

0 comments on commit 7bfa8d4

Please sign in to comment.