/
model.py
80 lines (71 loc) · 1.99 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch as tc
import torch.nn as nn
import torch.nn.init as init
from config import NROW,NCOL
from core import DEVICE
class Conv(nn.Module):
def __init__(self, chn_in, chn_out, ker_sz=3):
super().__init__()
self.c=nn.Conv2d(chn_in,chn_out,ker_sz,padding=ker_sz//2,padding_mode="circular",bias=False)
self.b=nn.BatchNorm2d(chn_out)
self.a=nn.LeakyReLU(0.1)
def forward(self, x):
return self.a(self.b(self.c(x)))
class Resi(nn.Module):
def __init__(self, chn, ker_sz=3):
super().__init__()
self.pre=nn.Sequential(
nn.Conv2d(chn,chn,ker_sz,padding=ker_sz//2,padding_mode="circular",bias=False),
nn.BatchNorm2d(chn),
nn.LeakyReLU(0.1),
nn.Conv2d(chn,chn,ker_sz,padding=ker_sz//2,padding_mode="circular",bias=False),
nn.BatchNorm2d(chn),
)
self.post=nn.LeakyReLU(0.1)
def forward(self, x):
return self.post(self.pre(x)+x)
class Full(nn.Module):
def __init__(self, N_in, N_out, afunc=nn.ReLU(), drop_out=False):
super().__init__()
self.l=nn.Linear(N_in,N_out)
self.drop_out=drop_out
if self.drop_out: self.d=nn.Dropout(0.5)
self.a=afunc
def forward(self, x):
x=self.l(x)
if self.drop_out: x=self.d(x)
if self.a: x=self.a(x)
return x
class SnakeNet(nn.Module):
def __init__(self):
super(SnakeNet,self).__init__()
self.chn_in=4
self.chn_mid=64
self.chn_out=10
self.feature=nn.Sequential(
Conv(self.chn_in,self.chn_mid),
Resi(self.chn_mid),
Resi(self.chn_mid),
Conv(self.chn_mid,self.chn_out),
nn.Flatten(),
)
self.adv = nn.Sequential(
Full(self.chn_out*NROW*NCOL,256),
Full(256,4,None),
)
self.stval = nn.Sequential(
Full(self.chn_out*NROW*NCOL,256),
Full(256,1,None),
)
for x in self.modules():
if isinstance(x,nn.Conv2d) or isinstance(x,nn.Linear):
init.xavier_uniform_(x.weight.data)
if x.bias != None:
init.zeros_(x.bias)
def forward(self,x):
x = x.reshape(-1,self.chn_in,NROW,NCOL)
x = self.feature(x)
adv = self.adv(x)
stval = self.stval(x)
qval = (adv-adv.mean())+stval
return qval