-
Notifications
You must be signed in to change notification settings - Fork 35
/
layers.py
217 lines (165 loc) · 8.32 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# ------------------------------------------------------------------------------------
# Enhancing Transformers
# Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from ViT-Pytorch (https://github.com/lucidrains/vit-pytorch)
# Copyright (c) 2020 Phil Wang. All Rights Reserved.
# ------------------------------------------------------------------------------------
import math
import numpy as np
from typing import Union, Tuple, List
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
def get_2d_sincos_pos_embed(embed_dim, grid_size):
"""
grid_size: int or (int, int) of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_size = (grid_size, grid_size) if type(grid_size) != tuple else grid_size
grid_h = np.arange(grid_size[0], dtype=np.float32)
grid_w = np.arange(grid_size[1], dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def init_weights(m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
w = m.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
class PreNorm(nn.Module):
def __init__(self, dim: int, fn: nn.Module) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, dim)
)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim: int, heads: int = 8, dim_head: int = 64) -> None:
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim) if project_out else nn.Identity()
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int) -> None:
super().__init__()
self.layers = nn.ModuleList([])
for idx in range(depth):
layer = nn.ModuleList([PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head)),
PreNorm(dim, FeedForward(dim, mlp_dim))])
self.layers.append(layer)
self.norm = nn.LayerNorm(dim)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class ViTEncoder(nn.Module):
def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int],
dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 3, dim_head: int = 64) -> None:
super().__init__()
image_height, image_width = image_size if isinstance(image_size, tuple) \
else (image_size, image_size)
patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \
else (patch_size, patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
en_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width))
self.num_patches = (image_height // patch_height) * (image_width // patch_width)
self.patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size),
Rearrange('b c h w -> b (h w) c'),
)
self.en_pos_embedding = nn.Parameter(torch.from_numpy(en_pos_embedding).float().unsqueeze(0), requires_grad=False)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.apply(init_weights)
def forward(self, img: torch.FloatTensor) -> torch.FloatTensor:
x = self.to_patch_embedding(img)
x += self.en_pos_embedding
x = self.transformer(x)
return x
class ViTDecoder(nn.Module):
def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int],
dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 3, dim_head: int = 64) -> None:
super().__init__()
image_height, image_width = image_size if isinstance(image_size, tuple) \
else (image_size, image_size)
patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \
else (patch_size, patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
de_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width))
self.num_patches = (image_height // patch_height) * (image_width // patch_width)
self.patch_dim = channels * patch_height * patch_width
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.de_pos_embedding = nn.Parameter(torch.from_numpy(de_pos_embedding).float().unsqueeze(0), requires_grad=False)
self.to_pixel = nn.Sequential(
Rearrange('b (h w) c -> b c h w', h=image_height // patch_height),
nn.ConvTranspose2d(dim, channels, kernel_size=patch_size, stride=patch_size)
)
self.apply(init_weights)
def forward(self, token: torch.FloatTensor) -> torch.FloatTensor:
token += self.de_pos_embedding
x = self.transformer(token)
x = self.to_pixel(x)
return x
def get_last_layer(self) -> nn.Parameter:
return self.to_pixel[-1].weight