In [1]:
import os
from tqdm import tnrange, tqdm_notebook, tqdm
import torch
import torchvision
from torch import nn, autograd, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image, make_grid
from torch.autograd import grad
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter

In [2]:
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [3]:
from matplotlib import rcParams
rcParams['figure.figsize'] = (12, 8)

%matplotlib inline

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
torch.cuda.is_available()

True

In [6]:
device_ids = [0]

In [7]:
BATCH_SIZE = 128
num_epochs = 100

z_dimension = 100
num_feature = (32, 4, 4)

img_shape = (1, 32, 32)

In [8]:
img_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])

dataset = datasets.MNIST('/home/left5/datas/mnist', transform=img_transform) #, download=True)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

In [9]:
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [10]:
def calc_gradient_penalty(netD, real_data, fake_data, real_labels, fake_labels):
    _alpha = torch.rand(1)
    alpha = _alpha.expand_as(real_data).to(device)
    
    alpha_label = _alpha.expand_as(real_labels).to(device)
    
#     alpha_labs = torch.rand(1)
#     alpha_labs = alpha_labs.expand_as(real_labels)
#     alpha_labs = alpha_labs.to(device)
    
    interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())
#     interpolates_labes = alpha_labs * real_labels.float().detach() + ((1 - alpha_labs) * fake_labels.float().detach())
    interpolate_labels = alpha_label * real_labels.detach() + ((1 - alpha_label) * fake_labels.detach())

    interpolates = interpolates.to(device)
    interpolate_labels = interpolate_labels.to(device)
    interpolates.requires_grad_(True)

    disc_interpolates, _ = netD(interpolates, interpolate_labels)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).to(device),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradients = gradients.view(gradients.size(0), -1)                              
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10
    return gradient_penalty

In [11]:
def gen_noise_label(batch_size):
    label = np.random.randint(0, 10, batch_size)
    #prefix = np.zeros((batch_size, 10))
    #prefix[:, label] = 1
    return label # prefix

In [12]:
def gen_noise(batch_size, label):
    prefix = np.zeros((batch_size, 10))
    prefix[np.arange(batch_size), label] = 1
    z = np.random.normal(0, 1, (batch_size, z_dimension))
    prefix = prefix / np.linalg.norm(prefix)
    z[:, :10] = prefix
    return torch.from_numpy(z).float().to(device)

In [13]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
#         self.label_embedding = nn.Embedding(10, np.prod(img_shape))
#         self.label_fc = nn.Linear(10, np.prod(img_shape))
        self.label_fc = nn.Sequential(
            nn.Linear(10, np.prod(img_shape)),
#             nn.BatchNorm1d(num_features=np.prod(img_shape)),
        )
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(2, 16, 3, padding=1, stride=2),
#             nn.BatchNorm2d(16),
            nn.LeakyReLU(.2, True),
#             nn.AvgPool2d(2, 2), 
        ) # b 16 16 16
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, padding=1, stride=2),
#             nn.BatchNorm2d(32),
            nn.LeakyReLU(.2, True),
#             nn.AvgPool2d(2, 2), 
        ) # b 32 8 8
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1, stride=2),
#             nn.BatchNorm2d(64),
            nn.LeakyReLU(.2, True),
#             nn.AvgPool2d(2, 2), 
        ) # b 64 4 4
        
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1, stride=2),
#             nn.BatchNorm2d(128),
            nn.LeakyReLU(.2, True),
#             nn.AvgPool2d(2, 2), 
        ) # b 128 2 2
        
        self.fc = nn.Sequential(
            nn.Linear(128 * 2 * 2, 1024),
            nn.LeakyReLU(.2, True),
            nn.Linear(1024, 1),
            nn.Sigmoid(),
#             nn.Tanh(),
        ) # b 1
        
        self.label = nn.Sequential(
            nn.Linear(128 * 2 * 2, 10),
            nn.Sigmoid(),
#             nn.Tanh(),
        )

    def forward(self, imgs, real_labels=None, fake_labels=None): # b 1 32 32
        
