In [None]:
!pip install gradio
!pip install tqdm



In [393]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
from torchvision.utils import save_image, make_grid
from tqdm import tqdm

In [394]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [426]:
image_size = 28
batch_size = 128
latent_size = 100
num_classes = 10
stats = (0.5,), (0.5,)

# Dataset
train_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

train_dataset = MNIST(root="./data", transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [427]:
class DiscriminatorModel(nn.Module):
    def __init__(self):
        super(DiscriminatorModel, self).__init__()
        input_dim = 784 + 10
        output_dim = 1
        self.label_embedding = nn.Embedding(10, 10)

        self.hidden_layer1 = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )

        self.hidden_layer2 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )

        self.hidden_layer3 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )

        self.hidden_layer4 = nn.Linear(256, output_dim)

    def forward(self, x, labels):
        c = self.label_embedding(labels).view(labels.shape[0], -1)
        x = torch.cat([x, c], 1)
        output = self.hidden_layer1(x)
        output = self.hidden_layer2(output)
        output = self.hidden_layer3(output)
        output = self.hidden_layer4(output)

        return output


class GeneratorModel(nn.Module):
    def __init__(self):
        super(GeneratorModel, self).__init__()
        input_dim = 100 + 10
        output_dim = 784
        self.label_embedding = nn.Embedding(10, 10)
        
        self.hidden_layer1 = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LeakyReLU(0.2)
        )

        self.hidden_layer2 = nn.Sequential(
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2)
        )

        self.hidden_layer3 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2)
        )

        self.hidden_layer4 = nn.Sequential(
            nn.Linear(1024, output_dim),
            nn.Tanh()
        )
    
    def forward(self, x, labels):
        c = self.label_embedding(labels).view(labels.shape[0], -1)
        x = torch.cat([x, c], 1)
        output = self.hidden_layer1(x)
        output = self.hidden_layer2(output)
        output = self.hidden_layer3(output)
        output = self.hidden_layer4(output)
        return output


In [428]:
discriminator = DiscriminatorModel()
generator = GeneratorModel()
discriminator.to(device)
generator.to(device)

GeneratorModel(
  (label_embedding): Embedding(10, 10)
  (hidden_layer1): Sequential(
    (0): Linear(in_features=110, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (hidden_layer2): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (hidden_layer3): Sequential(
    (0): Linear(in_features=512, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (hidden_layer4): Sequential(
    (0): Linear(in_features=1024, out_features=784, bias=True)
    (1): Tanh()
  )
)

In [429]:
opt_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_g = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [430]:
# Функция денормализации изображений
def denorm(img_tensors):
    return img_tensors * stats[1][0] + stats[0][0]

# Функция сохранения сгенерированных изображений
sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok=True)

def save_samples(index, latent_tensors, labels, show=True):
    latent_tensors = latent_tensors.view(latent_tensors.size(0), -1)
    if latent_tensors.size(1) != labels.size(1):
        labels = labels.view(labels.size(0), -1)
    fake_images = generator(latent_tensors, labels).to(device)
    fake_fname = 'image-{0:0=4d}.png'.format(index)
    save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8)
    if show:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))



# Training functions
def train_discriminator(real_images, labels):
    opt_d.zero_grad()

    real_images = real_images.view(-1, 784).to(device)
    labels = labels.to(device)

    real_preds = discriminator(real_images, labels)
    real_loss = F.binary_cross_entropy_with_logits(real_preds, torch.ones_like(real_preds))
    real_score = torch.mean(torch.sigmoid(real_preds)).item()  # Применяем сигмоиду к real_preds

    latent = torch.randn(batch_size, latent_size).to(device)
    fake_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
    fake_images = generator(latent, fake_labels).detach().to(device)

    fake_preds = discriminator(fake_images, fake_labels)
    fake_loss = F.binary_cross_entropy_with_logits(fake_preds, torch.zeros_like(fake_preds))
    fake_score = torch.mean(torch.sigmoid(fake_preds)).item()  # Применяем сигмоиду к fake_preds

    loss = real_loss + fake_loss
    loss.backward()
    opt_d.step()

    return loss.item(), real_score, fake_score

