In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import torchvision

from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Используемое устройство: {device}")

anime_path = "/kaggle/input/anime-faces/data"
real_path = "/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba"

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

class CustomDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# Загрузка данных
anime_images = [os.path.join(anime_path, img) for img in os.listdir(anime_path) 
                if img.lower().endswith(('.jpg', '.png', '.jpeg'))]
real_images = [os.path.join(real_path, img) for img in os.listdir(real_path)]

anime_dataset = CustomDataset(anime_images, transform=transform)
real_dataset = CustomDataset(real_images, transform=transform)

batch_size = 32
num_workers = 2

anime_loader = DataLoader(
    anime_dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
    persistent_workers=True,
    num_workers=num_workers
)

real_loader = DataLoader(
    real_dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
    persistent_workers=True,
    num_workers=num_workers
)

Используемое устройство: cuda


In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False),
            nn.InstanceNorm2d(in_channels)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator128(nn.Module):
    def __init__(self, num_residual=6):
        super().__init__()
        # Энкодер
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 7, padding=3),  # [B, 32, 128, 128]
            nn.InstanceNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # [B, 64, 64, 64]
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(64, 128, 3, stride=2, padding=1),  # [B, 128, 32, 32]
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        # Residual blocks
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(128) for _ in range(num_residual)]
        )
        
        # Декодер
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),  # [B, 64, 64, 64]
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),  # [B, 32, 128, 128]
            nn.InstanceNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(32, 3, 7, padding=3),  # [B, 3, 128, 128]
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.res_blocks(x)
        x = self.decoder(x)
        return x

class Discriminator128(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # [B, 32, 64, 64]
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # [B, 64, 32, 32]
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # [B, 128, 16, 16]
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, 4, padding=1),  # [B, 256, 13, 13]
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 1, 4, padding=1)  # [B, 1, 10, 10]
        )

    def forward(self, x):
        return self.model(x).view(-1, 1).squeeze(1)


G_real_to_anime = Generator128().to(device)
G_anime_to_real = Generator128().to(device)
D_real = Discriminator128().to(device)
D_anime = Discriminator128().to(device)


# class ResidualBlock(nn.Module):
#     def __init__(self, in_features):
#         super().__init__()
#         self.block = nn.Sequential(
#             nn.ReflectionPad2d(1),
#             nn.Conv2d(in_features, in_features, 3),
#             nn.InstanceNorm2d(in_features),
#             nn.ReLU(inplace=True),
#             nn.ReflectionPad2d(1),
#             nn.Conv2d(in_features, in_features, 3),
#             nn.InstanceNorm2d(in_features)
#         )

#     def forward(self, x):
#         return x + self.block(x)

# class Generator(nn.Module):
#     def __init__(self, num_residual_blocks=9):
#         super().__init__()
#         # Кодировщик
#         self.encoder = nn.Sequential(
#             nn.ReflectionPad2d(3),
#             nn.Conv2d(3, 64, 7),
#             nn.InstanceNorm2d(64),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(64, 128, 3, stride=2, padding=1),
#             nn.InstanceNorm2d(128),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(128, 256, 3, stride=2, padding=1),
#             nn.InstanceNorm2d(256),
#             nn.ReLU(inplace=True)
#         )
        
#         # Residual blocks
#         res_blocks = []
#         for _ in range(num_residual_blocks):
#             res_blocks.append(ResidualBlock(256))
#         self.res_blocks = nn.Sequential(*res_blocks)
        
#         # Декодировщик
#         self.decoder = nn.Sequential(
#             nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
#             nn.InstanceNorm2d(128),
#             nn.ReLU(inplace=True),
#             nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
#             nn.InstanceNorm2d(64),
#             nn.ReLU(inplace=True),
#             nn.ReflectionPad2d(3),
#             nn.Conv2d(64, 3, 7),
#             nn.Tanh()
#         )

#     def forward(self, x):
#         x = self.encoder(x)
#         x = self.res_blocks(x)
#         x = self.decoder(x)
#         return x


# class Discriminator(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.model = nn.Sequential(
#             nn.Conv2d(3, 64, 4, stride=2, padding=1),
#             nn.LeakyReLU(0.2, inplace=True),
            
#             nn.Conv2d(64, 128, 4, stride=2, padding=1),
#             nn.InstanceNorm2d(128),
#             nn.LeakyReLU(0.2, inplace=True),
            
#             nn.Conv2d(128, 256, 4, stride=2, padding=1),
#             nn.InstanceNorm2d(256),
#             nn.LeakyReLU(0.2, inplace=True),
            
#             nn.Conv2d(256, 512, 4, padding=1),
#             nn.InstanceNorm2d(512),
#             nn.LeakyReLU(0.2, inplace=True),
            
#             nn.Conv2d(512, 1, 4, padding=1)
#         )

#     def forward(self, x):
#         return self.model(x)

# G_real_to_anime = Generator().to(device)
# G_anime_to_real = Generator().to(device)
# D_real = Discriminator().to(device)
# D_anime = Discriminator().to(device)

In [None]:
checkpoint_path = "/kaggle/input/cyclegan6/pytorch/default/1/cyclegan_128_120.pth"
checkpoint = torch.load(checkpoint_path)
G_real_to_anime.load_state_dict(checkpoint['G_real_to_anime'])
G_anime_to_real.load_state_dict(checkpoint['G_anime_to_real'])
D_real.load_state_dict(checkpoint['D_real'])
D_anime.load_state_dict(checkpoint['D_anime'])

G_real_to_anime.train()
G_anime_to_real.train()
D_real.train()
D_anime.train()

