#### Лабораторная 4. (дедлайн 26.03)
- Датасет: Celena https://www.kaggle.com/datasets/jessicali9530/celeba-dataset
- Подготовка датасета: использовать детектор/сегментатор для предварительного вырезания лиц людей
- Можно также добавить технику выравнивания лица (не обязательно)
- Обучить VAE или GAN для задачи безусловной генерации лиц
- А затем эту же сетку преобразовать в условную и обучить для задачи условной генерации лиц по полу человека (или любому другому признаку)
- Метрики: посчитать FID и IS, показать кривые обучения
- Возможно, потребуется использовать техники и улучшения VAE/GAN, чтобы они лучше обучились! Например, вместо простого GAN лучше посмотреть в сторону WGAN (Вассерштайн ГАН)

In [1]:
import os

import kagglehub

  from .autonotebook import tqdm as notebook_tqdm


#### Датасет

In [2]:
if not os.path.exists('./2'):
    # Download latest version
    path = kagglehub.dataset_download("jessicali9530/celeba-dataset")
    print("Path to dataset files:", path)

In [None]:
%%bash
# mv /Users/zwt/.cache/kagglehub/datasets/jessicali9530/celeba-dataset/versions/2 ./

#### Подготовка датасета

##### 1. Attr - condition

In [4]:
ATTR_PATH = './2/list_attr_celeba.csv'
DATASET_PATH = './2/img_align_celeba/img_align_celeba'
CROP_PATH = './crop_img'

In [5]:
import pandas as pd

attr_dataframe = pd.read_csv(ATTR_PATH)

In [6]:
print(f"attrs: {attr_dataframe.columns[1:].to_list()}")

attrs: ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young']


In [7]:
# choose the male to the imgae generate condition.

attr_dataframe['Male'].value_counts(normalize=True)

Male
-1    0.583246
 1    0.416754
Name: proportion, dtype: float64

In [8]:
image_male_dict = attr_dataframe.set_index('image_id')['Male'].to_dict()

print(image_male_dict.get('069065.jpg'))

-1


##### 2. Crop image

In [9]:
import os
import cv2
import numpy as np
from mtcnn import MTCNN
from tqdm import tqdm

def crop_faces(input_dir, output_dir, target_size=(128, 128), ratio=0.5):
    np.random.seed(703)
    
    detector = MTCNN()
    # Male Female
    os.makedirs(os.path.join(output_dir, 'male'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'female'), exist_ok=True)
    images = os.listdir(input_dir)
    
    for img_name in tqdm(np.random.choice(images, int(len(images) * ratio))):
        img_path = os.path.join(input_dir, img_name)
        key = os.path.basename(img_path)
        
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        # 检测人脸
        results = detector.detect_faces(img)
        if not results:
            print(f"No face detected in {img_name}")
            continue
        
        # 提取最大的人脸区域
        max_area = 0
        best_box = None
        for res in results:
            x, y, w, h = res['box']
            area = w * h
            if area > max_area:
                max_area = area
                best_box = (x, y, w, h)
        
        # 扩展边界框避免裁剪过紧
        x, y, w, h = best_box
        padding = 0.2  # 扩展20%区域
        x = max(0, int(x - padding * w))
        y = max(0, int(y - padding * h))
        w = int(w * (1 + 2*padding))
        h = int(h * (1 + 2*padding))
        
        # 裁剪并调整尺寸
        face = img[y:y+h, x:x+w]
        face = cv2.resize(face, target_size)
        
        # 保存结果
        output_path = os.path.join(os.path.join(output_dir, 'male' if image_male_dict.get(key) == 1 else 'female'), img_name)
        cv2.imwrite(output_path, cv2.cvtColor(face, cv2.COLOR_RGB2BGR))

In [10]:
if not os.path.exists(CROP_PATH): 
    crop_faces(DATASET_PATH, CROP_PATH, ratio=0.1)
else: print('Dataset crop done.')

Dataset crop done.


##### 3. Dataloader

In [11]:
images = []
for dir_path, dirnames, filenames in os.walk(CROP_PATH):
    for file in filenames:
        images.append(os.path.join(dir_path, file))

len(images)

19192

In [12]:
train_images, test_images = images[: int(len(images) * 0.9)], images[int(len(images) * 0.9): ]

In [13]:
len(train_images), len(test_images)

(17272, 1920)

In [14]:
from dataset.dataset import CelebADataset
import torchvision.transforms as transforms

 # 定义图像预处理流程
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),  # 数据增强
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化到[-1,1]
])

