In [1]:
import os
import torch
import torch.nn as nn
from model import StyledGenerator, Discriminator, TextProcess, PixelNorm
from dataset import MultiResolutionDataset
from train import sample_data
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from pytorch_pretrained_bert import BertTokenizer
from PIL import Image
from torchvision.transforms import functional as trans_fn
from torchvision import transforms
from functools import partial
import lmdb
from io import BytesIO

In [5]:
transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ]
)
data_path = './data/birds_lmdb/'
dataset = MultiResolutionDataset(data_path, transform, max_length=24)
dataloader = sample_data(dataset, 16, 16)

In [11]:
x,y = dataset[0]

In [13]:
y

tensor([  101,  4743,  2038,  2829,  2303, 12261,  1010,  2317,  7388, 12261,
         1998,  2304, 23525,   102,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0])

In [7]:
data_iter = iter(dataloader)
image, caption = next(data_iter)

In [4]:
text_process = nn.DataParallel(TextProcess(max_length=18, embedding_dim=128, condition_dim=128)).cuda()
generator = nn.DataParallel(StyledGenerator(512)).cuda()
discriminator = nn.DataParallel(
    Discriminator(from_rgb_activate=True)
).cuda()

In [18]:
ckpt = torch.load('./output/birds_1115_text/checkpoint/train_step-4.model')
        
text_process.module.load_state_dict(ckpt['text_process'])
generator.module.load_state_dict(ckpt['generator'])
discriminator.module.load_state_dict(ckpt['discriminator'])

<All keys matched successfully>

In [19]:
image = image.cuda()
caption = caption.cuda()
c_code, sent_emb, words_embs, mu, log_var = text_process(caption)

In [20]:
gen_in1, gen_in2 = torch.randn(2, 16, 512, device='cuda').chunk(
    2, 0
)
gen_in1, gen_in2 = gen_in1.squeeze(0), gen_in2.squeeze(0)

In [21]:
fake_image = generator(gen_in1, c_code, step=2, alpha=1)

In [22]:
fake_predict, cond_fake_predict = discriminator(fake_image, sent_emb, step=2, alpha=1)

In [23]:
print(fake_predict, cond_fake_predict)

tensor([[-12.3413],
        [-11.3501],
        [-10.0159],
        [ -4.0069],
        [ -4.0809],
        [-12.1158],
        [-15.8684],
        [-10.9860],
        [ -7.3656],
        [ -7.0174],
        [-11.1419],
        [ -8.3112],
        [ -7.6498],
        [-16.0772],
        [-10.2059],
        [-10.7565]], device='cuda:0', grad_fn=<AddmmBackward>) tensor([[  2.7255],
        [ -5.6749],
        [  6.7460],
        [  2.1424],
        [ -3.1646],
        [-11.3415],
        [-13.9169],
        [  1.9149],
        [  5.4847],
        [-14.3393],
        [ -3.0558],
        [ -2.6986],
        [  5.4023],
        [-37.3936],
        [  4.0854],
        [ -0.8057]], device='cuda:0', grad_fn=<AddmmBackward>)


In [24]:
real_predict, cond_real_predict = discriminator(image, sent_emb, step=2, alpha=1)

In [25]:
print(real_predict, cond_real_predict)

tensor([[-6.9548],
        [-7.8173],
        [-7.3097],
        [-6.0937],
        [-6.8984],
        [-3.6984],
        [-7.0426],
        [-5.4220],
        [-7.0415],
        [-7.2569],
        [-6.4480],
        [-7.3371],
        [-6.2449],
        [-6.5143],
        [-5.1907],
        [-3.3722]], device='cuda:0', grad_fn=<AddmmBackward>) tensor([[ 3.5470],
        [ 0.1580],
        [ 3.5053],
        [ 6.1817],
        [ 4.9314],
        [ 5.2373],
        [ 0.9385],
        [-0.9460],
        [ 5.5512],
        [-0.1245],
        [ 4.8126],
        [ 5.8424],
        [ 7.3368],
        [ 3.5590],
        [ 7.4234],
        [-0.0569]], device='cuda:0', grad_fn=<AddmmBackward>)


In [26]:
list(discriminator.module.cond_linear.parameters())

