In [None]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
from matplotlib import pyplot as plt
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm  # 导入tqdm

# 自定义数据集类
class AnimeDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg'))]
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image
    
image_dir = '/home/yuchi/AI/anim'
image_size = 64
batch_size = 16

# 图像数据预处理
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载数据集
dataset = AnimeDataset(image_dir=image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [8]:
# 定义生成器网络
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 128 * 8 * 8),
            nn.ReLU(),
            nn.Unflatten(1, (128, 8, 8)),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# 定义判别器网络
# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),       # Output: (64, 32, 32)
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),     # Output: (128, 16, 16)
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, 4, 2, 1),    # Output: (128, 8, 8)
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 1),       # Ensure it matches 128 * 8 * 8 = 8192
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# 参数设置
nz = 100
num_epochs = 20
lr = 0.0001
beta1 = 0.5
fixed_noise = torch.randn(64, nz)

# 创建网络实例
generator = Generator()
discriminator = Discriminator()

# 损失函数和优化器
criterion = nn.BCELoss()
optimizerGen = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizerDis = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"

In [None]:
# 訓練循環
for epoch in range(num_epochs):
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, real_images in progress_bar:
        real_images = real_images.to(device)
        batch_size = real_images.size(0)

        # 訓練判別器
        optimizerDis.zero_grad()
        
        # 真實數據
        labels_real = torch.ones(batch_size, device=device)
        output_real = discriminator(real_images).view(-1)
        loss_real = criterion(output_real, labels_real)
        loss_real.backward()

        # 假數據
        noise = torch.randn(batch_size, nz, device=device)
        fake_images = generator(noise)
        labels_fake = torch.zeros(batch_size, device=device)
        output_fake = discriminator(fake_images.detach()).view(-1)
        loss_fake = criterion(output_fake, labels_fake)
        loss_fake.backward()
        optimizerDis.step()

        # 訓練生成器
        optimizerGen.zero_grad()
        output_fake = discriminator(fake_images).view(-1)
        loss_gen = criterion(output_fake, labels_real)  # 使用真實標籤
        loss_gen.backward()
        optimizerGen.step()

        # 更新進度條顯示
        progress_bar.set_description(
            f"Epoch [{epoch+1}/{num_epochs}]")
    
    # 保存模型
    torch.save(generator.state_dict(), f'/home/yuchi/AI/DCGAN/model/generator_{epoch+1}.pth')
    torch.save(discriminator.state_dict(), f'/home/yuchi/AI/DCGAN/model/discriminator_{epoch+1}.pth')

Epoch [1/20]: 100%|██████████| 3973/3973 [08:38<00:00,  7.67it/s]
Epoch [2/20]: 100%|██████████| 3973/3973 [08:40<00:00,  7.63it/s]
Epoch [3/20]: 100%|██████████| 3973/3973 [09:15<00:00,  7.15it/s]
Epoch [4/20]: 100%|██████████| 3973/3973 [09:40<00:00,  6.85it/s]
Epoch [5/20]: 100%|██████████| 3973/3973 [09:51<00:00,  6.72it/s]
Epoch [6/20]: 100%|██████████| 3973/3973 [09:55<00:00,  6.67it/s]
Epoch [7/20]: 100%|██████████| 3973/3973 [09:11<00:00,  7.20it/s]
Epoch [8/20]: 100%|██████████| 3973/3973 [09:31<00:00,  6.95it/s]
Epoch [9/20]: 100%|██████████| 3973/3973 [13:39<00:00,  4.85it/s]
Epoch [10/20]: 100%|██████████| 3973/3973 [11:48<00:00,  5.61it/s]
Epoch [11/20]: 100%|██████████| 3973/3973 [11:42<00:00,  5.65it/s]
Epoch [12/20]: 100%|██████████| 3973/3973 [12:00<00:00,  5.51it/s]
Epoch [13/20]: 100%|██████████| 3973/3973 [09:39<00:00,  6.86it/s]
Epoch [14/20]: 100%|██████████| 3973/3973 [09:13<00:00,  7.17it/s]
Epoch [15/20]: 100%|██████████| 3973/3973 [12:58<00:00,  5.11it/s]
Epoc