#         if(fake_labels is None and real_labels is not None):
#             _imgs = torch.cat((imgs, self.label_embedding(real_labels).reshape(imgs.size(0), *img_shape)), 1)
#         else:
# #             _labs = self.label_embedding(real_labels).reshape(imgs.size(0), *img_shape) + self.label_embedding(real_labels).reshape(imgs.size(0), *img_shape)
#             _labs = torch.zeros_like(imgs)
#             _imgs = torch.cat((imgs, _labs), 1)
        
        _imgs = torch.cat((imgs, self.label_fc(real_labels).reshape(imgs.size(0), *img_shape)), 1)
        
        outs = self.conv1(_imgs)
        outs = self.conv2(outs)
        outs = self.conv3(outs)
        outs = self.conv4(outs)
        outs = outs.view(imgs.size(0), -1)
        img = self.fc(outs)
        lab = self.label(outs)
        
        return img, lab # b 1 1 1, b 10

In [14]:
class Generator(nn.Module):
    def __init__(self, inp_dim, num_feature):
        super(Generator, self).__init__()

#         self.label_emb = nn.Embedding(10, 10)
#         self.label_fc = nn.Linear(10, inp_dim)
        
        self.label_fc = nn.Sequential(
            nn.Linear(10, 10),
#             nn.BatchNorm1d(num_features=10),
        )
        
        self.fc = nn.Sequential(
            nn.Linear(inp_dim, np.prod(num_feature)),
#             nn.Tanh(),
#             nn.Sigmoid(),
            nn.BatchNorm1d(np.prod(num_feature)),
        ) # b *num_feature  b 32 4 4
        
        self.upsample1 = nn.Sequential(
#             nn.BatchNorm2d(128),
            nn.ConvTranspose2d(32, 16, 4, 2, 1),
#             nn.functional.interpolate(scale_factor=2),
#             nn.Upsample(scale_factor=2),
#             nn.Conv2d(32, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(.2, True),
#             nn.ReLU(True),
        ) # b 16 8 8
        
        self.upsample2 = nn.Sequential(
#             nn.functional.interpolate(scale_factor=2),
            nn.ConvTranspose2d(16, 8, 4, 2, 1, 0),
#             nn.Upsample(scale_factor=2),
#             nn.Conv2d(16, 8, 3, padding=1),
            nn.BatchNorm2d(8),
            nn.LeakyReLU(.2, True),
#             nn.ReLU(True),
        ) # b 8 16 16
        
        self.upsample3 = nn.Sequential(
            nn.ConvTranspose2d(8, 4, 4, 2, 1),
#             nn.functional.interpolate(scale_factor=2),
#             nn.Upsample(scale_factor=2),
#             nn.Conv2d(8, 4, 3, padding=1, stride=1),
            nn.BatchNorm2d(4),
            nn.LeakyReLU(.2, True),
#             nn.ReLU(True),
        ) # b 4 32 32
        
        self.conv = nn.Sequential(
#             nn.functional.interpolate(scale_factor=2),
            nn.Conv2d(4, 1, 3, padding=1, stride=1),
            nn.Tanh(),
        ) # b 1 32 32

    def forward(self, noise, labels):
        
#         gen_input = torch.cat((labels.reshape(noise.size(0), -1), noise), -1)
        gen_input = torch.cat((self.label_fc(labels), noise), -1)
#         gen_input = noise + self.label_fc(labels)
        
#         gen_input = torch.cat((labels, noise), -1)
    
        outs = self.fc(gen_input)
        outs = outs.view(noise.size(0), *num_feature)
        outs = self.upsample1(outs)
        outs = self.upsample2(outs)
        outs = self.upsample3(outs)

        outs = self.conv(outs)
        
        return outs


In [15]:
D = Discriminator().to(device)
G = Generator(z_dimension + 10, num_feature).to(device)

# D.weight_init(.0, 0.02)
# G.weight_init(.0, 0.02)

D = nn.DataParallel(D, device_ids=device_ids).to(device)
G = nn.DataParallel(G, device_ids=device_ids).to(device)