[Parameter containing:
 tensor([0.0028], device='cuda:0', requires_grad=True), Parameter containing:
 tensor([[ 4.6028e-03,  2.0716e-03, -7.0309e-02, -1.8585e+00,  1.3172e-02,
           3.4134e-03, -2.0789e-03,  1.4790e-01,  1.3579e-03,  4.1447e-03,
          -8.5507e-02,  3.0530e-03, -7.8137e-04,  3.9972e-01, -1.9104e-03,
           2.1885e-01,  2.5385e-03, -2.5111e-03, -1.0920e-03, -3.6429e-03,
          -7.4963e-03, -7.9946e-04, -7.1822e-03,  3.9212e-04, -5.8045e-03,
          -1.1394e-03, -1.0435e-02,  1.8483e-01,  2.3120e-01, -2.1802e-03,
          -5.2651e-03, -5.5230e-03, -3.2060e-03, -7.0342e-03, -1.4212e-03,
           1.7932e-03,  6.3728e-02, -6.4072e-04, -3.2959e-03, -7.5518e-02,
           8.2140e-01, -4.9785e-03,  6.3329e-03, -1.1312e-02, -1.2374e-02,
           3.8444e-04, -4.0006e-01, -7.8164e-03, -1.6973e-02, -7.3576e-03,
          -4.7296e-03, -1.3644e-02, -1.1120e-02, -1.4048e-04, -4.3126e-03,
          -5.4504e-03,  2.0524e-01,  1.7369e-03,  6.3939e-03, -1.0401e-03,

In [27]:
list(discriminator.module.uncond_linear.parameters())

[Parameter containing:
 tensor([0.], device='cuda:0', requires_grad=True), Parameter containing:
 tensor([[ 0.5347,  1.0063,  0.8361, -0.3474, -0.7262,  1.0974, -0.4119, -0.0503,
           0.0978,  0.1145, -0.2498,  0.0968, -1.6943,  0.6792,  0.3447, -0.1313,
           2.3334, -1.2555,  0.8176, -1.5670, -1.4045, -0.6904, -0.6859, -1.5966,
          -0.8214,  1.0558,  0.6464, -1.1312, -0.3776,  0.2721,  0.0648, -1.3137,
          -0.6502,  0.0554, -0.2081,  0.1633,  1.2382,  1.8609, -1.9676, -1.1435,
           1.5995,  0.1522,  0.3746, -0.5345, -0.2971, -0.3713,  0.3923, -0.8274,
          -2.1498,  0.0765, -1.7068, -1.4490, -1.2354, -0.6065,  0.7193, -0.0184,
           2.1649,  0.4736,  1.8154,  1.0189, -0.0141,  1.4525,  0.6307,  1.4188,
          -0.0201,  1.5840, -0.1345,  0.5022,  0.9447,  1.1047, -0.4026,  1.0160,
           0.3055, -0.6138,  0.1949, -0.1231, -2.2208, -0.5889, -0.9105, -0.3540,
          -1.0980,  0.9768,  0.0883, -0.7209,  0.8490,  1.9694,  0.3660, -1.4909,
 

In [32]:
for param in discriminator.module.cond_linear.parameters():
    print(param)

Parameter containing:
tensor([-0.3497], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[ 4.2548e+00, -4.2975e+00,  4.2328e+00, -8.6010e+00, -3.6583e+00,
         -4.7534e+00,  6.1431e+00, -4.5836e+00,  4.9216e+00, -8.6113e+00,
          3.7165e+00,  3.0080e+00,  8.2881e+00,  3.2673e+00, -4.0672e+00,
         -3.9369e+00,  3.9243e+00, -4.5492e+00,  3.4317e+00,  3.3498e+00,
         -2.9930e-01, -1.0317e+01,  5.7217e+00, -5.5745e+00, -8.3253e+00,
          5.2831e+00,  6.6102e+00,  3.7642e+00,  6.8925e+00,  3.9787e+00,
          6.8428e+00, -9.8916e+00,  5.7453e+00, -4.2589e+00, -5.4853e+00,
          2.8553e+00,  7.1142e-01, -5.0571e+00,  7.0481e+00, -1.7355e+00,
         -5.6823e+00, -9.2463e+00,  6.6507e+00,  6.0433e+00, -6.2733e+00,
          6.3619e+00,  7.3124e+00, -8.3420e+00,  7.0957e+00, -1.1324e+01,
          4.5173e+00, -6.4903e+00, -4.1134e+00,  5.7684e+00,  3.6925e+00,
         -1.0550e+01, -1.0216e+01,  4.3370e+00,  7.2894e+00,  6.6203e+00,
         -3.4

In [54]:
with env.begin(write=False) as txn:
    img_key = f'4-{str(0).zfill(5)}'.encode('utf-8')
    txt_key = f'txt-1-{str(0).zfill(5)}'.encode('utf-8')
    txt_bytes = txn.get(txt_key)
    txt = txt_bytes.decode('utf-8')
    img_bytes = txn.get(img_key)
    buffer = BytesIO(img_bytes)
    img = Image.open(buffer)

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [187]:
text_process = TextProcess()

In [14]:
# data loader test
transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )
dataset = MultiResolutionDataset('./birds/', transform)
loader = sample_data(dataset, 16, 8)
data_iter = iter(loader)

In [55]:
def sample_data(dataset, batch_size, image_size=4):
    dataset.resolution = image_size
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4)

    return loader