def train_generator():
    opt_g.zero_grad()

    latent = torch.randn(batch_size, latent_size).to(device)
    labels = torch.randint(0, num_classes, (batch_size,)).to(device)

    fake_labels = labels
    fake_images = generator(latent, fake_labels)
    fake_preds = discriminator(fake_images, fake_labels)
    loss = F.binary_cross_entropy_with_logits(fake_preds, torch.ones_like(fake_preds))
    loss.backward()
    opt_g.step()

    return loss.item(), latent

def fit(epochs=10, start_idx=1):
    torch.cuda.empty_cache()

    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []

    for epoch in range(epochs):  # Iterate over the range of epochs
        for real_images, labels in tqdm(train_loader):
            real_images = real_images.to(device)
            labels = labels.to(device)
            loss_d, real_score, fake_score = train_discriminator(real_images, labels)
            loss_g, latent = train_generator()

        losses_g.append(loss_g)
        losses_d.append(loss_d)
        real_scores.append(real_score)
        fake_scores.append(fake_score)

        print(f"[{epoch+1}/{epochs}], loss_g: {loss_g:.4f}, loss_d: {loss_d:.4f}, real_score: {real_score:.4f}, fake_score: {fake_score:.4f}")

        # save_samples(epoch+start_idx, latent, labels, show=False)

    return losses_g, losses_d, latent, fake_scores



In [431]:
losses_g, losses_d, latent, fake_scores = fit(epochs=12, start_idx=1)

100%|██████████| 469/469 [00:53<00:00,  8.71it/s]


[1/12], loss_g: 1.7666, loss_d: 0.3318, real_score: 0.8137, fake_score: 0.0838


100%|██████████| 469/469 [00:53<00:00,  8.76it/s]


[2/12], loss_g: 4.0560, loss_d: 0.2708, real_score: 0.9425, fake_score: 0.1561


100%|██████████| 469/469 [00:53<00:00,  8.76it/s]


[3/12], loss_g: 2.9062, loss_d: 0.4146, real_score: 0.9653, fake_score: 0.2488


100%|██████████| 469/469 [00:54<00:00,  8.68it/s]


[4/12], loss_g: 2.2213, loss_d: 0.5825, real_score: 0.8057, fake_score: 0.1651


100%|██████████| 469/469 [00:54<00:00,  8.61it/s]


[5/12], loss_g: 1.5142, loss_d: 0.6921, real_score: 0.6629, fake_score: 0.0819


100%|██████████| 469/469 [00:54<00:00,  8.61it/s]


[6/12], loss_g: 1.8435, loss_d: 0.5887, real_score: 0.8036, fake_score: 0.1921


100%|██████████| 469/469 [00:53<00:00,  8.75it/s]


[7/12], loss_g: 2.3665, loss_d: 0.5687, real_score: 0.8537, fake_score: 0.2447


100%|██████████| 469/469 [00:54<00:00,  8.66it/s]


[8/12], loss_g: 1.2331, loss_d: 0.6864, real_score: 0.6867, fake_score: 0.1320


100%|██████████| 469/469 [00:53<00:00,  8.78it/s]


[9/12], loss_g: 1.9921, loss_d: 0.7326, real_score: 0.7865, fake_score: 0.3134


100%|██████████| 469/469 [00:54<00:00,  8.68it/s]


[10/12], loss_g: 1.4760, loss_d: 0.9328, real_score: 0.6470, fake_score: 0.2670


100%|██████████| 469/469 [00:53<00:00,  8.74it/s]


[11/12], loss_g: 2.0437, loss_d: 1.2769, real_score: 0.8015, fake_score: 0.5496


100%|██████████| 469/469 [00:53<00:00,  8.72it/s]

[12/12], loss_g: 1.5111, loss_d: 0.9639, real_score: 0.6518, fake_score: 0.2960





In [432]:

# Сохраняем обученную модель
torch.save(generator.state_dict(), 'CGAN.pth')