# Функции потерь
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

# Оптимизаторы
optimizer_G = optim.Adam(
    list(G_real_to_anime.parameters()) + list(G_anime_to_real.parameters()),
    lr=0.00005,
    betas=(0.5, 0.999)
)
optimizer_D_real = optim.Adam(D_real.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D_anime = optim.Adam(D_anime.parameters(), lr=0.0001, betas=(0.5, 0.999))

# Mixed precision
scaler = torch.cuda.amp.GradScaler()

# Настройка обучения
start_epoch = 120
num_epochs = 140

# Цикл обучения
for epoch in range(start_epoch, num_epochs):
    progress_bar = tqdm(zip(real_loader, anime_loader), 
                       total=min(len(real_loader), len(anime_loader)),
                       desc=f"Эпоха [{epoch + 1}/{num_epochs}]")
    
    for real_batch, anime_batch in progress_bar:
        real_images = real_batch.to(device, non_blocking=True)
        anime_images = anime_batch.to(device, non_blocking=True)
        
        # Обучение генераторов
        with torch.cuda.amp.autocast():
            fake_anime = G_real_to_anime(real_images)
            fake_real = G_anime_to_real(anime_images)
            cycled_real = G_anime_to_real(fake_anime)
            cycled_anime = G_real_to_anime(fake_real)
            identity_real = G_anime_to_real(real_images)
            identity_anime = G_real_to_anime(anime_images)
            
            loss_GAN_R2A = criterion_GAN(D_anime(fake_anime), torch.ones_like(D_anime(fake_anime)))
            loss_GAN_A2R = criterion_GAN(D_real(fake_real), torch.ones_like(D_real(fake_real)))
            loss_cycle = criterion_cycle(cycled_real, real_images) + criterion_cycle(cycled_anime, anime_images)
            loss_identity = criterion_identity(identity_real, real_images) + criterion_identity(identity_anime, anime_images)
            
            total_loss_G = loss_GAN_R2A + loss_GAN_A2R + 15.0 * loss_cycle + 2.5 * loss_identity
        
        optimizer_G.zero_grad(set_to_none=True)
        scaler.scale(total_loss_G).backward()
        scaler.step(optimizer_G)
        
        # Обучение дискриминаторов
        for disc_optim, real_imgs, fake_imgs, disc_model in [
            (optimizer_D_real, real_images, fake_real.detach(), D_real),
            (optimizer_D_anime, anime_images, fake_anime.detach(), D_anime)
        ]:
            with torch.cuda.amp.autocast():
                real_loss = criterion_GAN(disc_model(real_imgs), torch.ones_like(disc_model(real_imgs)))
                fake_loss = criterion_GAN(disc_model(fake_imgs), torch.zeros_like(disc_model(fake_imgs)))
                loss_D = (real_loss + fake_loss) * 0.5
            
            disc_optim.zero_grad(set_to_none=True)
            scaler.scale(loss_D).backward()
            scaler.step(disc_optim)
        
        scaler.update()
        
        # Очистка памяти
        del fake_anime, fake_real, cycled_real, cycled_anime, identity_real, identity_anime
        torch.cuda.empty_cache()
        
        progress_bar.set_postfix({
            "Loss G": f"{total_loss_G.item():.4f}",
            "Loss D": f"{loss_D.item():.4f}"
        })
    
    # Сохранение моделей каждые 5 эпох
    if (epoch + 1) % 5 == 0:
        torch.save({
            'G_real_to_anime': G_real_to_anime.state_dict(),
            'G_anime_to_real': G_anime_to_real.state_dict(),
            'D_real': D_real.state_dict(),
            'D_anime': D_anime.state_dict()
        }, f"cyclegan_128_{epoch+1}.pth")
        
        # Генерация примеров
        with torch.no_grad():
            test_images = real_images[:8]
            fake_anime = G_real_to_anime(test_images)
            
            grid = torchvision.utils.make_grid(
                torch.cat([test_images*0.5+0.5, fake_anime*0.5+0.5], 0),
                nrow=8
            )
            torchvision.utils.save_image(grid, f"results_epoch_{epoch+1}.png")




  checkpoint = torch.load(checkpoint_path)
  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
Эпоха [121/140]: 100%|██████████| 674/674 [08:21<00:00,  1.34it/s, Loss G=3.1628, Loss D=0.2034]  
Эпоха [122/140]: 100%|██████████| 674/674 [08:20<00:00,  1.35it/s, Loss G=3.2658, Loss D=0.0888]
Эпоха [123/140]: 100%|██████████| 674/674 [08:20<00:00,  1.35it/s, Loss G=3.8279, Loss D=0.0907]
Эпоха [124/140]: 100%|██████████| 674/674 [08:20<00:00,  1.35it/s, Loss G=2.8202, Loss D=0.0742] 
Эпоха [125/140]: 100%|██████████| 674/674 [08:20<00:00,  1.35it/s, Loss G=3.0800, Loss D=0.0792]
Эпоха [126/140]: 100%|██████████| 674/674 [08:21<00:00,  1.34it/s, Loss G=2.4980, Loss D=0.3162]  
Эпоха [127/140]: 100%|██████████| 674/674 [08:21<00:00,  1.34it/s, Loss G=2.7276, Loss D=0.2370]
Эпоха [128/140]: 100%|██████████| 674/674 [08:21<00:00,  1.34it/s, Loss G=3.2432, Loss D=0.1183]
Эпоха [129/140]: 100%|██████████| 674/674 [08:21<00:00,  1.34it/s, Lo