In [None]:
!git clone https://github.com/justinpinkney/clip2latent/

In [None]:
%cd clip2latent

In [None]:
%%writefile custom_requirements.txt

wandb==0.12.16
ninja==1.10.2.3
dalle2-pytorch==0.2.38
hydra-core==1.3.2
typer==0.4.1
joblib==1.1.0
webdataset==0.2.5
gradio==3.4
protobuf==3.20.1
-e .

In [None]:
!pip install -r custom_requirements.txt

In [None]:
%%writefile scripts/custom_generate_dataset.py
from multiprocessing import Process
import multiprocessing as mp
import math
from functools import partial
from pathlib import Path

import numpy as np
import torch
import typer
from joblib import Parallel, delayed
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F

from clip2latent.models import Clipper, load_sg


import multiprocessing as mp
try:
    mp.set_start_method('spawn')
except:
    pass

generators = {
    "sg2-ffhq-1024": partial(load_sg, 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl'),
    "sg3-lhq-256": partial(load_sg, 'https://huggingface.co/justinpinkney/stylegan3-t-lhq-256/resolve/main/lhq-256-stylegan3-t-25Mimg.pkl'),
    "my_model": partial(load_sg, '/kaggle/input/stylegan2-fashion-4000kimg/pytorch/default/1/network-snapshot-004000.pkl')
}

def mix_styles(w_batch, space):
    """Defines a style mixing procedure"""
    space_spec = {
        "w3": (4, 4, 10),
    }
    latent_mix = space_spec[space]

    bs = w_batch.shape[0]
    spec = torch.tensor(latent_mix).to(w_batch.device)

    index = torch.randint(0,bs, (len(spec),bs)).to(w_batch.device)
    return w_batch[index, 0, :].permute(1,0,2).repeat_interleave(spec, dim=1), spec

@torch.no_grad()
def run_folder_list(
    device_index,
    out_dir,
    generator_name,
    feature_extractor_name,
    out_image_size,
    batch_size,
    n_save_workers,
    samples_per_folder,
    folder_indexes,
    space="w",
    save_im=True,
    ):
    """Generate a directory of generated images and correspdonding embeddings and latents"""
    latent_dim = 512
    device = f"cuda:{device_index}"
    typer.echo(device_index)

    typer.echo("Loading generator")
    G = generators[generator_name]().to(device).eval()

    typer.echo("Loading feature extractor")
    feature_extractor = Clipper(feature_extractor_name).to(device)

    typer.echo("Generating samples")
    typer.echo(f"using space {space}")

    with Parallel(n_jobs=n_save_workers, prefer="threads") as parallel:
        for i_folder in folder_indexes:
            folder_name = out_dir/f"{i_folder:05d}"
            folder_name.mkdir(exist_ok=True)

            z = torch.randn(samples_per_folder, latent_dim, device=device)
            w = G.mapping(z, c=None)
            ds = torch.utils.data.TensorDataset(w)
            loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False, drop_last=False)
            for batch_idx, batch in enumerate(tqdm(loader, position=device_index)):
                if space == "w":
                    this_w = batch[0].to(device)
                    latents = this_w[:,0,:].cpu().numpy()
                else:
                    this_w, select_idxs = mix_styles(batch[0].to(device), space)
                    latents = this_w[:,select_idxs,:].cpu().numpy()

                out = G.synthesis(this_w)

                out = F.interpolate(out, (out_image_size,out_image_size), mode="area")
                image_features = feature_extractor.embed_image(out)
                image_features = image_features.cpu().numpy()

                if save_im:
                    out = out.permute(0,2,3,1).clamp(-1,1)
                    out = (255*(out*0.5 + 0.5).cpu().numpy()).astype(np.uint8)
                else:
                    out = [None]*len(latents)
                parallel(
                    delayed(process_and_save)(batch_size, folder_name, batch_idx, idx, latent, im, image_feature, save_im)
                    for idx, (latent, im, image_feature) in enumerate(zip(latents, out, image_features))
                    )

    typer.echo("finished folder")


def process_and_save(batch_size, folder_name, batch_idx, idx, latent, im, image_feature, save_im):
    count = batch_idx*batch_size + idx
    basename = folder_name/f"{folder_name.stem}{count:04}"
    np.save(basename.with_suffix(".latent.npy"), latent)
    np.save(basename.with_suffix(".img_feat.npy"), image_feature)
    if save_im:
        im = Image.fromarray(im)
        im.save(basename.with_suffix(".gen.jpg"), quality=95)

def make_webdataset(in_dir, out_dir):
    import tarfile
    in_folders = [x for x in Path(in_dir).glob("*") if x.is_dir]
    out_dir = Path(out_dir)
    out_dir.mkdir()
    for folder in in_folders:
        filename = out_dir/f"{folder.stem}.tar"
        files_to_add = sorted(list(folder.rglob("*")))

        with tarfile.open(filename, "w") as tar:
            for f in files_to_add:
                tar.add(f)