criterion = nn.CrossEntropyLoss()

d_optimezer = optim.Adam(D.parameters(), lr=2e-4, betas=(0, 0.9))
g_optimezer = optim.Adam(G.parameters(), lr=2e-4, betas=(0, 0.9))

one = torch.FloatTensor([1])
mone = one * -1
one = one.to(device)
mone = mone.to(device)

In [16]:
writer = SummaryWriter('./log/cnn_condition_wgan_gp')

In [17]:
img_path = "save_images/cnn_condition_wgan_img"
if not os.path.exists(img_path): os.makedirs(img_path)

In [18]:
# condition_label = torch.from_numpy(np.arange(10)).cuda()
condition_label = np.arange(10)
condition_noise = torch.randn(condition_label.shape[0], z_dimension).cuda()

condition_label_one_hot = np.zeros((condition_label.shape[0], 10))
condition_label_one_hot[np.arange(condition_label.shape[0]), condition_label] = 1
condition_label_one_hot = torch.from_numpy(condition_label_one_hot).float().cuda()

In [None]:
total_count = len(dataloader)
for epoch in tqdm_notebook(range(num_epochs)):
    _step = epoch * total_count
    
    d_loss_total = .0
    g_loss_total = .0
    w_loss_total = .0
    for i, (imgs, labs) in enumerate(dataloader):
        
        real_imgs = imgs.cuda()
        real_labs = labs.cuda()
        real_labs_one_hot = np.zeros((imgs.size(0), 10))
        real_labs_one_hot[np.arange(imgs.size(0)), labs] = 1
        real_labs_one_hot = torch.from_numpy(real_labs_one_hot).float().cuda()
        
        z = torch.randn(imgs.size(0), z_dimension).cuda()
        fake_labels = torch.from_numpy(np.random.randint(0, 10, imgs.size(0))).cuda()

        prefix = np.random.randint(0, 10, imgs.size(0))
        fake_labels_one_hot = np.zeros((imgs.size(0), 10))
        fake_labels_one_hot[np.arange(imgs.size(0)), prefix] = 1
        fake_labels_one_hot = torch.from_numpy(fake_labels_one_hot).float().cuda()
        
        ########## G ##########
#         for p in D.parameters():
#             p.requires_grad_(False)

        fake_imgs = G(z, fake_labels_one_hot)
        fake_out, fake_out_labels = D(fake_imgs, fake_labels_one_hot)

#         fake_labels_loss = criterion(fake_out_labels, fake_labels).mean()
#         fake_cost = -fake_out.mean()
#         g_loss = 1 * fake_labels_loss + fake_cost
#         g_loss = fake_out.mean()

        g_loss = adversarial_loss(fake_out, real)

        g_optimezer.zero_grad()
#         g_loss.backward(mone)
#         g_loss = -g_loss
        g_loss.backward()
        g_optimezer.step()
        #######################
        
        
        ########## D ##########
#         for p in D.parameters():
#             p.requires_grad_(True)
        
        real_out, real_labels = D(real_imgs, real_labs_one_hot)
#         d_loss_real_lab = criterion(real_labels, real_labs).mean()
        d_loss_real = real_out.mean()
        real_scores = real_out
        
        fake_out, fake_out_labels = D(fake_imgs.detach(), fake_labels_one_hot)
#         d_loss_fake_lab = criterion(fake_out_labels, fake_labels).mean()
        d_loss_fake = fake_out.mean()
        fake_scores = fake_out
        
