Skip to content
This repository has been archived by the owner on Jan 5, 2024. It is now read-only.

Commit

Permalink
🪻 Import from collections.abc instead of typing.
Browse files Browse the repository at this point in the history
This is much better.
  • Loading branch information
rentruewang committed Jan 5, 2024
1 parent b60d94c commit 57021ea
Show file tree
Hide file tree
Showing 12 changed files with 277 additions and 79 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/typecheck.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: ⬇️ Python Poetry
uses: abatilo/actions-poetry@v2
with:
poetry-version: "1.5"
poetry-version: "1.7"

- name: ⬇️ Python Dependencies
run: poetry install
Expand Down
15 changes: 7 additions & 8 deletions perbert/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import os
from enum import Enum

from typing_extensions import Self
from typing import Any


class StrEnum(str, Enum):
Expand All @@ -13,7 +12,7 @@ def __str__(self) -> str:
def __hash__(self) -> int:
return hash(str(self))

def __eq__(self, other: Self | str) -> bool:
def __eq__(self, other: Any) -> bool:
return str(self) == str(other)


Expand All @@ -29,14 +28,14 @@ class Splits(StrEnum):


class CollatorType(StrEnum):
Token = "token"
WholeWord = "wholeword"
TOKEN = "token"
WHOLE_WORD = "wholeword"


class SchedulerAlgo(StrEnum):
Const = "constant"
Bert = "bert"
Step = "step"
CONST = "constant"
BERT = "bert"
STEP = "step"


if (_cpus := os.cpu_count()) is not None:
Expand Down
12 changes: 6 additions & 6 deletions perbert/data/collators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import abc
from typing import Any, Dict, List, Protocol, Type
from typing import Any, Protocol

from transformers import DataCollatorForLanguageModeling, DataCollatorForWholeWordMask
from transformers.tokenization_utils import PreTrainedTokenizer
Expand All @@ -10,14 +10,14 @@

class Collator(Protocol):
@abc.abstractmethod
def __call__(self, encodings: List[Dict[str, Any]]) -> Dict[str, Any]:
def __call__(self, encodings: list[dict[str, Any]]) -> dict[str, Any]:
...