def main(
    out_dir:Path,
    n_samples:int=1_000_000,
    generator_name:str="sg2-ffhq-1024", # Key into `generators` dict`
    feature_extractor_name:str="ViT-B/32",
    n_gpus:int=2,
    out_image_size:int=256,
    batch_size:int=32,
    n_save_workers:int=16,
    space:str="w",
    samples_per_folder:int=10_000,
    save_im:bool=False, # Save the generated images?
    ):
    typer.echo("starting")

    out_dir.mkdir(parents=True)

    n_folders = math.ceil(n_samples/samples_per_folder)
    folder_indexes = range(n_folders)

    sub_indexes = np.array_split(folder_indexes, n_gpus)

    processes = []
    for dev_idx, folder_list in enumerate(sub_indexes):
        p = Process(
            target=run_folder_list,
            args=(
                dev_idx,
                out_dir,
                generator_name,
                feature_extractor_name,
                out_image_size,
                batch_size,
                n_save_workers,
                samples_per_folder,
                folder_list,
                space,
                save_im,
                ),
            )
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    typer.echo("finished all")

if __name__ == "__main__":
    # mp.set_start_method('spawn')
    typer.run(main)

In [None]:
from pathlib import Path
from scripts.custom_generate_dataset import main  # đổi thành tên file script của bạn, ví dụ: generate_dataset.py

# Chạy tạo dataset
main(
    out_dir=Path("/kaggle/working/my_dataset"),
    n_samples=500000,
    generator_name="my_model",      # key trong generators dict
    feature_extractor_name="ViT-B/32",
    n_gpus=1,
    out_image_size=256,
    batch_size=16,
    n_save_workers=8,
    space="w",
    samples_per_folder=5000,
    save_im=False
)


In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path("scripts").resolve()))

from scripts.custom_generate_dataset import make_webdataset
make_webdataset("/kaggle/working/my_dataset", "/kaggle/working/my_dataset_tar")


In [None]:
!ls /kaggle/working/my_dataset_tar/

In [None]:
%%writefile /kaggle/working/my_config.yaml
model:
  network:
    dim: 512
    num_timesteps: 1000
    depth: 12
    dim_head: 64
    heads: 12
  diffusion:
    image_embed_dim: 512
    timesteps: 1000
    cond_drop_prob: 0.2
    image_embed_scale: 1.0
    text_embed_scale: 1.0
    beta_schedule: cosine
    predict_x_start: true

data:
  bs: 128
  format: webdataset
  path: /kaggle/working/my_dataset_tar/{00000..00099}.tar
  embed_noise_scale: 1.0
  sg_pkl: /kaggle/input/stylegan2-fashion-4000kimg/pytorch/default/1/network-snapshot-004000.pkl
  clip_variant: ViT-B/32
  n_latents: 1
  latent_dim: 512
  latent_repeats: 
  - 14 
  val_im_samples: 4
  val_text_samples: /kaggle/working/val_text.txt
  val_samples_per_text: 1

train:
  znorm_embed: false
  znorm_latent: true
  max_it: 100000
  val_it: 10000
  lr: 1e-4
  weight_decay: 0.01
  ema_update_every: 10
  ema_beta: 0.9999
  ema_power: 0.75

logging: console
wandb_project: clip2latent
wandb_entity: null
name: null

device: cuda
resume: ~


In [None]:
import pandas as pd
from pathlib import Path
df = pd.read_csv("/kaggle/input/new-text-fashion-blip/data_BLIP_refined.csv")
df["caption"]
val_text_path = Path("/kaggle/working/val_text.txt")
df["caption"].head(2000).to_csv(val_text_path, index=False, header=False)

In [None]:
%%writefile scripts/custom_train1.py
import logging
from datetime import datetime
from functools import partial
from pathlib import Path

import hydra
import numpy as np
import torch
from omegaconf import OmegaConf
from tqdm.auto import tqdm

import wandb
from clip2latent.data import load_data
from clip2latent.models import load_models
from clip2latent.train_utils import (compute_val, make_grid,
                                     make_image_val_data, make_text_val_data)

logger = logging.getLogger(__name__)
noop = lambda *args, **kwargs: None
logfun = noop

class Checkpointer():
    """A small class to take care of saving checkpoints"""
    def __init__(self, directory, checkpoint_its):
        directory = Path(directory)
        self.directory = directory
        self.checkpoint_its = checkpoint_its
        if not directory.exists():
            directory.mkdir(parents=True)

    def save_checkpoint(self, model, iteration):
        if iteration % self.checkpoint_its:
            return

        k_it = iteration // 1000
        filename = self.directory/f"{k_it:06}.ckpt"
        checkpoint = {"state_dict": model.state_dict()}
        if hasattr(model, "cfg"):
            checkpoint["cfg"] = model.cfg

        print(f"Saving checkpoint to {filename}")
        torch.save(checkpoint, filename)



