In [None]:
!git clone https://github.com/uko3/StyleGAN-nada.git
!pip install Ninja
!pip install git+https://github.com/openai/CLIP.git -q
%cd StyleGAN-nada

In [None]:
# Импорты
import torch
import torch.optim as optim
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import math
import copy
import re
import warnings
import clip

# Импортируем собственные модули
from modules.stylegan_arch.model import Generator
from modules.losses import CLIPLoss, CLIPDirectionalLoss
from modules.trainer import LatentStyleTrainer
from modules.utils import freeze_layers_adaptive, generate_visualize_and_save # Import freeze_layers_adaptive directly

# Зададим девайс
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Используется устройство: {device}")

In [None]:
# Загрузка предварительно обученной модели StyleGAN2
if not os.path.exists('stylegan2-ffhq-config-f.pt'):
    !gdown https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT -O stylegan2-ffhq-config-f.pt

# Параметры генератора
size = 1024
latent_dim = 512
n_mlp = 8
channel_multiplier = 2
ckpt_path = 'stylegan2-ffhq-config-f.pt'

# Инициализация генератора
generator = Generator(size, latent_dim, n_mlp, channel_multiplier=channel_multiplier).to(device)
generator.eval()
checkpoint = torch.load(ckpt_path)
generator.load_state_dict(checkpoint["g_ema"])

In [None]:
model_clip, preprocess = clip.load('ViT-B/32', device)

In [None]:
# import clip
source_class = "photo"
target_class = "sketch" #plasticine
text_source = clip.tokenize([source_class]).to(device)
text_target = clip.tokenize([target_class]).to(device)

with torch.no_grad():
    text_features_source = model_clip.encode_text(text_source)
    text_features_target = model_clip.encode_text(text_target)
text_features_source = text_features_source / text_features_source.norm(dim=-1, keepdim=True)
text_features_target = text_features_target / text_features_target.norm(dim=-1, keepdim=True)
sim = torch.nn.functional.cosine_similarity(text_features_target, text_features_source)
print(f"Text sim: {sim.item():.4f}")

# Инициализация CLIP Directional Loss
clip_directional_loss_fn = CLIPDirectionalLoss()


In [None]:
top_k = 10
# Инициализация тренера
trainer = LatentStyleTrainer(
    generator=generator,
    model_clip=model_clip,
    text_features_source=text_features_source,
    text_features_target=text_features_target,
    freeze_fn=lambda model_train, model_frozen, text_target_feat: freeze_layers_adaptive(
        model_train, model_frozen, text_target_feat, k=top_k, device=device
    ), 
    clip_directional_loss=clip_directional_loss_fn,
    latent_dim=latent_dim,
    batch_size=2,
    device=device,
    lr_generator=0.008,
    lr_lambda=0.02, 
    weight_decay=0.003,
    lambda_clip_init=1.0, # Установлено 5.0 для соответствия вашим логам
    lambda_l2_init=1.0, # Установлено 0.2 для соответствия вашим логам
)

# Запуск обучения
epochs_to_train = 41
trainer.train(epochs=epochs_to_train, freeze_each_epoch=True, reclassify=False)

In [None]:
# Построение графиков потерь
trainer.plot_losses()

# Визуализация и сохранение финальных изображений
seeds = (92126, 773, 779, 373, 2112)
generate_visualize_and_save(trainer, seeds, output_dir="../validation_outputs", folder_name="sketch")

# Визуализация направлений CLIP
latent_w_vis = trainer.sample_latent_w(seed=seeds[0]) # Используем первый сид для визуализации
image_frozen_vis, _ = trainer.model["generator_frozen"]([latent_w_vis], input_is_latent=True, randomize_noise=False)
image_styled_vis, _ = trainer.model["generator_train"]([latent_w_vis], input_is_latent=True, randomize_noise=False)
trainer.visualize_clip_directions(image_frozen=image_frozen_vis, image_styled=image_styled_vis,
                                  text_target=text_target, text_source=text_source, preprocess=preprocess)

In [None]:
# Сохранение обученной модели
# output_model_path = "pretrained_models/my_custom_stylegan_generator.pth"

output_model_path = "pretrained_models/"+target_class+".pth"
trainer.model["generator_train"].eval()
torch.save(trainer.model["generator_train"].state_dict(), output_model_path)
print(f"Модель сохранена в: {output_model_path}")