Skip to content

Commit

Permalink
Update pixelsnail for sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
rosinality committed Jun 26, 2019
1 parent 3d04afe commit bd0ba96
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 37 deletions.
51 changes: 41 additions & 10 deletions pixelsnail.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def __init__(
bias=bias,
)
)

self.out_channel = out_channel

if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]

self.kernel_size = kernel_size

self.activation = activation

Expand Down Expand Up @@ -75,6 +82,8 @@ def __init__(

if isinstance(kernel_size, int):
kernel_size = [kernel_size] * 2

self.kernel_size = kernel_size

if padding == 'downright':
pad = [kernel_size[1] - 1, 0, kernel_size[0] - 1, 0]
Expand Down Expand Up @@ -142,7 +151,7 @@ def __init__(
self.dropout = nn.Dropout(dropout)

self.conv2 = conv_module(channel, in_channel * 2, kernel_size)

if condition_dim > 0:
# self.condition = nn.Linear(condition_dim, in_channel * 2, bias=False)
self.condition = WNConv2d(condition_dim, in_channel * 2, 1, bias=False)
Expand Down Expand Up @@ -212,7 +221,7 @@ def reshape(input):
mask, start_mask = causal_mask(height * width)
mask = mask.type_as(query)
start_mask = start_mask.type_as(query)
attn = attn.masked_fill(mask == 0, -1e9)
attn = attn.masked_fill(mask == 0, -1e4)
attn = torch.softmax(attn, 3) * start_mask
attn = self.dropout(attn)

Expand Down Expand Up @@ -329,6 +338,7 @@ def __init__(
n_cond_res_block=0,
cond_res_channel=0,
cond_res_kernel=3,
n_out_res_block=0,
):
super().__init__()

Expand Down Expand Up @@ -375,26 +385,47 @@ def __init__(
n_class, cond_res_channel, cond_res_kernel, n_cond_res_block
)

self.out = nn.Sequential(nn.ELU(inplace=True), WNConv2d(channel, n_class, 1))
out = []

for i in range(n_out_res_block):
out.append(GatedResBlock(channel, res_channel, 1))

out.extend([nn.ELU(inplace=True), WNConv2d(channel, n_class, 1)])

def forward(self, input, condition=None):
self.out = nn.Sequential(*out)

def forward(self, input, condition=None, cache=None):
if cache is None:
cache = {}
batch, height, width = input.shape
input = F.one_hot(input, self.n_class).permute(0, 3, 1, 2).float()
input = (
F.one_hot(input, self.n_class).permute(0, 3, 1, 2).type_as(self.background)
)
horizontal = shift_down(self.horizontal(input))
vertical = shift_right(self.vertical(input))
out = horizontal + vertical

background = self.background[:, :, :height, :].expand(batch, 2, height, width)

if condition is not None:
condition = F.one_hot(condition, self.n_class).permute(0, 3, 1, 2).float()
condition = self.cond_resnet(condition)
condition = F.interpolate(condition, scale_factor=2)
condition = condition[:, :, :height, :]
if 'condition' in cache:
condition = cache['condition']
condition = condition[:, :, :height, :]

else:
condition = (
F.one_hot(condition, self.n_class)
.permute(0, 3, 1, 2)
.type_as(self.background)
)
condition = self.cond_resnet(condition)
condition = F.interpolate(condition, scale_factor=2)
cache['condition'] = condition.detach().clone()
condition = condition[:, :, :height, :]

for block in self.blocks:
out = block(out, background, condition=condition)

out = self.out(out)

return out
return out, cache
66 changes: 46 additions & 20 deletions sample.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
@torch.no_grad()
def sample_model(model, device, batch, size, temperature, condition=None):
row = torch.zeros(batch, *size, dtype=torch.int64).to(device)
cache = {}

for i in tqdm(range(size[0])):
for j in range(size[1]):
out = model(row[:, : i + 1, :], condition=condition)
out, cache = model(row[:, : i + 1, :], condition=condition, cache=cache)
prob = torch.softmax(out[:, :, i, j] / temperature, 1)
sample = torch.multinomial(prob, 1).squeeze(-1)
row[:, i, j] = sample
Expand All @@ -24,7 +25,47 @@ def sample_model(model, device, batch, size, temperature, condition=None):


def load_model(model, checkpoint, device):
model.load_state_dict(torch.load(os.path.join('checkpoint', checkpoint)))
ckpt = torch.load(os.path.join('checkpoint', checkpoint))


if 'args' in ckpt:
args = ckpt['args']

if model == 'vqvae':
model = VQVAE()

elif model == 'pixelsnail_top':
model = PixelSNAIL(
[32, 32],
512,
args.channel,
5,
4,
args.n_res_block,
args.n_res_channel,
dropout=args.dropout,
n_out_res_block=args.n_out_res_block,
)

elif model == 'pixelsnail_bottom':
model = PixelSNAIL(
[64, 64],
512,
args.channel,
5,
4,
args.n_res_block,
args.n_res_channel,
attention=False,
dropout=args.dropout,
n_cond_res_block=args.n_cond_res_block,
cond_res_channel=args.n_res_channel,
)

if 'model' in ckpt:
ckpt = ckpt['model']

model.load_state_dict(ckpt)
model = model.to(device)
model.eval()

Expand All @@ -44,24 +85,9 @@ def load_model(model, checkpoint, device):

args = parser.parse_args()

model_vqvae = VQVAE()
model_top = PixelSNAIL([32, 32], 512, 256, 5, 4, 4, 256)
model_bottom = PixelSNAIL(
[64, 64],
512,
256,
5,
3,
4,
256,
attention=False,
n_cond_res_block=3,
cond_res_channel=256,
)

model_vqvae = load_model(model_vqvae, args.vqvae, device)
model_top = load_model(model_top, args.top, device)
model_bottom = load_model(model_bottom, args.bottom, device)
model_vqvae = load_model('vqvae', args.vqvae, device)
model_top = load_model('pixelsnail_top', args.top, device)
model_bottom = load_model('pixelsnail_bottom', args.bottom, device)

top_sample = sample_model(model_top, device, args.batch, [32, 32], args.temp)
bottom_sample = sample_model(
Expand Down
41 changes: 34 additions & 7 deletions train_pixelsnail.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,14 @@ def __call__(self, input):
parser.add_argument('--epoch', type=int, default=420)
parser.add_argument('--hier', type=str, default='top')
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--channel', type=int, default=256)
parser.add_argument('--n_res_block', type=int, default=4)
parser.add_argument('--n_res_channel', type=int, default=256)
parser.add_argument('--n_out_res_block', type=int, default=0)
parser.add_argument('--n_cond_res_block', type=int, default=3)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--sched', type=str)
parser.add_argument('--ckpt', type=str)
parser.add_argument('path', type=str)

args = parser.parse_args()
Expand All @@ -78,23 +85,43 @@ def __call__(self, input):

dataset = LMDBDataset(args.path)
loader = DataLoader(dataset, batch_size=args.batch, shuffle=True, num_workers=4)

ckpt = {}

if args.ckpt is not None:
ckpt = torch.load(args.ckpt)
args = ckpt['args']

if args.hier == 'top':
model = PixelSNAIL([32, 32], 512, 256, 5, 4, 4, 256)
model = PixelSNAIL(
[32, 32],
512,
args.channel,
5,
4,
args.n_res_block,
args.n_res_channel,
dropout=args.dropout,
n_out_res_block=args.n_out_res_block,
)

elif args.hier == 'bottom':
model = PixelSNAIL(
[64, 64],
512,
256,
args.channel,
5,
3,
4,
256,
args.n_res_block,
args.n_res_channel,
attention=False,
n_cond_res_block=3,
cond_res_channel=256,
dropout=args.dropout,
n_cond_res_block=args.n_cond_res_block,
cond_res_channel=args.n_res_channel,
)

if 'model' in ckpt:
model.load_state_dict(ckpt['model'])

model = nn.DataParallel(model)
model = model.to(device)
Expand All @@ -109,6 +136,6 @@ def __call__(self, input):
for i in range(args.epoch):
train(args, i, loader, model, optimizer, scheduler, device)
torch.save(
model.module.state_dict(),
{'model': model.module.state_dict(), 'args': args},
f'checkpoint/pixelsnail_{args.hier}_{str(i + 1).zfill(3)}.pt',
)

0 comments on commit bd0ba96

Please sign in to comment.