def validation(current_it, device, diffusion_prior, G, clip_model, val_data, samples_per_text):
    single_im = {"clip_features": val_data["val_im"]["clip_features"][0].unsqueeze(0)}
    captions = val_data["val_caption"]

    for input_data, key, cond_scale, repeats in zip(
        [val_data["val_im"], single_im, val_data["val_text"], val_data["val_text"]],
        ["image-similarity", "image-vars", "text2im", "text2im-super2"],
        [1.0, 1.0, 1.0, 2.0],
        [1, 8, samples_per_text, samples_per_text],
    ):
        tiled_data = input_data["clip_features"].repeat_interleave(repeats, dim=0)
        cos_sim, ims = compute_val(diffusion_prior, tiled_data, G, clip_model, device, cond_scale=cond_scale)
        logfun({f'val/{key}':cos_sim.mean()}, step=current_it)


        if key.startswith("text"):
            num_chunks = int(np.ceil(ims.shape[0]//repeats))
            for idx, (sim, im_chunk) in enumerate(zip(
                cos_sim.chunk(num_chunks),
                ims.chunk(num_chunks)
                )):

                caption = captions[idx]
                im = wandb.Image(make_grid(im_chunk), caption=f'{sim.mean():.2f} - {caption}')
                logfun({f'val/image/{key}/{idx}': im}, step=current_it)
        else:
            for idx, im in enumerate(ims.chunk(int(np.ceil(ims.shape[0]/16)))):
                logfun({f'val/image/{key}/{idx}': wandb.Image(make_grid(im))}, step=current_it)

    logger.info("Validation done.")

def train_step(diffusion_prior, device, batch):
    diffusion_prior.train()
    batch_z, batch_w = batch
    batch_z = batch_z.to(device)
    batch_w = batch_w.to(device)

    loss = diffusion_prior(batch_z, batch_w)
    loss.backward()
    return loss


def train(trainer, loader, device, val_it, validate, save_checkpoint, max_it, print_it=50):

    current_it = 0
    current_epoch = 0

    while current_it < max_it:

        logfun({'epoch': current_epoch}, step=current_it)
        pbar = tqdm(loader)
        for batch in pbar:
            if current_it % val_it == 0:
                with torch.no_grad():
                    validate(current_it, device, trainer)
                torch.cuda.empty_cache()

                #validate(current_it)


            trainer.train()
            batch_clip, batch_latent = batch

            input_args = {
                "image_embed": batch_latent.to(device),
                "text_embed": batch_clip.to(device)
            }
            loss = trainer(**input_args)

            if (current_it % print_it == 0):
                logfun({'loss': loss}, step=current_it)

            trainer.update()
            current_it += 1
            pbar.set_postfix({"loss": loss, "epoch": current_epoch, "it": current_it})

            save_checkpoint(trainer, current_it)

        current_epoch += 1


@hydra.main(config_path="../config", config_name="config")
def main(cfg):

    if cfg.logging == "wandb":
        wandb.init(
            project=cfg.wandb_project,
            config=OmegaConf.to_container(cfg),
            entity=cfg.wandb_entity,
            name=cfg.name,
        )
        global logfun
        logfun = wandb.log
    elif cfg.logging is None:
        logger.info("Not logging")
    elif cfg.logging == "console" or cfg.logging is None:
        # log ra console
        logfun = lambda d, step=None: print(f"[Step {step}] {d}" if step is not None else d)
        print("Console logging enabled.")
    else:
        raise NotImplementedError(f"Logging type {cfg.logging} not implemented")

    device = cfg.device
    stats, loader = load_data(cfg.data)
    G, clip_model, trainer = load_models(cfg, device, stats)

    text_embed, text_samples = make_text_val_data(G, clip_model, hydra.utils.to_absolute_path(cfg.data.val_text_samples))
    val_data = {
        "val_im": make_image_val_data(G, clip_model, cfg.data.val_im_samples, device),
        "val_text": text_embed,
        "val_caption": text_samples,
    }

    if 'resume' in cfg and cfg.resume is not None:
        # Does not load previous iteration count
        logger.info(f"Resuming from {cfg.resume}")
        trainer.load_state_dict(torch.load(cfg.resume, map_location="cpu")["state_dict"])

    checkpoint_dir = f"checkpoints/{datetime.now():%Y%m%d-%H%M%S}"
    checkpointer = Checkpointer(checkpoint_dir, cfg.train.val_it)
    validate1 = partial(validation,
        G=G,
        clip_model=clip_model,
        val_data=val_data,
        samples_per_text=cfg.data.val_samples_per_text,
        )


    train(trainer, loader, device,
        val_it=cfg.train.val_it,
        max_it=cfg.train.max_it,
        validate=validate1,
        save_checkpoint=checkpointer.save_checkpoint,
        )

if __name__ == "__main__":
    main()


In [None]:
import hydra
from omegaconf import OmegaConf
from scripts.custom_train1 import main as train_main  # import function main train

cfg_path = "/kaggle/working/my_config.yaml"
cfg = OmegaConf.load(cfg_path)
print(cfg)
# Chạy train
train_main(cfg)