class HuggingfaceCollator(Collator):
def __init__(
self,
klass: Type[DataCollatorForLanguageModeling | DataCollatorForWholeWordMask],
klass: type[DataCollatorForLanguageModeling | DataCollatorForWholeWordMask],
*,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
mask_prob: float,
Expand All @@ -42,15 +42,15 @@ def _get_collator(
return_tensors="pt",
)

def __call__(self, encodings: List[Dict[str, Any]]) -> Dict[str, Any]:
def __call__(self, encodings: list[dict[str, Any]]) -> dict[str, Any]:
collator = self._get_collator()
return collator(encodings)


class DecayCollator(HuggingfaceCollator):
def __init__(
self,
klass: Type[DataCollatorForLanguageModeling | DataCollatorForWholeWordMask],
klass: type[DataCollatorForLanguageModeling | DataCollatorForWholeWordMask],
*,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
mask_prob: float,
Expand All @@ -69,7 +69,7 @@ def __init__(
self._eventual_mask_prob = self.mask_prob
self.mask_prob = 0

def __call__(self, encodings: List[Dict[str, Any]]) -> Dict[str, Any]:
def __call__(self, encodings: list[dict[str, Any]]) -> dict[str, Any]:
prob_diff = self._eventual_mask_prob - self.mask_prob
self.mask_prob = self._eventual_mask_prob - prob_diff * self.base
return super().__call__(encodings=encodings)
4 changes: 2 additions & 2 deletions perbert/data/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def _init_collator(self) -> Collator:

collator_type = CollatorType(self.cfg["model"]["lm"]["collator"])

if collator_type == CollatorType.Token:
if collator_type == CollatorType.TOKEN:
collator_cls = DataCollatorForLanguageModeling
elif collator_type == CollatorType.WholeWord:
elif collator_type == CollatorType.WHOLE_WORD:
collator_cls = DataCollatorForWholeWordMask
else:
raise ValueError(f"Collator type: {collator_type} not supported.")
Expand Down
19 changes: 10 additions & 9 deletions perbert/data/datasets/mappers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import abc
from typing import Any, Callable, Dict, List, Protocol, TypeVar
from collections.abc import Callable
from typing import Any, Protocol, TypeVar

import loguru
import numpy as np
Expand Down Expand Up @@ -36,25 +37,25 @@ def map(
writer_batch_size: int = 1000,
num_proc: int | None = None,
load_from_cache_file: bool = False,
remove_columns: List[str] | None = None,
remove_columns: list[str] | None = None,
desc: str | None = None,
) -> Self:
...

@property
@abc.abstractmethod
def column_names(self) -> List[str] | Dict[str, List[str]]:
def column_names(self) -> list[str] | dict[str, list[str]]:
...

@abc.abstractmethod
def remove_columns(self, column_names: str | List[str]) -> Self:
def remove_columns(self, column_names: str | list[str]) -> Self:
...


def _flat_column_names(mapper: Mappable) -> List[str]:
def _flat_column_names(mapper: Mappable) -> list[str]:
columns = mapper.column_names

if isinstance(columns, List):
if isinstance(columns, list):
return columns

columns = sum(columns.values(), [])
Expand Down Expand Up @@ -96,10 +97,10 @@ def __init__(self, cfg: DictConfig, mapper: T) -> None:
self.max_length = data_cfg["max_length"]
self.tokenizer = AutoTokenizer.from_pretrained(data_cfg["tokenizer"])

def _filter_empty(self, entry: Dict[str, str]) -> bool:
def _filter_empty(self, entry: dict[str, str]) -> bool:
return len(entry["text"]) > 0

def _line_by_line(self, entries: Dict[str, str]) -> BatchEncoding:
def _line_by_line(self, entries: dict[str, str]) -> BatchEncoding:
return self.tokenizer(
entries["text"],
padding=self.padding,
Expand All @@ -109,7 +110,7 @@ def _line_by_line(self, entries: Dict[str, str]) -> BatchEncoding:
return_tensors="np",
)

def _joined_lines(self, entries: Dict[str, List[str]]) -> Dict[str, List[ndarray]]:
def _joined_lines(self, entries: dict[str, list[str]]) -> dict[str, list[ndarray]]:
joined_line = " ".join(entries["text"])

tokenized = self.tokenizer(
Expand Down
5 changes: 3 additions & 2 deletions perbert/data/datasets/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import abc
from typing import Dict, Generic, Protocol, Sized, TypeVar
from collections.abc import Sized
from typing import Generic, Protocol, TypeVar

import loguru
from datasets import DatasetDict
Expand Down Expand Up @@ -60,7 +61,7 @@ def __getitem__(self, key: int) -> T:
return self._seq[key]


class DatasetDictWrapper(Dict[str, T]):
class DatasetDictWrapper(dict[str, T]):
def __init__(self, dd: DatasetDict) -> None:
super().__init__(dd)

Expand Down
2 changes: 1 addition & 1 deletion perbert/models/init.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from collections.abc import Callable

import loguru
from torch.nn import Embedding, LayerNorm, Linear, Module, init
Expand Down
8 changes: 4 additions & 4 deletions perbert/models/length_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(
max_length: int,
total_steps: int,
min_length: int = 128,
algo: str | SchedulerAlgo = SchedulerAlgo.Bert,
algo: str | SchedulerAlgo = SchedulerAlgo.BERT,
step_interval: int = 10000,
step_size: float = 2.0,
) -> None:
Expand Down Expand Up @@ -37,11 +37,11 @@ def step_schedule(self) -> int:
def step(self) -> int:
self.steps += 1

if self.algo == SchedulerAlgo.Const:
if self.algo == SchedulerAlgo.CONST:
return self.max_length
elif self.algo == SchedulerAlgo.Bert:
elif self.algo == SchedulerAlgo.BERT:
return self.bert_schedule()
elif self.algo == SchedulerAlgo.Step:
elif self.algo == SchedulerAlgo.STEP:
return self.step_schedule()
else:
raise NotImplementedError
12 changes: 7 additions & 5 deletions perbert/models/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# pyright: reportIncompatibleMethodOverride=false
from __future__ import annotations

import typing
from enum import Enum
from typing import Any, Dict, List, Tuple
from typing import Any

import loguru
import torch
Expand Down Expand Up @@ -58,8 +59,8 @@ def __init__(self, cfg: DictConfig) -> None:
def bert_config(self) -> BertConfig:
return self.lm.bert.config

def forward(self, **kwargs: Any) -> BertOutput:
return self.lm(**kwargs)
def forward(self, *args: Any, **kwargs: Any) -> BertOutput:
return self.lm(*args, **kwargs)

def _step(self, batch: BatchEncoding, batch_idx: int, name: str) -> Tensor:
loguru.logger.trace("{} step batch: {}", name, batch_idx)
Expand Down Expand Up @@ -91,13 +92,14 @@ def training_step(self, batch: BatchEncoding, batch_idx: int) -> Tensor:

@torch.no_grad()
def test_step(self, batch: BatchEncoding, batch_idx: int) -> Tensor:
super().test_step
return self._step(batch, batch_idx=batch_idx, name="test")

@torch.no_grad()
def validation_step(self, batch: BatchEncoding, batch_idx: int) -> Tensor:
return self._step(batch, batch_idx=batch_idx, name="validation")

def configure_optimizers(self) -> Tuple[List[Optimizer], List[LambdaLR]]:
def configure_optimizers(self) -> tuple[list[Optimizer], list[LambdaLR]]:
model_cfg = self.cfg["model"]

optim_type = OptimizerType(model_cfg["optimizer"])
Expand Down Expand Up @@ -137,7 +139,7 @@ def configure_optimizers(self) -> Tuple[List[Optimizer], List[LambdaLR]]:
loguru.logger.info("Optimizer: {}", optimizer)
return ([optimizer], [scheduler])

def configure_metrics(self) -> Dict[str, Metric]:
def configure_metrics(self) -> dict[str, Metric]:
metrics = {}

met_cfg = self.cfg["model"]["metrics"]
Expand Down
7 changes: 4 additions & 3 deletions perbert/trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import typing
from typing import Any, List, Mapping
from collections.abc import Mapping
from typing import Any

from aim.pytorch_lightning import AimLogger
from omegaconf import DictConfig
Expand Down Expand Up @@ -29,7 +30,7 @@ def __init__(self, cfg: DictConfig) -> None:
)

@property
def __callbacks(self) -> List[Callback]:
def __callbacks(self) -> list[Callback]:
callback_cfg = self.cfg["callbacks"]

callbacks = []
Expand All @@ -49,7 +50,7 @@ def __callbacks(self) -> List[Callback]:
return callbacks

@property
def __loggers(self) -> List[Logger]:
def __loggers(self) -> list[Logger]:
logger_cfg = self.cfg["loggers"]

loggers = []
Expand Down
Loading

0 comments on commit 57021ea

Please sign in to comment.