In [1]:
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import torch.nn.functional as F
import os
import timeit

# import util
import numpy as np

# import base_module

# MMD

In [40]:
def normalize(x, dim=1):
    return x.div(x.norm(2, dim=dim).expand_as(x))

def match(x, y, dist):
    '''
    Computes distance between corresponding points points in `x` and `y`
    using distance `dist`.
    '''
    if dist == 'L2':
        return (x - y).pow(2).mean()
    elif dist == 'L1':
        return (x - y).abs().mean()
    elif dist == 'cos':
        x_n = normalize(x)
        y_n = normalize(y)
        return 2 - (x_n).mul(y_n).mean()
    else:
        assert dist == 'none', 'wtf ?'

In [2]:
def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
    m = K_XX.size(0)    # assume X, Y are same shape

    # Get the various sums of kernels that we'll use
    # Kts drop the diagonal, but we don't need to compute them explicitly
    if const_diagonal is not False:
        diag_X = diag_Y = const_diagonal
        sum_diag_X = sum_diag_Y = m * const_diagonal
    else:
        diag_X = torch.diag(K_XX)                       # (m,)
        diag_Y = torch.diag(K_YY)                       # (m,)
        sum_diag_X = torch.sum(diag_X)
        sum_diag_Y = torch.sum(diag_Y)

    Kt_XX_sums = K_XX.sum(dim=1) - diag_X             # \tilde{K}_XX * e = K_XX * e - diag_X
    Kt_YY_sums = K_YY.sum(dim=1) - diag_Y             # \tilde{K}_YY * e = K_YY * e - diag_Y
    K_XY_sums_0 = K_XY.sum(dim=0)                     # K_{XY}^T * e

    Kt_XX_sum = Kt_XX_sums.sum()                       # e^T * \tilde{K}_XX * e
    Kt_YY_sum = Kt_YY_sums.sum()                       # e^T * \tilde{K}_YY * e
    K_XY_sum = K_XY_sums_0.sum()                       # e^T * K_{XY} * e

    if biased:
        mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
            + (Kt_YY_sum + sum_diag_Y) / (m * m)
            - 2.0 * K_XY_sum / (m * m))
    else:
        mmd2 = (Kt_XX_sum / (m * (m - 1))
            + Kt_YY_sum / (m * (m - 1))
            - 2.0 * K_XY_sum / (m * m))

    return mmd2

In [3]:
def _mix_rbf_kernel(X, Y, sigma_list):
    assert(X.size(0) == Y.size(0))
    m = X.size(0)

    Z = torch.cat((X, Y), 0)
    ZZT = torch.mm(Z, Z.t())
    diag_ZZT = torch.diag(ZZT).unsqueeze(1)
    Z_norm_sqr = diag_ZZT.expand_as(ZZT)
    exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t()

    K = 0.0
    for sigma in sigma_list:
        gamma = 1.0 / (2 * sigma**2)
        K += torch.exp(-gamma * exponent)

    return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list)


def mix_rbf_mmd2(X, Y, sigma_list, biased=True):
    K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
    # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)
    return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)

In [35]:
# sigma for MMD
base = 1.0
sigma_list = [1, 2, 4, 8, 16]
sigma_list = [sigma / base for sigma in sigma_list]

# Base model

In [48]:
class Encoder(nn.Module):
    def __init__(self, isize, nc, k=100, ndf=64):
        super(Encoder, self).__init__()
        assert isize % 16 == 0, "isize has to be a multiple of 16"

        # input is nc x isize x isize
        main = nn.Sequential()
        main.add_module('initial_conv_{0}-{1}'.format(nc, ndf),
                        nn.Conv2d(nc, ndf, 4, 2, 1, bias=False))
        main.add_module('initial_relu_{0}'.format(ndf),
                        nn.LeakyReLU(0.2, inplace=True))
        csize, cndf = isize / 2, ndf

        while csize > 4:
            in_feat = cndf
            out_feat = cndf * 2
            main.add_module('pyramid_{0}-{1}_conv'.format(in_feat, out_feat),
                            nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False))
            main.add_module('pyramid_{0}_batchnorm'.format(out_feat),
                            nn.BatchNorm2d(out_feat))
            main.add_module('pyramid_{0}_relu'.format(out_feat),
                            nn.LeakyReLU(0.2, inplace=True))
            cndf = cndf * 2
            csize = csize / 2

        main.add_module('final_{0}-{1}_conv'.format(cndf, 1),
                        nn.Conv2d(cndf, k, 4, 1, 0, bias=False))

        self.main = main

    def forward(self, input):
        output = self.main(input)
        return output