#         gradient_penalty = calc_gradient_penalty(D, real_imgs, fake_imgs, real_labs_one_hot, fake_labels_one_hot)
        
        d_loss = (d_loss_fake + d_loss_real) / 2 # + gradient_penalty + 1 * d_loss_real_lab
        
        d_optimezer.zero_grad()
        d_loss.backward()
        d_optimezer.step()
        #######################
        
        w_dist = d_loss_fake - d_loss_real
        
        d_loss_total += d_loss.item() * imgs.size(0)
        g_loss_total += g_loss.item() * imgs.size(0)
        w_loss_total += w_dist * imgs.size(0)
        
        step = _step + i + 1
        
        if (i + 1) % 100 == 0:
            writer.add_scalar('Discriminator Real Loss', d_loss_real.item(), step)
            writer.add_scalar('Discriminator Fake Loss', d_loss_fake.item(), step)
            writer.add_scalar('Discriminator Loss', d_loss.item(), step)
            writer.add_scalar('Generator Loss', g_loss.item(), step)
            writer.add_scalar('Wasserstein Distance', w_dist.item(), step)
        
        
        if (i + 1) % 300 == 0:
            tqdm.write('  Epoch[{}/{}], Step: {:6d}, d_loss: {:.6f}, g_loss: {:.6f} real_scores: {:.6f}' \
', fake_scores: {:.6f}, W: {:.6f}'.format(epoch+1, num_epochs, 
                                          (i+1) * BATCH_SIZE, 
                                          d_loss, g_loss, 
                                          real_scores.mean(), 
                                          fake_scores.mean(), w_dist))
    
    
    setp = (epoch + 1) * total_count
    _d_loss_total = d_loss_total / (total_count * (epoch + 1))
    _g_loss_total = g_loss_total / (total_count * (epoch + 1))
    _w_loss_total = w_loss_total / (total_count * (epoch + 1))
    
    writer.add_scalar('Discriminator Total Loss', _d_loss_total, step)
    writer.add_scalar('Generator Total Loss', _g_loss_total, step)
    
    tqdm.write("Finish Epoch [{}/{}], D Loss: {:.6f}, G Loss: {:.6f}, W: {:.6f}".format(epoch+1, 
                                                                             num_epochs, 
                                                                             _d_loss_total,
                                                                             _g_loss_total,
                                                                             _w_loss_total, ))
    
    writer.add_image('Generator Image', make_grid(fake_imgs.view(-1, 1, 32, 32).cpu().data, normalize=True, scale_each=True), step)
    condition_imgs = G(condition_noise, condition_label_one_hot)
    writer.add_image('Condition Generator Image', make_grid(condition_imgs.view(-1, 1, 32, 32).cpu().data, normalize=True, scale_each=True), step)
    
    if epoch == 0:
        real_images = real_imgs.view(-1, 1, 32, 32).cpu().data
        save_image(real_images, os.path.join(img_path, 'real_images.png'))
    
    
    fake_images = fake_imgs.view(-1, 1, 32, 32).cpu().data
    save_image(fake_images, os.path.join(img_path, 'fake_images-{}.png'.format(epoch+1)))
    save_image(condition_imgs, os.path.join(img_path, 'condition_images-{}.png'.format(epoch+1)))

HBox(children=(IntProgress(value=0), HTML(value='')))

  Epoch[1/100], Step:  38400, d_loss: 0.000000, g_loss: 2.309500 real_scores: 0.000000, fake_scores: 0.000000, W: 0.000000
Finish Epoch [1/100], D Loss: 1.589986, G Loss: 293.206078, W: 0.271360
  Epoch[2/100], Step:  38400, d_loss: 0.000000, g_loss: 2.310557 real_scores: 0.000000, fake_scores: 0.000000, W: 0.000000
Finish Epoch [2/100], D Loss: 0.000000, G Loss: 147.524826, W: 0.000000


In [20]:
writer.close()

In [21]:
torch.save(D.state_dict(), './ser/condition_wgan_gp_discriminator_3.pt')
torch.save(G.state_dict(), './ser/condition_wgan_gp_generator_3.pt')

In [None]:
D.load_state_dict(torch.load('./ser/condition_wgan_gp_discriminator_3.pt'))
G.load_state_dict(torch.load('./ser/condition_wgan_gp_generator_3.pt'))

In [1]:
clabel = np.arange(10)
z = torch.randn(clabel.shape[0], z_dimension).cuda()
images = G(z, torch.from_numpy(clabel).cuda())
# save_image(images, 'xx.png')
plt.imshow(Image.fromarray(make_grid(images).mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()))
plt.show()

NameError: name 'np' is not defined