In [1]:
import os
import torch
import torch.nn as nn
from model import StyledGenerator, Discriminator, TextProcess, PixelNorm
from dataset import MultiResolutionDataset
from train import sample_data
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from pytorch_pretrained_bert import BertTokenizer
from PIL import Image
from torchvision.transforms import functional as trans_fn
from torchvision import transforms
from functools import partial
import lmdb
from io import BytesIO

In [2]:
transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ]
)
data_path = './data/birds_lmdb/'
dataset = MultiResolutionDataset(data_path, transform, max_length=18)
dataloader = sample_data(dataset, 16, 16)

In [3]:
data_iter = iter(dataloader)
image, caption = next(data_iter)

In [6]:
text_process = nn.DataParallel(TextProcess(max_length=18, embedding_dim=128, condition_dim=128)).cuda()
generator = nn.DataParallel(StyledGenerator(512)).cuda()
discriminator = nn.DataParallel(
    Discriminator(from_rgb_activate=True)
).cuda()

In [7]:
ckpt = torch.load('./output/birds_1115_sentnorm/checkpoint/train_step-2.model')
        
text_process.module.load_state_dict(ckpt['text_process'])
generator.module.load_state_dict(ckpt['generator'])
discriminator.module.load_state_dict(ckpt['discriminator'])

<All keys matched successfully>

In [8]:
image = image.cuda()
caption = caption.cuda()
c_code, sent_emb, words_embs, mu, log_var = text_process(caption)

In [9]:
gen_in1, gen_in2 = torch.randn(2, 16, 512, device='cuda').chunk(
    2, 0
)
gen_in1, gen_in2 = gen_in1.squeeze(0), gen_in2.squeeze(0)

In [10]:
fake_image = generator(gen_in1, c_code, step=2, alpha=1)

In [11]:
fake_predict, cond_fake_predict = discriminator(fake_image, sent_emb, step=2, alpha=1)

In [12]:
print(fake_predict, cond_fake_predict)

tensor([[ 0.0107],
        [ 0.0078],
        [ 0.0295],
        [ 0.0340],
        [ 0.0154],
        [ 0.0197],
        [ 0.0303],
        [ 0.0285],
        [ 0.0141],
        [-0.0121],
        [ 0.0001],
        [ 0.0235],
        [-0.0486],
        [ 0.0229],
        [-0.0068],
        [ 0.0140]], device='cuda:0', grad_fn=<AddmmBackward>) tensor([[294.3565],
        [295.3235],
        [275.7867],
        [262.7184],
        [310.3739],
        [287.1729],
        [302.3688],
        [279.0579],
        [237.2812],
        [263.8758],
        [283.0998],
        [287.5467],
        [214.9000],
        [286.3645],
        [302.3677],
        [302.2542]], device='cuda:0', grad_fn=<AddmmBackward>)


In [13]:
real_predict, cond_real_predict = discriminator(image, sent_emb, step=2, alpha=1)

In [14]:
print(real_predict, cond_real_predict)

tensor([[ 0.0470],
        [ 0.0384],
        [-0.0314],
        [ 0.1020],
        [ 0.1676],
        [ 0.0348],
        [ 0.1095],
        [ 0.0342],
        [-0.0216],
        [ 0.0590],
        [ 0.0539],
        [ 0.0271],
        [ 0.0904],
        [-0.0711],
        [ 0.0158],
        [ 0.0355]], device='cuda:0', grad_fn=<AddmmBackward>) tensor([[-123.4869],
        [ 251.5795],
        [-125.8931],
        [-272.5009],
        [  51.3055],
        [ 208.0517],
        [-506.0524],
        [ 210.1932],
        [-267.4745],
        [ 195.9772],
        [ 168.3169],
        [ 169.7793],
        [-118.1396],
        [-403.6581],
        [-365.6588],
        [ 241.0832]], device='cuda:0', grad_fn=<AddmmBackward>)


In [20]:
sent_emb.view(16,-1,1,1).repeat(1, 1, 4, 4).shape

torch.Size([16, 128, 4, 4])

In [34]:
sent_emb.

torch.Size([16, 256])

In [32]:
for param in discriminator.module.cond_linear.parameters():
    print(param)

Parameter containing:
tensor([-0.3497], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[ 4.2548e+00, -4.2975e+00,  4.2328e+00, -8.6010e+00, -3.6583e+00,
         -4.7534e+00,  6.1431e+00, -4.5836e+00,  4.9216e+00, -8.6113e+00,
          3.7165e+00,  3.0080e+00,  8.2881e+00,  3.2673e+00, -4.0672e+00,
         -3.9369e+00,  3.9243e+00, -4.5492e+00,  3.4317e+00,  3.3498e+00,
         -2.9930e-01, -1.0317e+01,  5.7217e+00, -5.5745e+00, -8.3253e+00,
          5.2831e+00,  6.6102e+00,  3.7642e+00,  6.8925e+00,  3.9787e+00,
          6.8428e+00, -9.8916e+00,  5.7453e+00, -4.2589e+00, -5.4853e+00,
          2.8553e+00,  7.1142e-01, -5.0571e+00,  7.0481e+00, -1.7355e+00,
         -5.6823e+00, -9.2463e+00,  6.6507e+00,  6.0433e+00, -6.2733e+00,
          6.3619e+00,  7.3124e+00, -8.3420e+00,  7.0957e+00, -1.1324e+01,
          4.5173e+00, -6.4903e+00, -4.1134e+00,  5.7684e+00,  3.6925e+00,
         -1.0550e+01, -1.0216e+01,  4.3370e+00,  7.2894e+00,  6.6203e+00,
         -3.4

In [54]:
with env.begin(write=False) as txn:
    img_key = f'4-{str(0).zfill(5)}'.encode('utf-8')
    txt_key = f'txt-1-{str(0).zfill(5)}'.encode('utf-8')
    txt_bytes = txn.get(txt_key)
    txt = txt_bytes.decode('utf-8')
    img_bytes = txn.get(img_key)
    buffer = BytesIO(img_bytes)
    img = Image.open(buffer)

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [187]:
text_process = TextProcess()

In [14]:
# data loader test
transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )
dataset = MultiResolutionDataset('./birds/', transform)
loader = sample_data(dataset, 16, 8)
data_iter = iter(loader)

In [55]:
def sample_data(dataset, batch_size, image_size=4):
    dataset.resolution = image_size
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4)

    return loader