Skip to content
This repository has been archived by the owner on Jan 5, 2024. It is now read-only.

Commit

Permalink
🪡 Fix typing issues.
Browse files Browse the repository at this point in the history
This should silence warnings.
  • Loading branch information
rentruewang committed Jan 5, 2024
1 parent bb27b7b commit 954af8d
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 4,457 deletions.
5 changes: 2 additions & 3 deletions perbert/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import os
from enum import Enum

from typing_extensions import Self
from typing import Any


class StrEnum(str, Enum):
Expand All @@ -13,7 +12,7 @@ def __str__(self) -> str:
def __hash__(self) -> int:
return hash(str(self))

def __eq__(self, other: Self | str) -> bool:
def __eq__(self, other: Any) -> bool:
return str(self) == str(other)


Expand Down
6 changes: 4 additions & 2 deletions perbert/models/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pyright: reportIncompatibleMethodOverride=false
from __future__ import annotations

import typing
Expand Down Expand Up @@ -58,8 +59,8 @@ def __init__(self, cfg: DictConfig) -> None:
def bert_config(self) -> BertConfig:
return self.lm.bert.config

def forward(self, **kwargs: Any) -> BertOutput:
return self.lm(**kwargs)
def forward(self, *args: Any, **kwargs: Any) -> BertOutput:
return self.lm(*args, **kwargs)

def _step(self, batch: BatchEncoding, batch_idx: int, name: str) -> Tensor:
loguru.logger.trace("{} step batch: {}", name, batch_idx)
Expand Down Expand Up @@ -91,6 +92,7 @@ def training_step(self, batch: BatchEncoding, batch_idx: int) -> Tensor:

@torch.no_grad()
def test_step(self, batch: BatchEncoding, batch_idx: int) -> Tensor:
super().test_step
return self._step(batch, batch_idx=batch_idx, name="test")

@torch.no_grad()
Expand Down
Loading

0 comments on commit 954af8d

Please sign in to comment.