forked from deepchem/deepchem
-
Notifications
You must be signed in to change notification settings - Fork 0
/
layers.py
4341 lines (3636 loc) · 164 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import math
import numpy as np
from typing import Any, Tuple, Optional, Sequence, List, Union, Callable, Dict, TypedDict
from collections.abc import Sequence as SequenceCollection
try:
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
except ModuleNotFoundError:
raise ImportError('These classes require PyTorch to be installed.')
try:
from torch_geometric.utils import scatter
except ModuleNotFoundError:
pass
from deepchem.utils.typing import OneOrMany, ActivationFn, ArrayLike
from deepchem.utils.pytorch_utils import get_activation, segment_sum
from torch.nn import init as initializers
class MultilayerPerceptron(nn.Module):
"""A simple fully connected feed-forward network, otherwise known as a multilayer perceptron (MLP).
Examples
--------
>>> model = MultilayerPerceptron(d_input=10, d_hidden=(2,3), d_output=2, dropout=0.0, activation_fn='relu')
>>> x = torch.ones(2, 10)
>>> out = model(x)
>>> print(out.shape)
torch.Size([2, 2])
"""
def __init__(self,
d_input: int,
d_output: int,
d_hidden: Optional[tuple] = None,
dropout: float = 0.0,
batch_norm: bool = False,
batch_norm_momentum: float = 0.1,
activation_fn: Union[Callable, str] = 'relu',
skip_connection: bool = False,
weighted_skip: bool = True):
"""Initialize the model.
Parameters
----------
d_input: int
the dimension of the input layer
d_output: int
the dimension of the output layer
d_hidden: tuple
the dimensions of the hidden layers
dropout: float
the dropout probability
batch_norm: bool
whether to use batch normalization
batch_norm_momentum: float
the momentum for batch normalization
activation_fn: str
the activation function to use in the hidden layers
skip_connection: bool
whether to add a skip connection from the input to the output
weighted_skip: bool
whether to add a weighted skip connection from the input to the output
"""
super(MultilayerPerceptron, self).__init__()
self.d_input = d_input
self.d_hidden = d_hidden
self.d_output = d_output
self.dropout = nn.Dropout(dropout)
self.batch_norm = batch_norm
self.batch_norm_momentum = batch_norm_momentum
self.activation_fn = get_activation(activation_fn)
self.model = nn.Sequential(*self.build_layers())
self.skip = nn.Linear(d_input, d_output) if skip_connection else None
self.weighted_skip = weighted_skip
def build_layers(self):
"""
Build the layers of the model, iterating through the hidden dimensions to produce a list of layers.
"""
layer_list = []
layer_dim = self.d_input
if self.d_hidden is not None:
for d in self.d_hidden:
layer_list.append(nn.Linear(layer_dim, d))
layer_list.append(self.dropout)
if self.batch_norm:
layer_list.append(
nn.BatchNorm1d(d, momentum=self.batch_norm_momentum))
layer_dim = d
layer_list.append(nn.Linear(layer_dim, self.d_output))
return layer_list
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the model."""
input = x
for layer in self.model:
x = layer(x)
if isinstance(layer, nn.Linear):
x = self.activation_fn(
x
) # Done because activation_fn returns a torch.nn.functional
if self.skip is not None:
if not self.weighted_skip:
return x + input
else:
return x + self.skip(input)
else:
return x
class CNNModule(nn.Module):
"""A 1, 2, or 3 dimensional convolutional network for either regression or classification.
The network consists of the following sequence of layers:
- A configurable number of convolutional layers
- A global pooling layer (either max pool or average pool)
- A final fully connected layer to compute the output
It optionally can compose the model from pre-activation residual blocks, as
described in https://arxiv.org/abs/1603.05027, rather than a simple stack of
convolution layers. This often leads to easier training, especially when using a
large number of layers. Note that residual blocks can only be used when
successive layers have the same output shape. Wherever the output shape changes, a
simple convolution layer will be used even if residual=True.
Examples
--------
>>> model = CNNModule(n_tasks=5, n_features=8, dims=2, layer_filters=[3,8,8,16], kernel_size=3, n_classes = 7, mode='classification', uncertainty=False, padding='same')
>>> x = torch.ones(2, 224, 224, 8)
>>> x = model(x)
>>> for tensor in x:
... print(tensor.shape)
torch.Size([2, 5, 7])
torch.Size([2, 5, 7])
"""
def __init__(self,
n_tasks: int,
n_features: int,
dims: int,
layer_filters: List[int] = [100],
kernel_size: OneOrMany[int] = 5,
strides: OneOrMany[int] = 1,
weight_init_stddevs: OneOrMany[float] = 0.02,
bias_init_consts: OneOrMany[float] = 1.0,
dropouts: OneOrMany[float] = 0.5,
activation_fns: OneOrMany[ActivationFn] = 'relu',
pool_type: str = 'max',
mode: str = 'classification',
n_classes: int = 2,
uncertainty: bool = False,
residual: bool = False,
padding: Union[int, str] = 'valid') -> None:
"""Create a CNN.
Parameters
----------
n_tasks: int
number of tasks
n_features: int
number of features
dims: int
the number of dimensions to apply convolutions over (1, 2, or 3)
layer_filters: list
the number of output filters for each convolutional layer in the network.
The length of this list determines the number of layers.
kernel_size: int, tuple, or list
a list giving the shape of the convolutional kernel for each layer. Each
element may be either an int (use the same kernel width for every dimension)
or a tuple (the kernel width along each dimension). Alternatively this may
be a single int or tuple instead of a list, in which case the same kernel
shape is used for every layer.
strides: int, tuple, or list
a list giving the stride between applications of the kernel for each layer.
Each element may be either an int (use the same stride for every dimension)
or a tuple (the stride along each dimension). Alternatively this may be a
single int or tuple instead of a list, in which case the same stride is
used for every layer.
weight_init_stddevs: list or float
the standard deviation of the distribution to use for weight initialization
of each layer. The length of this list should equal len(layer_filters)+1,
where the final element corresponds to the dense layer. Alternatively this
may be a single value instead of a list, in which case the same value is used
for every layer.
bias_init_consts: list or float
the value to initialize the biases in each layer to. The length of this
list should equal len(layer_filters)+1, where the final element corresponds
to the dense layer. Alternatively this may be a single value instead of a
list, in which case the same value is used for every layer.
dropouts: list or float
the dropout probability to use for each layer. The length of this list should equal len(layer_filters).
Alternatively this may be a single value instead of a list, in which case the same value is used for every layer
activation_fns: str or list
the torch activation function to apply to each layer. The length of this list should equal
len(layer_filters). Alternatively this may be a single value instead of a list, in which case the
same value is used for every layer, 'relu' by default
pool_type: str
the type of pooling layer to use, either 'max' or 'average'
mode: str
Either 'classification' or 'regression'
n_classes: int
the number of classes to predict (only used in classification mode)
uncertainty: bool
if True, include extra outputs and loss terms to enable the uncertainty
in outputs to be predicted
residual: bool
if True, the model will be composed of pre-activation residual blocks instead
of a simple stack of convolutional layers.
padding: str, int or tuple
the padding to use for convolutional layers, either 'valid' or 'same'
"""
super(CNNModule, self).__init__()
if dims not in (1, 2, 3):
raise ValueError('Number of dimensions must be 1, 2 or 3')
if mode not in ['classification', 'regression']:
raise ValueError(
"mode must be either 'classification' or 'regression'")
self.n_tasks = n_tasks
self.n_features = n_features
self.dims = dims
self.mode = mode
self.n_classes = n_classes
self.uncertainty = uncertainty
self.mode = mode
self.layer_filters = layer_filters
self.residual = residual
n_layers = len(layer_filters)
# PyTorch layers require input and output channels as parameter
# if only one layer to make the model creating loop below work, multiply layer_filters wutg 2
if len(layer_filters) == 1:
layer_filters = layer_filters * 2
if not isinstance(kernel_size, SequenceCollection):
kernel_size = [kernel_size] * n_layers
if not isinstance(strides, SequenceCollection):
strides = [strides] * n_layers
if not isinstance(dropouts, SequenceCollection):
dropouts = [dropouts] * n_layers
if isinstance(
activation_fns,
str) or not isinstance(activation_fns, SequenceCollection):
activation_fns = [activation_fns] * n_layers
if not isinstance(weight_init_stddevs, SequenceCollection):
weight_init_stddevs = [weight_init_stddevs] * n_layers
if not isinstance(bias_init_consts, SequenceCollection):
bias_init_consts = [bias_init_consts] * n_layers
self.activation_fns = [get_activation(f) for f in activation_fns]
self.dropouts = dropouts
if uncertainty:
if mode != 'regression':
raise ValueError(
"Uncertainty is only supported in regression mode")
if any(d == 0.0 for d in dropouts):
raise ValueError(
'Dropout must be included in every layer to predict uncertainty'
)
# Python tuples use 0 based indexing, dims defines number of dimension for convolutional operation
ConvLayer = (nn.Conv1d, nn.Conv2d, nn.Conv3d)[self.dims - 1]
if pool_type == 'average':
PoolLayer = (F.avg_pool1d, F.avg_pool2d,
F.avg_pool3d)[self.dims - 1]
elif pool_type == 'max':
PoolLayer = (F.max_pool1d, F.max_pool2d,
F.max_pool3d)[self.dims - 1]
else:
raise ValueError("pool_type must be either 'average' or 'max'")
self.PoolLayer = PoolLayer
self.layers = nn.ModuleList()
in_shape = n_features
for out_shape, size, stride, weight_stddev, bias_const in zip(
layer_filters, kernel_size, strides, weight_init_stddevs,
bias_init_consts):
layer = ConvLayer(in_channels=in_shape,
out_channels=out_shape,
kernel_size=size,
stride=stride,
padding=padding,
dilation=1,
groups=1,
bias=True)
nn.init.normal_(layer.weight, 0, weight_stddev)
# initializing layer bias with nn.init gives mypy typecheck error
# using the following workaround
if layer.bias is not None:
layer.bias = nn.Parameter(
torch.full(layer.bias.shape, bias_const))
self.layers.append(layer)
in_shape = out_shape
self.classifier_ffn = nn.LazyLinear(self.n_tasks * self.n_classes)
self.output_layer = nn.LazyLinear(self.n_tasks)
self.uncertainty_layer = nn.LazyLinear(self.n_tasks)
def forward(self, inputs: OneOrMany[torch.Tensor]) -> List[Any]:
"""
Parameters
----------
x: torch.Tensor
Input Tensor
Returns
-------
torch.Tensor
Output as per use case : regression/classification
"""
if isinstance(inputs, torch.Tensor):
x, dropout_switch = inputs, None
else:
x, dropout_switch = inputs
x = torch.transpose(x, 1, -1) # n h w c -> n c h w
prev_layer = x
for layer, activation_fn, dropout in zip(self.layers,
self.activation_fns,
self.dropouts):
x = layer(x)
if dropout > 0. and dropout_switch:
x = F.dropout(x, dropout)
# residual blocks can only be used when successive layers have the same output shape
if self.residual and x.shape[1] == prev_layer.shape[1]:
x = x + prev_layer
if activation_fn is not None:
x = activation_fn(x)
prev_layer = x
x = self.PoolLayer(x, kernel_size=x.size()[2:])
outputs = []
batch_size = x.shape[0]
x = torch.reshape(x, (batch_size, -1))
if self.mode == "classification":
logits = self.classifier_ffn(x)
logits = logits.view(batch_size, self.n_tasks, self.n_classes)
output = F.softmax(logits, dim=2)
outputs = [output, logits]
else:
output = self.output_layer(x)
output = output.view(batch_size, self.n_tasks)
if self.uncertainty:
log_var = self.uncertainty_layer(x)
log_var = log_var.view(batch_size, self.n_tasks, 1)
var = torch.exp(log_var)
outputs = [output, var, output, log_var]
else:
outputs = [output]
return outputs
class ScaleNorm(nn.Module):
"""Apply Scale Normalization to input.
The ScaleNorm layer first computes the square root of the scale, then computes the matrix/vector norm of the input tensor.
The norm value is calculated as `sqrt(scale) / matrix norm`.
Finally, the result is returned as `input_tensor * norm value`.
This layer can be used instead of LayerNorm when a scaled version of the norm is required.
Instead of performing the scaling operation (`scale / norm`) in a lambda-like layer, we are defining it within this layer to make prototyping more efficient.
References
----------
.. [1] Lukasz Maziarka et al. "Molecule Attention Transformer" Graph Representation Learning workshop and Machine Learning and the Physical Sciences workshop at NeurIPS 2019. 2020. https://arxiv.org/abs/2002.08264
Examples
--------
>>> from deepchem.models.torch_models.layers import ScaleNorm
>>> scale = 0.35
>>> layer = ScaleNorm(scale)
>>> input_tensor = torch.tensor([[1.269, 39.36], [0.00918, -9.12]])
>>> output_tensor = layer(input_tensor)
"""
def __init__(self, scale: float, eps: float = 1e-5):
"""Initialize a ScaleNorm layer.
Parameters
----------
scale: float
Scale magnitude.
eps: float
Epsilon value. Default = 1e-5.
"""
super(ScaleNorm, self).__init__()
self.scale = nn.Parameter(torch.tensor(math.sqrt(scale)))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = self.scale / torch.norm(x, dim=-1,
keepdim=True).clamp(min=self.eps)
return x * norm
class MultiHeadedMATAttention(nn.Module):
"""First constructs an attention layer tailored to the Molecular Attention Transformer [1]_ and then converts it into Multi-Headed Attention.
In Multi-Headed attention the attention mechanism multiple times parallely through the multiple attention heads.
Thus, different subsequences of a given sequences can be processed differently.
The query, key and value parameters are split multiple ways and each split is passed separately through a different attention head.
References
----------
.. [1] Lukasz Maziarka et al. "Molecule Attention Transformer" Graph Representation Learning workshop and Machine Learning and the Physical Sciences workshop at NeurIPS 2019. 2020. https://arxiv.org/abs/2002.08264
Examples
--------
>>> from deepchem.models.torch_models.layers import MultiHeadedMATAttention, MATEmbedding
>>> import deepchem as dc
>>> import torch
>>> input_smile = "CC"
>>> feat = dc.feat.MATFeaturizer()
>>> input_smile = "CC"
>>> out = feat.featurize(input_smile)
>>> node = torch.tensor(out[0].node_features).float().unsqueeze(0)
>>> adj = torch.tensor(out[0].adjacency_matrix).float().unsqueeze(0)
>>> dist = torch.tensor(out[0].distance_matrix).float().unsqueeze(0)
>>> mask = torch.sum(torch.abs(node), dim=-1) != 0
>>> layer = MultiHeadedMATAttention(
... dist_kernel='softmax',
... lambda_attention=0.33,
... lambda_distance=0.33,
... h=16,
... hsize=1024,
... dropout_p=0.0)
>>> op = MATEmbedding()(node)
>>> output = layer(op, op, op, mask, adj, dist)
"""
def __init__(self,
dist_kernel: str = 'softmax',
lambda_attention: float = 0.33,
lambda_distance: float = 0.33,
h: int = 16,
hsize: int = 1024,
dropout_p: float = 0.0,
output_bias: bool = True):
"""Initialize a multi-headed attention layer.
Parameters
----------
dist_kernel: str
Kernel activation to be used. Can be either 'softmax' for softmax or 'exp' for exponential.
lambda_attention: float
Constant to be multiplied with the attention matrix.
lambda_distance: float
Constant to be multiplied with the distance matrix.
h: int
Number of attention heads.
hsize: int
Size of dense layer.
dropout_p: float
Dropout probability.
output_bias: bool
If True, dense layers will use bias vectors.
"""
super().__init__()
if dist_kernel == "softmax":
self.dist_kernel = lambda x: torch.softmax(-x, dim=-1)
elif dist_kernel == "exp":
self.dist_kernel = lambda x: torch.exp(-x)
self.lambda_attention = lambda_attention
self.lambda_distance = lambda_distance
self.lambda_adjacency = 1.0 - self.lambda_attention - self.lambda_distance
self.d_k = hsize // h
self.h = h
linear_layer = nn.Linear(hsize, hsize)
self.linear_layers = nn.ModuleList([linear_layer for _ in range(3)])
self.dropout_p = nn.Dropout(dropout_p)
self.output_linear = nn.Linear(hsize, hsize, output_bias)
def _single_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor,
adj_matrix: torch.Tensor,
distance_matrix: torch.Tensor,
dropout_p: float = 0.0,
eps: float = 1e-6,
inf: float = 1e12) -> Tuple[torch.Tensor, torch.Tensor]:
"""Defining and computing output for a single MAT attention layer.
Parameters
----------
query: torch.Tensor
Standard query parameter for attention.
key: torch.Tensor
Standard key parameter for attention.
value: torch.Tensor
Standard value parameter for attention.
mask: torch.Tensor
Masks out padding values so that they are not taken into account when computing the attention score.
adj_matrix: torch.Tensor
Adjacency matrix of the input molecule, returned from dc.feat.MATFeaturizer()
dist_matrix: torch.Tensor
Distance matrix of the input molecule, returned from dc.feat.MATFeaturizer()
dropout_p: float
Dropout probability.
eps: float
Epsilon value
inf: float
Value of infinity to be used.
"""
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(
mask.unsqueeze(1).repeat(1, query.shape[1], query.shape[2],
1) == 0, -inf)
p_attn = F.softmax(scores, dim=-1)
adj_matrix = adj_matrix / (
torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(2) + eps)
if len(adj_matrix.shape) <= 3:
p_adj = adj_matrix.unsqueeze(1).repeat(1, query.shape[1], 1, 1)
else:
p_adj = adj_matrix.repeat(1, query.shape[1], 1, 1)
distance_matrix = torch.tensor(distance_matrix).squeeze().masked_fill(
mask.repeat(1, mask.shape[-1], 1) == 0, np.inf)
distance_matrix = self.dist_kernel(distance_matrix)
p_dist = distance_matrix.unsqueeze(1).repeat(1, query.shape[1], 1, 1)
p_weighted = self.lambda_attention * p_attn + self.lambda_distance * p_dist + self.lambda_adjacency * p_adj
p_weighted = self.dropout_p(p_weighted)
return torch.matmul(p_weighted.float(), value.float()), p_attn
def forward(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor,
adj_matrix: torch.Tensor,
distance_matrix: torch.Tensor,
dropout_p: float = 0.0,
eps: float = 1e-6,
inf: float = 1e12) -> torch.Tensor:
"""Output computation for the MultiHeadedAttention layer.
Parameters
----------
query: torch.Tensor
Standard query parameter for attention.
key: torch.Tensor
Standard key parameter for attention.
value: torch.Tensor
Standard value parameter for attention.
mask: torch.Tensor
Masks out padding values so that they are not taken into account when computing the attention score.
adj_matrix: torch.Tensor
Adjacency matrix of the input molecule, returned from dc.feat.MATFeaturizer()
dist_matrix: torch.Tensor
Distance matrix of the input molecule, returned from dc.feat.MATFeaturizer()
dropout_p: float
Dropout probability.
eps: float
Epsilon value
inf: float
Value of infinity to be used.
"""
if mask is not None and len(mask.shape) <= 2:
mask = mask.unsqueeze(1)
batch_size = query.size(0)
query, key, value = [
layer(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
for layer, x in zip(self.linear_layers, (query, key, value))
]
x, _ = self._single_attention(query, key, value, mask, adj_matrix,
distance_matrix, dropout_p, eps, inf)
x = x.transpose(1, 2).contiguous().view(batch_size, -1,
self.h * self.d_k)
return self.output_linear(x)
class MATEncoderLayer(nn.Module):
"""Encoder layer for use in the Molecular Attention Transformer [1]_.
The MATEncoder layer primarily consists of a self-attention layer (MultiHeadedMATAttention) and a feed-forward layer (PositionwiseFeedForward).
This layer can be stacked multiple times to form an encoder.
References
----------
.. [1] Lukasz Maziarka et al. "Molecule Attention Transformer" Graph Representation Learning workshop and Machine Learning and the Physical Sciences workshop at NeurIPS 2019. 2020. https://arxiv.org/abs/2002.08264
Examples
--------
>>> from rdkit import Chem
>>> import torch
>>> import deepchem
>>> from deepchem.models.torch_models.layers import MATEmbedding, MATEncoderLayer
>>> input_smile = "CC"
>>> feat = deepchem.feat.MATFeaturizer()
>>> out = feat.featurize(input_smile)
>>> node = torch.tensor(out[0].node_features).float().unsqueeze(0)
>>> adj = torch.tensor(out[0].adjacency_matrix).float().unsqueeze(0)
>>> dist = torch.tensor(out[0].distance_matrix).float().unsqueeze(0)
>>> mask = torch.sum(torch.abs(node), dim=-1) != 0
>>> layer = MATEncoderLayer()
>>> op = MATEmbedding()(node)
>>> output = layer(op, mask, adj, dist)
"""
def __init__(self,
dist_kernel: str = 'softmax',
lambda_attention: float = 0.33,
lambda_distance: float = 0.33,
h: int = 16,
sa_hsize: int = 1024,
sa_dropout_p: float = 0.0,
output_bias: bool = True,
d_input: int = 1024,
d_hidden: int = 1024,
d_output: int = 1024,
activation: str = 'leakyrelu',
n_layers: int = 1,
ff_dropout_p: float = 0.0,
encoder_hsize: int = 1024,
encoder_dropout_p: float = 0.0):
"""Initialize a MATEncoder layer.
Parameters
----------
dist_kernel: str
Kernel activation to be used. Can be either 'softmax' for softmax or 'exp' for exponential, for the self-attention layer.
lambda_attention: float
Constant to be multiplied with the attention matrix in the self-attention layer.
lambda_distance: float
Constant to be multiplied with the distance matrix in the self-attention layer.
h: int
Number of attention heads for the self-attention layer.
sa_hsize: int
Size of dense layer in the self-attention layer.
sa_dropout_p: float
Dropout probability for the self-attention layer.
output_bias: bool
If True, dense layers will use bias vectors in the self-attention layer.
d_input: int
Size of input layer in the feed-forward layer.
d_hidden: int
Size of hidden layer in the feed-forward layer.
d_output: int
Size of output layer in the feed-forward layer.
activation: str
Activation function to be used in the feed-forward layer.
Can choose between 'relu' for ReLU, 'leakyrelu' for LeakyReLU, 'prelu' for PReLU,
'tanh' for TanH, 'selu' for SELU, 'elu' for ELU and 'linear' for linear activation.
n_layers: int
Number of layers in the feed-forward layer.
dropout_p: float
Dropout probability in the feeed-forward layer.
encoder_hsize: int
Size of Dense layer for the encoder itself.
encoder_dropout_p: float
Dropout probability for connections in the encoder layer.
"""
super(MATEncoderLayer, self).__init__()
self.self_attn = MultiHeadedMATAttention(dist_kernel, lambda_attention,
lambda_distance, h, sa_hsize,
sa_dropout_p, output_bias)
self.feed_forward = PositionwiseFeedForward(d_input, d_hidden, d_output,
activation, n_layers,
ff_dropout_p)
layer = SublayerConnection(size=encoder_hsize,
dropout_p=encoder_dropout_p)
self.sublayer = nn.ModuleList([layer for _ in range(2)])
self.size = encoder_hsize
def forward(self,
x: torch.Tensor,
mask: torch.Tensor,
adj_matrix: torch.Tensor,
distance_matrix: torch.Tensor,
sa_dropout_p: float = 0.0) -> torch.Tensor:
"""Output computation for the MATEncoder layer.
In the MATEncoderLayer intialization, self.sublayer is defined as an nn.ModuleList of 2 layers. We will be passing our computation through these layers sequentially.
nn.ModuleList is subscriptable and thus we can access it as self.sublayer[0], for example.
Parameters
----------
x: torch.Tensor
Input tensor.
mask: torch.Tensor
Masks out padding values so that they are not taken into account when computing the attention score.
adj_matrix: torch.Tensor
Adjacency matrix of a molecule.
distance_matrix: torch.Tensor
Distance matrix of a molecule.
sa_dropout_p: float
Dropout probability for the self-attention layer (MultiHeadedMATAttention).
"""
x = self.sublayer[0](x,
self.self_attn(x,
x,
x,
mask=mask,
dropout_p=sa_dropout_p,
adj_matrix=adj_matrix,
distance_matrix=distance_matrix))
return self.sublayer[1](x, self.feed_forward(x))
class SublayerConnection(nn.Module):
"""SublayerConnection layer based on the paper `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
The SublayerConnection normalizes and adds dropout to output tensor of an arbitary layer.
It further adds a residual layer connection between the input of the arbitary layer and the dropout-adjusted layer output.
Examples
--------
>>> from deepchem.models.torch_models.layers import SublayerConnection
>>> scale = 0.35
>>> layer = SublayerConnection(2, 0.)
>>> input_ar = torch.tensor([[1., 2.], [5., 6.]])
>>> output = layer(input_ar, input_ar)
"""
def __init__(self, size: int, dropout_p: float = 0.0):
"""Initialize a SublayerConnection Layer.
Parameters
----------
size: int
Size of layer.
dropout_p: float
Dropout probability.
"""
super(SublayerConnection, self).__init__()
self.norm = nn.LayerNorm(size)
self.dropout_p = nn.Dropout(dropout_p)
def forward(self, x: torch.Tensor, output: torch.Tensor) -> torch.Tensor:
"""Output computation for the SublayerConnection layer.
Takes an input tensor x, then adds the dropout-adjusted sublayer output for normalized x to it.
This is done to add a residual connection followed by LayerNorm.
Parameters
----------
x: torch.Tensor
Input tensor.
output: torch.Tensor
Layer whose normalized output will be added to x.
"""
if x is None:
return self.dropout_p(self.norm(output))
return x + self.dropout_p(self.norm(output))
class PositionwiseFeedForward(nn.Module):
"""PositionwiseFeedForward is a layer used to define the position-wise feed-forward (FFN) algorithm for the Molecular Attention Transformer [1]_
Each layer in the MAT encoder contains a fully connected feed-forward network which applies two linear transformations and the given activation function.
This is done in addition to the SublayerConnection module.
Note: This modified version of `PositionwiseFeedForward` class contains `dropout_at_input_no_act` condition to facilitate its use in defining
the feed-forward (FFN) algorithm for the Directed Message Passing Neural Network (D-MPNN) [2]_
References
----------
.. [1] Lukasz Maziarka et al. "Molecule Attention Transformer" Graph Representation Learning workshop and Machine Learning and the Physical Sciences workshop at NeurIPS 2019. 2020. https://arxiv.org/abs/2002.08264
.. [2] Analyzing Learned Molecular Representations for Property Prediction https://arxiv.org/pdf/1904.01561.pdf
Examples
--------
>>> from deepchem.models.torch_models.layers import PositionwiseFeedForward
>>> feed_fwd_layer = PositionwiseFeedForward(d_input = 2, d_hidden = 2, d_output = 2, activation = 'relu', n_layers = 1, dropout_p = 0.1)
>>> input_tensor = torch.tensor([[1., 2.], [5., 6.]])
>>> output_tensor = feed_fwd_layer(input_tensor)
"""
def __init__(self,
d_input: int = 1024,
d_hidden: int = 1024,
d_output: int = 1024,
activation: str = 'leakyrelu',
n_layers: int = 1,
dropout_p: float = 0.0,
dropout_at_input_no_act: bool = False):
"""Initialize a PositionwiseFeedForward layer.
Parameters
----------
d_input: int
Size of input layer.
d_hidden: int (same as d_input if d_output = 0)
Size of hidden layer.
d_output: int (same as d_input if d_output = 0)
Size of output layer.
activation: str
Activation function to be used. Can choose between 'relu' for ReLU, 'leakyrelu' for LeakyReLU, 'prelu' for PReLU,
'tanh' for TanH, 'selu' for SELU, 'elu' for ELU and 'linear' for linear activation.
n_layers: int
Number of layers.
dropout_p: float
Dropout probability.
dropout_at_input_no_act: bool
If true, dropout is applied on the input tensor. For single layer, it is not passed to an activation function.
"""
super(PositionwiseFeedForward, self).__init__()
self.dropout_at_input_no_act: bool = dropout_at_input_no_act
if activation == 'relu':
self.activation: Any = nn.ReLU()
elif activation == 'leakyrelu':
self.activation = nn.LeakyReLU(0.1)
elif activation == 'prelu':
self.activation = nn.PReLU()
elif activation == 'tanh':
self.activation = nn.Tanh()
elif activation == 'selu':
self.activation = nn.SELU()
elif activation == 'elu':
self.activation = nn.ELU()
elif activation == "linear":
self.activation = lambda x: x
self.n_layers: int = n_layers
d_output = d_output if d_output != 0 else d_input
d_hidden = d_hidden if d_hidden != 0 else d_input
if n_layers == 1:
self.linears: Any = [nn.Linear(d_input, d_output)]
else:
self.linears = [nn.Linear(d_input, d_hidden)] + \
[nn.Linear(d_hidden, d_hidden) for _ in range(n_layers - 2)] + \
[nn.Linear(d_hidden, d_output)]
self.linears = nn.ModuleList(self.linears)
dropout_layer = nn.Dropout(dropout_p)
self.dropout_p = nn.ModuleList([dropout_layer for _ in range(n_layers)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Output Computation for the PositionwiseFeedForward layer.
Parameters
----------
x: torch.Tensor
Input tensor.
"""
if not self.n_layers:
return x
if self.n_layers == 1:
if self.dropout_at_input_no_act:
return self.linears[0](self.dropout_p[0](x))
else:
return self.dropout_p[0](self.activation(self.linears[0](x)))
else:
if self.dropout_at_input_no_act:
x = self.dropout_p[0](x)
for i in range(self.n_layers - 1):
x = self.dropout_p[i](self.activation(self.linears[i](x)))
return self.linears[-1](x)
class MATEmbedding(nn.Module):
"""Embedding layer to create embedding for inputs.
In an embedding layer, input is taken and converted to a vector representation for each input.
In the MATEmbedding layer, an input tensor is processed through a dropout-adjusted linear layer and the resultant vector is returned.
References
----------
.. [1] Lukasz Maziarka et al. "Molecule Attention Transformer" Graph Representation Learning workshop and Machine Learning and the Physical Sciences workshop at NeurIPS 2019. 2020. https://arxiv.org/abs/2002.08264
Examples
--------
>>> from deepchem.models.torch_models.layers import MATEmbedding
>>> layer = MATEmbedding(d_input = 3, d_output = 3, dropout_p = 0.2)
>>> input_tensor = torch.tensor([1., 2., 3.])
>>> output = layer(input_tensor)
"""
def __init__(self,
d_input: int = 36,
d_output: int = 1024,
dropout_p: float = 0.0):
"""Initialize a MATEmbedding layer.
Parameters
----------
d_input: int
Size of input layer.
d_output: int
Size of output layer.
dropout_p: float
Dropout probability for layer.
"""
super(MATEmbedding, self).__init__()
self.linear_unit = nn.Linear(d_input, d_output)
self.dropout = nn.Dropout(dropout_p)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Computation for the MATEmbedding layer.
Parameters
----------
x: torch.Tensor
Input tensor to be converted into a vector.
"""
return self.dropout(self.linear_unit(x))
class MATGenerator(nn.Module):
"""MATGenerator defines the linear and softmax generator step for the Molecular Attention Transformer [1]_.
In the MATGenerator, a Generator is defined which performs the Linear + Softmax generation step.
Depending on the type of aggregation selected, the attention output layer performs different operations.
References
----------
.. [1] Lukasz Maziarka et al. "Molecule Attention Transformer" Graph Representation Learning workshop and Machine Learning and the Physical Sciences workshop at NeurIPS 2019. 2020. https://arxiv.org/abs/2002.08264
Examples
--------
>>> from deepchem.models.torch_models.layers import MATGenerator
>>> layer = MATGenerator(hsize = 3, aggregation_type = 'mean', d_output = 1, n_layers = 1, dropout_p = 0.3, attn_hidden = 128, attn_out = 4)
>>> input_tensor = torch.tensor([1., 2., 3.])
>>> mask = torch.tensor([1., 1., 1.])
>>> output = layer(input_tensor, mask)
"""
def __init__(self,
hsize: int = 1024,
aggregation_type: str = 'mean',
d_output: int = 1,
n_layers: int = 1,
dropout_p: float = 0.0,
attn_hidden: int = 128,
attn_out: int = 4):
"""Initialize a MATGenerator.
Parameters
----------
hsize: int
Size of input layer.
aggregation_type: str
Type of aggregation to be used. Can be 'grover', 'mean' or 'contextual'.
d_output: int
Size of output layer.
n_layers: int
Number of layers in MATGenerator.
dropout_p: float
Dropout probability for layer.
attn_hidden: int
Size of hidden attention layer.
attn_out: int
Size of output attention layer.
"""
super(MATGenerator, self).__init__()
if aggregation_type == 'grover':
self.att_net = nn.Sequential(
nn.Linear(hsize, attn_hidden, bias=False),