In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils
import torchvision
import matplotlib.pyplot as plt

In [2]:
class ConvolutionBlock(nn.Module):
    def __init__(self, input_dim, output_dim,kernek_size=4, stride=2, padding=1):
        super(ConvolutionBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_dim, output_dim, kernel_size=kernek_size, stride=stride, padding=padding),
            nn.BatchNorm2d(output_dim),
            nn.LeakyReLU(0.2, inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

In [3]:
class Encoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Encoder, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=2, stride=2, padding=0),  # 28x28 -> 13x13
            nn.LeakyReLU(0.2,True),nn.BatchNorm2d(16),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, stride=2, padding=0), # 13x13 -> 6x6
            nn.LeakyReLU(0.2,True),nn.BatchNorm2d(32),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=2, stride=2, padding=0), # 6x6 -> 2x2
            nn.LeakyReLU(0.2,True),nn.BatchNorm2d(64),
            nn.Conv2d(in_channels=64, out_channels=100, kernel_size=2, stride=2, padding=0), # 2x2 -> 1x1
            nn.Tanh()

        )
    
    def forward(self, x):
        return self.main(x)

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(in_channels=100, out_channels=256, kernel_size=7, stride=1, padding=0),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2,True),

            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2,True),

            nn.ConvTranspose2d(in_channels=128, out_channels=1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()  
        )

    def forward(self,x):
        return self.main(x)


In [5]:
class Discriminator(nn.Module):
    def __init__(self, encoded_size,decoded_size):
        super(Discriminator, self).__init__()
        f = 4
        s = 2
        p = 1
        self.encoded_size = encoded_size
        self.decoded_size = decoded_size
        self.main = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=2, stride=2, padding=0),  # 28x28 -> 13x13
            nn.LeakyReLU(0.2,True),nn.BatchNorm2d(16),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, stride=2, padding=0), # 13x13 -> 6x6
            nn.LeakyReLU(0.2,True),nn.BatchNorm2d(32),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=2, stride=2, padding=0), # 6x6 -> 2x2
            nn.LeakyReLU(0.2,True),nn.BatchNorm2d(64),
            nn.Conv2d(in_channels=64, out_channels=100, kernel_size=2, stride=2, padding=0), # 2x2 -> 1x1
            nn.Tanh()
        )
        self.fc = nn.Sequential(
            nn.Linear(200, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()

        )
    def forward(self, encoder, decoder):
        x = self.main(decoder)
        x = nn.Flatten()(x)
        encoder = nn.Flatten()(encoder)
        x = torch.cat((encoder, x), dim=1)
        return nn.Sigmoid()(self.fc(x))

In [6]:
batch_size = 64
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(input_dim=28, output_dim=100).to(device)
decoder = Generator().to(device)
discriminator = Discriminator(encoded_size=100, decoded_size=28).to(device)


optimizer_E = torch.optim.Adam(encoder.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_G = torch.optim.Adam(decoder.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
z_dim = 100

In [7]:
bce_loss = nn.BCELoss()

def discriminator_loss(real_pair, fake_pair):
    real_loss = bce_loss(real_pair, torch.ones_like(real_pair))
    fake_loss = bce_loss(fake_pair, torch.zeros_like(fake_pair))
    return real_loss + fake_loss

def generator_encoder_loss(fake_pair):
    return bce_loss(fake_pair, torch.ones_like(fake_pair))


In [8]:
def train_bigan(generator, encoder, discriminator, dataloader, optimizer_G, optimizer_E, optimizer_D, num_epochs):
    fix_noise = torch.randn(batch_size, z_dim,1,1).to(device)
    for epoch in range(num_epochs):
        for (real_data,_) in dataloader:
            # 从真实数据中生成隐变量
            real_data = real_data.to(device)
            z_real = encoder(real_data)

            # 从随机噪声中生成数据
            z_fake = torch.randn(batch_size, z_dim,1,1).to(device)
            fake_data = generator(z_fake)

            # 判别器在真实和伪造的 (x, z) 对上计算损失
            real_pair = discriminator(z_real, real_data)
            fake_pair = discriminator(z_fake, fake_data)
            d_loss = discriminator_loss(real_pair, fake_pair)
            
            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()

            # 更新生成器和编码器 (希望判别器将 (fake_data, z_fake) 识别为真实)
            z_fake = torch.randn(real_data.size(0), z_dim,1,1).to(device)
            fake_data = generator(z_fake)
            fake_pair = discriminator(z_fake, fake_data)
            ge_loss = generator_encoder_loss(fake_pair)
            
            optimizer_G.zero_grad()
            optimizer_E.zero_grad()
            ge_loss.backward()
            optimizer_G.step()
            optimizer_E.step()
        if (epoch+1) % 10 == 0:
            fake_images = generator(fix_noise).detach().cpu().numpy()
            plt.figure(figsize=(10, 10))
            for i in range(64):
                plt.subplot(8, 8, i+1)
                plt.imshow(fake_images[i, 0, :, :], cmap='gray')
                plt.axis('off')
            plt.savefig(f'./results/fake_images_epoch_{epoch+1}.png')

        print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {d_loss.item()}, Loss G/E: {ge_loss.item()}")
        
num_epochs = 50
train_bigan(decoder, encoder, discriminator, train_loader, optimizer_G, optimizer_E, optimizer_D, num_epochs)


Epoch [1/50] Loss D: 1.038232445716858, Loss G/E: 0.6898459792137146
Epoch [2/50] Loss D: 1.0187749862670898, Loss G/E: 0.6921552419662476


KeyboardInterrupt: 