-
Notifications
You must be signed in to change notification settings - Fork 1
/
lenet.py
74 lines (62 loc) · 2.31 KB
/
lenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python
# type: ignore
"""Implementation of the general LeNet architecture using PyTorch."""
from types import SimpleNamespace
import torch.nn as nn
from torchfl.compatibility import ACTIVATION_FUNCTIONS_BY_NAME
from torchfl.models.sota.mlp import LinearBlock
class LeNet(nn.Module):
"""LeNet base definition"""
def __init__(
self, num_classes=10, num_channels=1, act_fn_name="relu", **kwargs
) -> None:
"""Constructor
Args:
- num_classes (int, optional): Number of classification outputs. Defaults to 10.
- num_channels (int, optional): Number of channels for the images in the dataset. Defaults to 3.
- act_fn_name (str, optional): Activation function to be used. Defaults to "relu". Accepted: ["tanh", "relu", "leakyrelu", "gelu"].
"""
super().__init__()
self.hparams = SimpleNamespace(
model_name="lenet",
num_classes=num_classes,
num_channels=num_channels,
act_fn_name=act_fn_name,
act_fn=ACTIVATION_FUNCTIONS_BY_NAME[act_fn_name],
pre_trained=False,
feature_extract=False,
finetune=False,
quantized=False,
)
self._create_network()
def _create_network(self):
self.input_net = nn.Sequential(
nn.Conv2d(
self.hparams.num_channels,
6,
kernel_size=5,
stride=1,
padding=2,
),
self.hparams.act_fn(),
nn.AvgPool2d(kernel_size=2, stride=2),
)
self.conv_net = nn.Sequential(
nn.Conv2d(6, 16, kernel_size=5, stride=1),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 120, kernel_size=5, stride=1),
self.hparams.act_fn(),
nn.Flatten(start_dim=1),
LinearBlock(300000, 84, self.hparams.act_fn, False),
)
self.output_net = nn.Sequential(
nn.Linear(84, self.hparams.num_classes)
)
def forward(self, x):
"""Forward propagation
Args:
- x (torch.Tensor): Input Tensor
Returns:
- torch.Tensor: Returns the tensor after forward propagation
"""
return self.output_net(self.conv_net(self.input_net(x)))