# input: batch_size * k * 1 * 1
# output: batch_size * nc * image_size * image_size
class Decoder(nn.Module):
    def __init__(self, isize, nc, k=100, ngf=64):
        super(Decoder, self).__init__()
        assert isize % 16 == 0, "isize has to be a multiple of 16"

        cngf, tisize = ngf // 2, 4
        while tisize != isize:
            cngf = cngf * 2
            tisize = tisize * 2

        main = nn.Sequential()
        # Correcting string formatting method from 'formatted' to 'format'
        main.add_module('initial_{0}-{1}_convt'.format(k, cngf), nn.ConvTranspose2d(k, cngf, 4, 1, 0, bias=False))
        main.add_module('initial_{0}_batchnorm'.format(cngf), nn.BatchNorm2d(cngf))
        main.add_module('initial_{0}_relu'.format(cngf), nn.ReLU(True))

        csize = 4
        while csize < isize // 2:
            main.add_module('pyramid_{0}-{1}_convt'.format(cngf, cngf // 2),
                            nn.ConvTranspose2d(cngf, cngf // 2, 4, 2, 1, bias=False))
            main.add_module('pyramid_{0}_batchnorm'.format(cngf // 2),
                            nn.BatchNorm2d(cngf // 2))
            main.add_module('pyramid_{0}_relu'.format(cngf // 2),
                            nn.ReLU(True))
            cngf = cngf // 2
            csize = csize * 2

        main.add_module('final_{0}-{1}_convt'.format(cngf, nc), nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False))
        main.add_module('final_{0}_tanh'.format(nc), nn.Tanh())

        self.main = main

    def forward(self, input):
        output = self.main(input)
        return output


def grad_norm(m, norm_type=2):
    total_norm = 0.0
    for p in m.parameters():
        param_norm = p.grad.data.norm(norm_type)
        total_norm += param_norm ** norm_type
    total_norm = total_norm ** (1. / norm_type)
    return total_norm


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.1)
        m.bias.data.fill_(0)

# Net

In [49]:
# NetG is a decoder
# input: batch_size * nz * 1 * 1
# output: batch_size * nc * image_size * image_size
class NetG(nn.Module):
    def __init__(self, decoder):
        super(NetG, self).__init__()
        self.decoder = decoder

    def forward(self, input):
        output = self.decoder(input)
        return output

In [50]:
# NetD is an encoder + decoder
# input: batch_size * nc * image_size * image_size
# f_enc_X: batch_size * k * 1 * 1
# f_dec_X: batch_size * nc * image_size * image_size
class NetD(nn.Module):
    def __init__(self, encoder, decoder):
        super(NetD, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input):
        f_enc_X = self.encoder(input)
        f_dec_X = self.decoder(f_enc_X)

        f_enc_X = f_enc_X.view(input.size(0), -1)
        f_dec_X = f_dec_X.view(input.size(0), -1)
        return f_enc_X, f_dec_X

In [51]:
class ONE_SIDED(nn.Module):
    def __init__(self):
        super(ONE_SIDED, self).__init__()

        main = nn.ReLU()
        self.main = main

    def forward(self, input):
        output = self.main(-input)
        output = -output.mean()
        return output

# Args

In [52]:
# def get_args(parser):
#     parser.add_argument('--dataset', required=True, help='mnist | cifar10 | cifar100 | lsun | imagenet | folder | lfw ')
#     parser.add_argument('--dataroot', required=True, help='path to dataset')
#     parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
#     parser.add_argument('--batch_size', type=int, default=64, help='input batch size')
#     parser.add_argument('--image_size', type=int, default=64, help='the height / width of the input image to network')
#     parser.add_argument('--nc', type=int, default=3, help='number of channel')
#     parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
#     parser.add_argument('--max_iter', type=int, default=100, help='number of epochs to train for')
#     parser.add_argument('--lr', type=float, default=0.00005, help='learning rate, default=0.00005')
#     parser.add_argument('--gpu_device', type=int, default=0, help='using gpu device id')
#     parser.add_argument('--netG', default='', help="path to netG (to continue training)")
#     parser.add_argument('--netD', default='', help="path to netD (to continue training)")
#     parser.add_argument('--Diters', type=int, default=5, help='number of D iters per each G iter')
#     parser.add_argument('--experiment', default=None, help='Where to store samples and models')
#     return parser

In [53]:
# Modified code to work with dictionary access
args = {
    'dataset': 'mnist',
    'dataroot': '/content',
    'workers': 2,
    'batch_size': 64,
    'image_size': 32,
    'nc': 1,
    'nz': 100,
    'max_iter': 100,
    'lr': 0.00005,
    'gpu_device': 0,
    'Diters': 5,
    'experiment': None
}

print(args)

if args['experiment'] is None:
    args['experiment'] = 'samples'
os.system('mkdir -p {0}'.format(args['experiment']))  # Also added -p to avoid error if the directory already exists

if torch.cuda.is_available():
    args['cuda'] = True  # Store this info in args as well
    torch.cuda.set_device(args['gpu_device'])
    print("Using GPU device", torch.cuda.current_device())
else:
    raise EnvironmentError("GPU device not available!")

# Setting seeds for reproducibility
args['manual_seed'] = 1126
np.random.seed(seed=args['manual_seed'])
random.seed(args['manual_seed'])
torch.manual_seed(args['manual_seed'])
torch.cuda.manual_seed(args['manual_seed'])
cudnn.benchmark = True  # This ensures that CUDA's auto-tuner selects the best algorithm for the current configuration

{'dataset': 'mnist', 'dataroot': '/content', 'workers': 2, 'batch_size': 64, 'image_size': 32, 'nc': 1, 'nz': 100, 'max_iter': 100, 'lr': 5e-05, 'gpu_device': 0, 'Diters': 5, 'experiment': None}
Using GPU device 0


# Get data

In [54]:
# here we use MNIST

transform = transforms.Compose([
    transforms.Resize(args['image_size']),  # Resize the image if needed
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize for 1 channel
])

from torchvision.datasets import MNIST
trn_dataset = MNIST(root='path_to_data', train=True, download=True, transform=transform)
trn_loader = torch.utils.data.DataLoader(trn_dataset,
                                         batch_size=args['batch_size'],
                                         shuffle=True,
                                         num_workers=int(args['workers']))

# Architecture

In [73]:
# construct encoder/decoder modules
hidden_dim = args['nz']
G_decoder = Decoder(args['image_size'], args['nc'], k=args['nz'], ngf=64)
D_encoder = Encoder(args['image_size'], args['nc'], k=hidden_dim, ndf=64)
D_decoder = Decoder(args['image_size'], args['nc'], k=hidden_dim, ngf=64)

netG = NetG(G_decoder)
netD = NetD(D_encoder, D_decoder)
one_sided = ONE_SIDED()
print("netG:", netG)
print("netD:", netD)
print("oneSide:", one_sided)

netG.apply(weights_init)
netD.apply(weights_init)
one_sided.apply(weights_init)

netG: NetG(
  (decoder): Decoder(
    (main): Sequential(
      (initial_100-256_convt): ConvTranspose2d(100, 256, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (initial_256_batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (initial_256_relu): ReLU(inplace=True)
      (pyramid_256-128_convt): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (pyramid_128_batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pyramid_128_relu): ReLU(inplace=True)
      (pyramid_128-64_convt): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (pyramid_64_batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pyramid_64_relu): ReLU(inplace=True)
      (final_64-1_convt): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (final_1_tanh)

ONE_SIDED(
  (main): ReLU()
)

# Train

In [74]:
# put variable into cuda device
fixed_noise = torch.cuda.FloatTensor(64, args['nz'], 1, 1).normal_(0, 1)
one = torch.cuda.FloatTensor([1])
# mone = one * -1
mone = torch.tensor(-1.0).to(args['gpu_device'])
if args['cuda']:
    netG.cuda()
    netD.cuda()
    one_sided.cuda()
fixed_noise = Variable(fixed_noise, requires_grad=False)

In [75]:
# setup optimizer
optimizerG = torch.optim.RMSprop(netG.parameters(), lr=args['lr'])
optimizerD = torch.optim.RMSprop(netD.parameters(), lr=args['lr'])

lambda_MMD = 1.0
lambda_AE_X = 8.0
lambda_AE_Y = 8.0
lambda_rg = 16.0

- **Discriminator Loss**:
  - **MMD Loss** (`mmd2_D`): statistical distance between the distributions of real data (`f_enc_X_D`) and generated data (`f_enc_Y_D`) based on their features extracted by the discriminator's encoder.
  - **Rank Hinge Loss** (`one_side_errD`): penalizes the discriminator if the average feature from real data is not ranked higher than that from generated data (X<Y).

  - **L2 Loss for AE Reconstruction** (`L2_AE_X_D` and `L2_AE_Y_D`): measure how well the discriminator's decoder can reconstruct the original inputs from their encoded states.
    
> one_sided(f_enc_X_D.mean(0) - f_enc_Y_D.mean(0))

    If the input tensor is [1, -2, 3, -4, 0], after negating it becomes [-1, 2, -3, 4, 0].
    The ReLU operation then converts it to [0, 2, 0, 4, 0].
    Negating this result yields [0, -2, 0, -4, 0].
    The mean of this output is (-2 -4) / 5 = -1.2.
    The final output would then be 1.2, reflecting the average magnitude of the negative inputs.



- **Generator Loss**:
  - **MMD Loss** (`mmd2_G`): minimize the distance between the features of real and generated data.
  - **Rank Hinge Loss** (`one_side_errG`): penalizes the generator if it fails to make its generated data (fake data) rank as more realistic compared to real data, according to the discriminator's assessment.

In [76]:
time = timeit.default_timer()
gen_iterations = 0
for t in range(args['max_iter']):
    data_iter = iter(trn_loader)
    i = 0
    while (i < len(trn_loader)):
        # ---------------------------
        #        Optimize over NetD
        # ---------------------------
        for p in netD.parameters():
            p.requires_grad = True

        if gen_iterations < 25 or gen_iterations % 500 == 0:
            Diters = 100
            Giters = 1
        else:
            Diters = 5
            Giters = 1

        for j in range(Diters):
            if i == len(trn_loader):
                break

            # clamp parameters of NetD encoder to a cube
            # do not clamp paramters of NetD decoder!!!
            for p in netD.encoder.parameters():
                p.data.clamp_(-0.01, 0.01)

            data = next(data_iter)
            i += 1
            netD.zero_grad()

            x_cpu, _ = data
            x = Variable(x_cpu.cuda())
            batch_size = x.size(0)

            f_enc_X_D, f_dec_X_D = netD(x)

            # noise = torch.cuda.FloatTensor(batch_size, args['nz'], 1, 1).normal_(0, 1)
            # noise = Variable(noise, volatile=True)  # total freeze netG
            # y = Variable(netG(noise).data)
            with torch.no_grad():
              noise = torch.cuda.FloatTensor(batch_size, args['nz'], 1, 1).normal_(0, 1)
              y = netG(noise)

            f_enc_Y_D, f_dec_Y_D = netD(y)

            # compute biased MMD2 and use ReLU to prevent negative value
            mmd2_D = mix_rbf_mmd2(f_enc_X_D, f_enc_Y_D, sigma_list)
            mmd2_D = F.relu(mmd2_D)

            # compute rank hinge loss
            #print('f_enc_X_D:', f_enc_X_D.size())
            #print('f_enc_Y_D:', f_enc_Y_D.size())
            one_side_errD = one_sided(f_enc_X_D.mean(0) - f_enc_Y_D.mean(0))

            # compute L2-loss of AE
            L2_AE_X_D = match(x.view(batch_size, -1), f_dec_X_D, 'L2')
            L2_AE_Y_D = match(y.view(batch_size, -1), f_dec_Y_D, 'L2')

            errD = torch.sqrt(mmd2_D) + lambda_rg * one_side_errD - lambda_AE_X * L2_AE_X_D - lambda_AE_Y * L2_AE_Y_D
            errD.backward(mone)
            optimizerD.step()

        # ---------------------------
        #        Optimize over NetG
        # ---------------------------
        for p in netD.parameters():
            p.requires_grad = False

        for j in range(Giters):
            if i == len(trn_loader):
                break

            data = next(data_iter)
            i += 1
            netG.zero_grad()

            x_cpu, _ = data
            x = Variable(x_cpu.cuda())
            batch_size = x.size(0)

            f_enc_X, f_dec_X = netD(x)

            noise = torch.cuda.FloatTensor(batch_size, args['nz'], 1, 1).normal_(0, 1)
            noise = Variable(noise)
            y = netG(noise)

            f_enc_Y, f_dec_Y = netD(y)

            # compute biased MMD2 and use ReLU to prevent negative value
            mmd2_G = mix_rbf_mmd2(f_enc_X, f_enc_Y, sigma_list)
            mmd2_G = F.relu(mmd2_G)

            # compute rank hinge loss
            one_side_errG = one_sided(f_enc_X.mean(0) - f_enc_Y.mean(0))

            errG = torch.sqrt(mmd2_G) + lambda_rg * one_side_errG
            # errG.backward(one)
            errG.backward()
            optimizerG.step()

            gen_iterations += 1

        run_time = (timeit.default_timer() - time) / 60.0

        print('[%3d/%3d][%3d/%3d] [%5d] (%.2f m) MMD2_D %.6f hinge %.6f L2_AE_X %.6f L2_Ae_Y %.6f loss_D %.6f Loss_G %.6f f_X %.6f f_Y %.6f |gD| %.4f |gG| %.4f'
      % (t, args['max_iter'], i, len(trn_loader), gen_iterations, run_time,
         mmd2_D.item(), one_side_errD.item(),
         L2_AE_X_D.item(), L2_AE_Y_D.item(),
         errD.item(), errG.item(),
         f_enc_X_D.mean().item(), f_enc_Y_D.mean().item(),
         grad_norm(netD), grad_norm(netG)))

        if gen_iterations % 500 == 0:
            y_fixed = netG(fixed_noise)
            y_fixed.data = y_fixed.data.mul(0.5).add(0.5)
            f_dec_X_D = f_dec_X_D.view(f_dec_X_D.size(0), args['nc'], args['image_size'], args['image_size'])
            f_dec_X_D.data = f_dec_X_D.data.mul(0.5).add(0.5)
            vutils.save_image(y_fixed.data, '{0}/fake_samples_{1}.png'.format(args['experiment'], gen_iterations))
            vutils.save_image(f_dec_X_D.data, '{0}/decode_samples_{1}.png'.format(args['experiment'], gen_iterations))

    if t % 50 == 0:
        torch.save(netG.state_dict(), '{0}/netG_iter_{1}.pth'.format(args['experiment'], t))
        torch.save(netD.state_dict(), '{0}/netD_iter_{1}.pth'.format(args['experiment'], t))


[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
[ 66/100][786/938] [ 9675] (30.32 m) MMD2_D 0.896261 hinge -0.000000 L2_AE_X 0.011130 L2_Ae_Y 0.010119 loss_D 0.776716 Loss_G 0.942656 f_X 0.093013 f_Y 0.002562 |gD| 99.5588 |gG| 5.3656
[ 66/100][792/938] [ 9676] (30.33 m) MMD2_D 1.033207 hinge -0.000000 L2_AE_X 0.010781 L2_Ae_Y 0.008055 loss_D 0.865782 Loss_G 0.977699 f_X 0.019904 f_Y -0.077558 |gD| 100.7911 |gG| 5.9268
[ 66/100][798/938] [ 9677] (30.33 m) MMD2_D 1.065634 hinge -0.000000 L2_AE_X 0.012732 L2_Ae_Y 0.010886 loss_D 0.843356 Loss_G 0.945085 f_X 0.036063 f_Y -0.063976 |gD| 61.2907 |gG| 4.9217
[ 66/100][804/938] [ 9678] (30.34 m) MMD2_D 1.000409 hinge -0.000000 L2_AE_X 0.010181 L2_Ae_Y 0.013433 loss_D 0.811293 Loss_G 0.993400 f_X 0.042423 f_Y -0.053790 |gD| 51.0259 |gG| 5.7862
[ 66/100][810/938] [ 9679] (30.34 m) MMD2_D 0.826744 hinge -0.000000 L2_AE_X 0.014486 L2_Ae_Y 0.010006 loss_D 0.713322 Loss_G 0.865440 f_X 0.090979 f_Y 0.003911 |gD| 140.7933 |gG| 6.1232
[ 66/100][816/938] [ 968