-
Notifications
You must be signed in to change notification settings - Fork 63
/
embedder_interface.py
94 lines (75 loc) · 2.94 KB
/
embedder_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
Abstract interface for Embedder.
Authors:
Christian Dallago
"""
import abc
import logging
from typing import List, Generator, Optional, Iterable, ClassVar
from numpy import ndarray
logger = logging.getLogger(__name__)
class EmbedderInterface(object, metaclass=abc.ABCMeta):
name: ClassVar[str]
def __init__(self, **kwargs):
"""
Initializer accepts location of a pre-trained model and options
"""
self._options = None
@classmethod
@abc.abstractmethod
def with_download(cls, **kwargs):
""" Convenience function to create an instance after downloading files. """
raise NotImplementedError
@abc.abstractmethod
def embed(self, sequence: str) -> ndarray:
"""
Returns embedding for one sequence.
:param sequence: Valid amino acid sequence as String
:return: An embedding of the sequence.
"""
raise NotImplementedError
def embed_batch(self, batch: List[str]) -> Generator[ndarray, None, None]:
""" Computes the embeddings from all sequences in the batch
The provided implementation is dummy implementation that should be
overwritten with the appropriate batching method for the model. """
for sequence in batch:
yield self.embed(sequence)
def embed_many(
self, sequences: Iterable[str], batch_size: Optional[int] = None
) -> Generator[ndarray, None, None]:
"""
Returns embedding for one sequence.
:param sequences: List of proteins as AA strings
:param batch_size: For embedders that profit from batching, this is maximum number of AA per batch
:return: A list object with embeddings of the sequences.
"""
if batch_size:
batch = []
length = 0
for sequence in sequences:
if len(sequence) > batch_size:
logger.warning(
f"A sequence is {len(sequence)} residues long, "
f"which is longer than your `batch_size` parameter which is {batch_size}"
)
yield from self.embed_batch([sequence])
continue
if length + len(sequence) >= batch_size:
yield from self.embed_batch(batch)
batch = []
length = 0
batch.append(sequence)
length += len(sequence)
yield from self.embed_batch(batch)
else:
for seq in sequences:
yield self.embed(seq)
@staticmethod
@abc.abstractmethod
def reduce_per_protein(embedding: ndarray) -> ndarray:
"""
For a variable size embedding, returns a fixed size embedding encoding all information of a sequence.
:param embedding: the embedding
:return: A fixed size embedding (a vector of size N, where N is fixed)
"""
raise NotImplementedError