-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
26 lines (22 loc) · 810 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
import torch.nn.functional as F
class Regressor(nn.Module):
def __init__(self):
super(Regressor, self).__init__()
self.layers = nn.Sequential(
nn.Linear(1,40),
nn.ReLU(),
nn.Linear(40,40),
nn.ReLU(),
nn.Linear(40,1)
)
def forward(self, x, params = None):
# x : a tensor with size 1
# params : an ordered dictionary of parameters
if params == None:
params = self.state_dict()
x = F.relu(F.linear(x, params['layers.0.weight'], params['layers.0.bias']))
x = F.relu(F.linear(x, params['layers.2.weight'], params['layers.2.bias']))
x = F.linear(x, params['layers.4.weight'], params['layers.4.bias'])
return x