train_dataset = CelebADataset(train_images, image_male_dict, transform)
test_dataset = CelebADataset(test_images, image_male_dict, transform)

In [15]:
train_dataset.__getitem__(0)[0].shape, train_dataset.__getitem__(0)[1]

(torch.Size([3, 128, 128]), 1)

In [16]:
from torch.utils.data import DataLoader
batch_size = 4

# 创建DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)

valid_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
)

#### Модель

In [17]:
import torch
from gan.discriminator import Discriminator
from gan.generator import Generator

device = 'mps'
G = Generator().to(device)
D = Discriminator().to(device)

In [18]:
opt_g = torch.optim.RMSprop(G.parameters(), lr=5e-5)
opt_d = torch.optim.RMSprop(D.parameters(), lr=5e-5)

In [19]:
# WGAN损失函数
def critic_loss(real_scores, fake_scores):
    return -(torch.mean(real_scores) - torch.mean(fake_scores))  # 最大化真实与生成的差异

def generator_loss(fake_scores):
    return -torch.mean(fake_scores)  # 最小化生成的分数

In [20]:
import numpy as np
from scipy.linalg import sqrtm
from torchvision.models import inception_v3
from torch.nn.functional import softmax

def calculate_inception_score(images, batch_size=32, splits=10):
    """计算Inception Score (IS)"""
    model = inception_v3(pretrained=True, transform_input=False).eval().to(device)
    preds = []
    
    for i in range(0, len(images), batch_size):
        batch = images[i:i+batch_size].to(device)
        with torch.no_grad():
            pred = model(batch)
            preds.append(softmax(pred, dim=1).cpu().numpy())
    
    preds = np.concatenate(preds, axis=0)
    scores = []
    for k in range(splits):
        part = preds[k * (len(preds) // splits): (k+1) * (len(preds) // splits)]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        scores.append(np.exp(kl))
    
    return np.mean(scores), np.std(scores)

In [21]:
def calculate_fid(real_images, fake_images, batch_size=32):
    """计算Frechet Inception Distance (FID)"""
    model = inception_v3(pretrained=True, transform_input=False).eval().to(device)
    
    def get_features(images):
        features = []
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size].to(device)
            with torch.no_grad():
                feat = model(batch)
                features.append(feat.cpu().numpy())
        return np.concatenate(features, axis=0)
    
    real_feat = get_features(real_images)
    fake_feat = get_features(fake_images)
    
    mu1, sigma1 = np.mean(real_feat, axis=0), np.cov(real_feat, rowvar=False)
    mu2, sigma2 = np.mean(fake_feat, axis=0), np.cov(fake_feat, rowvar=False)
    
    ssdiff = np.sum((mu1 - mu2) ** 2)
    covmean = sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid

In [24]:
import torch
import numpy as np
from tqdm import tqdm
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from pytorch_fid import fid_score
from torchmetrics.image.inception import InceptionScore

# 训练函数封装
def train_wgan(
    G, D, 
    train_loader, 
    opt_g, opt_d,
    device,
    n_epochs=100,
    latent_dim=100,
    n_critic=5,
    lambda_gp=10,      # 梯度惩罚系数
    use_gp=True,        # 是否使用梯度惩罚
    use_sn=False,       # 是否使用谱归一化
    eval_interval=5,    # 评估间隔
    sample_interval=10, # 采样间隔
):
    # 创建日志字典
    history = {
        'g_loss': [],
        'd_loss': [],
        'fid': [],
        'is_score': [],
        'epochs': []
    }
    
    # 初始化评估指标
    inception_score = InceptionScore().to(device)
    
    # 训练循环
    for epoch in range(n_epochs):
        G.train()
        D.train()
        epoch_g_loss = 0.0
        epoch_d_loss = 0.0
        
        # 进度条
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}")
        
        for batch_idx, real_data in enumerate(pbar):
            real_imgs, real_label = real_data
            real_imgs = real_imgs.to(device)
            batch_size = real_imgs.size(0)
            
            # ========================
            #  训练判别器 (Critic)
            # ========================
            d_losses = []
            for _ in range(n_critic):
                # 生成假图像
                z = torch.randn(batch_size, latent_dim).to(device)
                fake_imgs = G(z).detach()
                
                # 计算判别器损失
                real_scores = D(real_imgs)
                fake_scores = D(fake_imgs)
                d_loss = -torch.mean(real_scores) + torch.mean(fake_scores)
                
                # 梯度惩罚 (WGAN-GP)
                if use_gp:
                    alpha = torch.rand(batch_size, 1, 1, 1).to(device)
                    interpolates = (alpha * real_imgs + (1 - alpha) * fake_imgs).requires_grad_(True)
                    d_interpolates = D(interpolates)
                    
                    gradients = torch.autograd.grad(
                        outputs=d_interpolates,
                        inputs=interpolates,
                        grad_outputs=torch.ones_like(d_interpolates).to(device),
                        create_graph=True,
                        retain_graph=True
                    )[0]
                    
                    gradients = gradients.view(gradients.size(0), -1)
                    gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
                    d_loss += lambda_gp * gp
                
                # 反向传播
                opt_d.zero_grad()
                d_loss.backward()
                opt_d.step()
                d_losses.append(d_loss.item())
            
            # ========================
            #  训练生成器
            # ========================
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_imgs = G(z)
            g_loss = -torch.mean(D(fake_imgs))
            
            opt_g.zero_grad()
            g_loss.backward()
            opt_g.step()
            
            # 记录损失
            epoch_g_loss += g_loss.item()
            epoch_d_loss += np.mean(d_losses)
            
            # 更新进度条
            pbar.set_postfix({
                'g_loss': g_loss.item(),
                'd_loss': np.mean(d_losses)
            })
            
            # 保存样本图像
            if batch_idx % sample_interval == 0:
                save_image(
                    fake_imgs[:16], 
                    f"samples/epoch_{epoch}_batch_{batch_idx}.png",
                    nrow=4, 
                    normalize=True
                )
        
        # 计算epoch平均损失
        epoch_g_loss /= len(train_loader)
        epoch_d_loss /= len(train_loader)
        history['g_loss'].append(epoch_g_loss)
        history['d_loss'].append(epoch_d_loss)
        
        # ========================
        #  评估指标
        # ========================
        if (epoch+1) % eval_interval == 0:
            G.eval()
            # 生成评估样本
            all_samples = []
            with torch.no_grad():
                for _ in range(10):  # 生成1000个样本
                    z = torch.randn(100, latent_dim).to(device)
                    samples = G(z)
                    all_samples.append(samples)
                all_samples = torch.cat(all_samples, dim=0)
            
            # 计算IS
            inception_score.update(all_samples)
            is_mean, is_std = inception_score.compute()
            
            # 计算FID（需要真实图像统计量）
            # 需提前计算真实图像的mu, sigma并保存为npz文件
            fid = fid_score.calculate_fid_given_samples(
                'real_stats.npz',
                all_samples.cpu().numpy(),
                device=device,
                batch_size=100
            )
            
            history['fid'].append(fid)
            history['is_score'].append(is_mean.item())
            history['epochs'].append(epoch)
            
            print(f"\nEpoch {epoch+1} | FID: {fid:.2f} | IS: {is_mean:.2f}±{is_std:.2f}")
            
            # 保存模型检查点
            torch.save(G.state_dict(), f"checkpoints/G_epoch_{epoch}.pth")
            torch.save(D.state_dict(), f"checkpoints/D_epoch_{epoch}.pth")
            
        # 绘制训练曲线
        plot_training_curves(history)
    
    return history

def plot_training_curves(history):
    plt.figure(figsize=(12, 4))
    
    # 损失曲线
    plt.subplot(131)
    plt.plot(history['g_loss'], label='Generator Loss')
    plt.plot(history['d_loss'], label='Critic Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    # FID曲线
    plt.subplot(132)
    plt.plot(history['epochs'], history['fid'], 'r-')
    plt.xlabel('Epoch')
    plt.ylabel('FID')
    
    # IS曲线
    plt.subplot(133)
    plt.plot(history['epochs'], history['is_score'], 'g-')
    plt.xlabel('Epoch')
    plt.ylabel('Inception Score')
    
    plt.tight_layout()
    plt.savefig('training_curves.png')
    plt.close()

In [25]:
# 启动训练
history = train_wgan(
    G, D, train_loader,
    opt_g, opt_d, device,
    n_epochs=100,
    use_gp=True,
    eval_interval=5
)

Epoch 1/100: 100%|██████████| 4318/4318 [46:50<00:00,  1.54it/s, g_loss=4.23, d_loss=-1.22]    
Epoch 2/100:   2%|▏         | 82/4318 [00:57<49:55,  1.41it/s, g_loss=5.44, d_loss=-0.67] 


KeyboardInterrupt: 