# Генерация изображений животных

Импортируем необходимые библиотеки

In [8]:
import os
import cv2
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
import tensorflow as tf
from tensorflow.keras import utils

Загрузим датасет

In [9]:
url = 'http://www.soshnikov.com/permanent/data/petfaces.tar.gz'
dataset_path = Path(utils.get_file('petfaces', origin=url, untar=True))

images = []
labels = []
for root, dirs, files in os.walk(dataset_path):
    for file in files:
        if file.endswith('.jpg'):
            image_path = os.path.join(root, file)
            label = os.path.basename(root)
            image = cv2.imread(image_path)
            image = cv2.resize(image, (64, 64))
            images.append(image)
            labels.append(label)

images = np.array(images)
labels = np.array(labels)


Разделим на тренировочный и тестовый наборы

In [10]:
train_images, test_images, train_labels, test_labels = train_test_split(images, labels, test_size=0.2, random_state=42)

Загрузим ResNet

In [11]:
resnet = models.resnet50(pretrained=True)
resnet.eval()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def extract_features(images):
    features = []
    for image in images:
        image = transform(image)
        image = torch.unsqueeze(image, 0)
        with torch.no_grad():
            feature = resnet(image)
        features.append(feature.squeeze().numpy())
    return features

train_features = extract_features(train_images)
test_features = extract_features(test_images)



Определим генератор и дискриминатор

In [12]:
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)
    
input_dim = train_features[0].shape[0]
output_dim = train_features[0].shape[0]

generator = Generator(input_dim, output_dim)
discriminator = Discriminator(output_dim)

Определим функции потерь и оптимизаторов и обучим GAN

In [13]:
loss_fn = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.001)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.001)

num_epochs = 10
batch_size = 64

for epoch in range(num_epochs):
    for i in range(0, len(train_features), batch_size):
        real_features = torch.FloatTensor(train_features[i:i+batch_size])
        real_labels = torch.ones(real_features.size(0), 1)
        fake_labels = torch.zeros(real_features.size(0), 1)

        discriminator.zero_grad()
        real_output = discriminator(real_features)
        real_loss = loss_fn(real_output, real_labels)
        
        noise = torch.randn(real_features.size(0), input_dim)
        fake_features = generator(noise)
        fake_output = discriminator(fake_features.detach())
        fake_loss = loss_fn(fake_output, fake_labels)
        
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        generator.zero_grad()
        fake_output = discriminator(fake_features)
        g_loss = loss_fn(fake_output, real_labels)
        g_loss.backward()
        optimizer_G.step()

  real_features = torch.FloatTensor(train_features[i:i+batch_size])


Сгенерируем что-нибудь

In [16]:
num_samples = 10


with torch.no_grad():
    noise = torch.randn(num_samples, input_dim)
    generated_features = generator(noise)
    generated_image = generated_features[i].numpy()
    generated_image = ((generated_image + 1) / 2) * 255
    generated_image = generated_image.astype(np.uint8)
    cv2.imshow('Generated Image', generated_image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()