# Import


In [1]:
import argparse
from pathlib import Path
import os
import math
import shutil
import time
from collections import deque

import numpy as np
from tqdm.notebook import tqdm
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.logging import get_logger
from diffusers import DDIMScheduler, AutoencoderKL, UNet2DConditionModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils.torch_utils import randn_tensor

In [2]:
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--pretrained_path",
        type=str,
        default="stable-diffusion-v1-5/stable-diffusion-inpainting",
    )
    parser.add_argument(
        "--vae_model_path", type=str, default="stabilityai/sd-vae-ft-mse"
    )
    parser.add_argument("--dataset_path", type=str, default="./data/zalando-hd-resized")
    parser.add_argument("--output_dir", type=str, default="./ckpt/train-result")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--resolution_height", type=int, default=512)
    parser.add_argument("--resolution_width", type=int, default=384)
    parser.add_argument("--train_batch_size", type=int, default=4)
    parser.add_argument("--num_train_epochs", type=int, default=20)
    parser.add_argument("--max_train_steps", type=int, default=None)
    parser.add_argument("--learning_rate", type=float, default=2e-5)
    parser.add_argument("--lr_scheduler", type=str, default="cosine")
    parser.add_argument("--lr_warmup_steps", type=int, default=0)
    parser.add_argument("--adam_beta1", type=float, default=0.9)
    parser.add_argument("--adam_beta2", type=float, default=0.999)
    parser.add_argument("--adam_weight_decay", type=float, default=1e-2)
    parser.add_argument("--adam_epsilon", type=float, default=1e-08)
    parser.add_argument("--logging_dir", type=str, default="logs")
    parser.add_argument("--mixed_precision", type=str, default="no")
    parser.add_argument("--checkpointing_steps", type=int, default=1000)
    parser.add_argument("--checkpoints_total_limit", type=int, default=1)
    parser.add_argument("--resume_from_checkpoint", type=str, default=None)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)

    args = parser.parse_args([])
    return args

In [3]:
args = parse_args()

In [4]:
# Temp args for test
args.num_train_epochs = 1

# Set Logger and Accelerator


In [5]:
logging_dir = Path(args.output_dir, args.logging_dir)
logging_dir.mkdir(parents=True, exist_ok=True)

In [6]:
accelerator = Accelerator(
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    mixed_precision=args.mixed_precision,
    log_with="tensorboard",
    project_dir=logging_dir,
)

logger = get_logger("mac-train-1epoch")

if args.seed is not None:
    set_seed(args.seed)

if accelerator.is_main_process:
    os.makedirs(args.output_dir, exist_ok=True)

In [7]:
accelerator.init_trackers("mac-train-1epoch", config=vars(args))

# Dataset and DataLoader


In [8]:
class VTONDataset(Dataset):
    def __init__(self, dataset_path, split="train", height=512, width=384):
        self.dataset_path = dataset_path
        self.split = split
        self.height = height
        self.width = width

        # set directory path
        self.image_person_dir = Path(self.dataset_path).joinpath(split, "image")
        self.image_cloth_dir = Path(self.dataset_path).joinpath(split, "cloth")
        self.mask_cloth_person_dir = Path(self.dataset_path).joinpath(
            split, "agnostic-mask"
        )

        # Check if directories exist
        if not self.image_person_dir.exists():
            raise FileNotFoundError(
                f"Image person directory not found: {self.image_person_dir}"
            )
        if not self.image_cloth_dir.exists():
            raise FileNotFoundError(
                f"Image cloth directory not found: {self.image_cloth_dir}"
            )
        if not self.mask_cloth_person_dir.exists():
            raise FileNotFoundError(
                f"Mask cloth person directory not found: {self.mask_cloth_person_dir}"
            )

        self.image_persion_files = sorted(self.image_person_dir.glob("*.jpg"))

        self.valid_files = []
        for person_img_path in self.image_persion_files:
            base_name = person_img_path.stem
            cloth_img_path = Path(self.image_cloth_dir).joinpath(f"{base_name}.jpg")
            mask_img_path = Path(self.mask_cloth_person_dir).joinpath(
                f"{base_name}_mask.png"
            )

            if cloth_img_path.exists() and mask_img_path.exists():
                self.valid_files.append(
                    {
                        "person": person_img_path,
                        "cloth": cloth_img_path,
                        "mask": mask_img_path,
                    }
                )
            else:
                logger.warning(f"Skip {base_name}")

        if not self.valid_files:
            raise ValueError(f"No valid files found in {self.split} split.")

        # Image preprocessor
        self.vae_image_processor = VaeImageProcessor(
            vae_scale_factor=8, do_normalize=True, do_convert_rgb=True
        )
        # Mask preprocessor
        self.mask_processor = VaeImageProcessor(
            vae_scale_factor=8,
            do_normalize=False,
            do_binarize=True,
            do_convert_grayscale=True,
        )

        self.transform_resize = transforms.Resize((self.height, self.width))

    def __len__(self):
        return len(self.valid_files)

    def __getitem__(self, idx):
        item = self.valid_files[idx]

        try:
            # 이미지 로드
            person_image_hc = Image.open(item["person"]).convert("RGB")
            cloth_image_c = Image.open(item["cloth"]).convert("RGB")
            garment_mask_m = Image.open(item["mask"]).convert("L")

            # 리사이즈
            person_image_hc = self.transform_resize(person_image_hc)
            cloth_image_c = self.transform_resize(cloth_image_c)
            garment_mask_m = self.transform_resize(garment_mask_m)

            # VAE 프로세서를 사용하여 각 이미지를 전처리합니다.
            person_hc_processed = self.vae_image_processor.preprocess(
                person_image_hc, self.height, self.width
            )[0]
            cloth_c_processed = self.vae_image_processor.preprocess(
                cloth_image_c, self.height, self.width
            )[0]
            mask_m_processed = self.mask_processor.preprocess(
                garment_mask_m, self.height, self.width
            )[0]
            mask = self.prepare_mask_image(mask_m_processed)

            person_hm_processed = person_hc_processed * (mask < 0.5)

            # squeeze를 통해 불필요한 차원 제거
            return {
                "person_hc": person_hc_processed,  # (HC)
                "person_hm": person_hm_processed,  # (HM) - 새로 추가됨
                "cloth_c": cloth_c_processed,  # (C)
                "mask_m": mask_m_processed,  # (M)
            }

        except Exception as e:
            logger.error(
                f"Error processing item at index {idx}: ({item['person'].name}){e}"
            )

            if idx > 0:
                return self.__getitem__(idx - 1)
            else:
                dummy_hc = torch.zeros(
                    (3, self.height, self.width), dtype=torch.float32
                )
                dummy_c = torch.zeros((3, self.height, self.width), dtype=torch.float32)
                dummy_m = torch.zeros((1, self.height, self.width), dtype=torch.float32)
                print("It's dummy")
                return {"person_hc": dummy_hc, "cloth_c": dummy_c, "mask_m": dummy_m}

    def prepare_mask_image(self, mask_image):
        mask_image[mask_image < 0.5] = 0
        mask_image[mask_image >= 0.5] = 1
        return mask_image

In [9]:
train_dataset = VTONDataset(
    args.dataset_path,
    split="train",
    height=args.resolution_height,
    width=args.resolution_width,
)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=args.train_batch_size,
    num_workers=0,
)

# 모델 및 노이즈 스케줄러 설정


In [10]:
class Skip(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        return hidden_states

In [11]:
def fine_tuned_modules(unet):
    trainable_modules = torch.nn.ModuleList()
    for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]:
        if hasattr(blocks, "attentions"):
            trainable_modules.append(blocks.attentions)
        else:
            for block in blocks:
                if hasattr(block, "attentions"):
                    trainable_modules.append(block.attentions)
    return trainable_modules

In [12]:
def skip_cross_attentions(unet):
    attn_processors = {
        name: unet.attn_processors[name] if name.endswith("attn1.processor") else Skip()
        for name in unet.attn_processors.keys()
    }
    return attn_processors

In [13]:
# Noise Scheduler
noise_scheduler = DDIMScheduler.from_pretrained(
    args.pretrained_path, subfolder="scheduler"
)

In [14]:
# VAE
vae = AutoencoderKL.from_pretrained(args.vae_model_path)

# Freeze VAE
vae.requires_grad_(False)

AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (c

In [28]:
from accelerate import load_checkpoint_in_model
from accelerate.utils import load_state_dict

In [16]:
unet = UNet2DConditionModel.from_pretrained(args.pretrained_path, subfolder="unet")

In [25]:
unet_trainable = fine_tuned_modules(unet)

In [None]:
unet.s

AttributeError: 'ModuleList' object has no attribute 'keys'

In [21]:
set(unet_trainable.state_dict().keys())

{'0.0.norm.bias',
 '0.0.norm.weight',
 '0.0.proj_in.bias',
 '0.0.proj_in.weight',
 '0.0.proj_out.bias',
 '0.0.proj_out.weight',
 '0.0.transformer_blocks.0.attn1.to_k.weight',
 '0.0.transformer_blocks.0.attn1.to_out.0.bias',
 '0.0.transformer_blocks.0.attn1.to_out.0.weight',
 '0.0.transformer_blocks.0.attn1.to_q.weight',
 '0.0.transformer_blocks.0.attn1.to_v.weight',
 '0.0.transformer_blocks.0.attn2.to_k.weight',
 '0.0.transformer_blocks.0.attn2.to_out.0.bias',
 '0.0.transformer_blocks.0.attn2.to_out.0.weight',
 '0.0.transformer_blocks.0.attn2.to_q.weight',
 '0.0.transformer_blocks.0.attn2.to_v.weight',
 '0.0.transformer_blocks.0.ff.net.0.proj.bias',
 '0.0.transformer_blocks.0.ff.net.0.proj.weight',
 '0.0.transformer_blocks.0.ff.net.2.bias',
 '0.0.transformer_blocks.0.ff.net.2.weight',
 '0.0.transformer_blocks.0.norm1.bias',
 '0.0.transformer_blocks.0.norm1.weight',
 '0.0.transformer_blocks.0.norm2.bias',
 '0.0.transformer_blocks.0.norm2.weight',
 '0.0.transformer_blocks.0.norm3.bias',


In [30]:
loaded_checkpoint = load_state_dict(
    "/Users/seongbae/workspace/ezpz-test/try-off-anyone/ckpt-test2/unet_transformer_block.pt"
)

In [64]:
resutl = unet.load_state_dict(loaded_checkpoint, strict=False)

In [69]:
len(resutl.missing_keys)

270

In [63]:
for k_loaded in set(loaded_checkpoint.keys()):
    if "conv_in.weight" in k_loaded:
        print(k_loaded)

In [15]:
unet = UNet2DConditionModel.from_pretrained(args.pretrained_path, subfolder="unet")

# 확인용: 파라미터 이름 확인
all_trainable = set(
    name for name, param in unet.named_parameters() if param.requires_grad
)

# Freeze all U-Net parameters first
unet.requires_grad_(False)

# Unfreeze only transformer block
trainable_unet_module = fine_tuned_modules(unet)
if not trainable_unet_module:
    logger.warning("No trainable modules identified by fine_tuned_modules.")
    unet.requires_grad_(True)
else:
    for module_list in trainable_unet_module:
        for module in module_list:
            if module is not None:
                for param in module.parameters():
                    param.requires_grad = True
    params_to_optimize = []
    for module_list in trainable_unet_module:
        for module in module_list:
            if module is not None:
                params_to_optimize.extend(list(module.parameters()))

    if not params_to_optimize:
        logger.error(
            "No parameters found to optimize even after attempting to unfreeze transformer blocks."
        )
        unet.requires_grad_(True)
        params_to_optimize = unet.parameters()

# Skip cross-attentions as per
unet.set_attn_processor(skip_cross_attentions(unet))

# 예시: module_list가 [module1, module2, ...] 형태라고 가정
module_trainable = set(
    name for name, param in unet.named_parameters() if param.requires_grad
)

In [16]:
# 전체에서 module_list에 포함된 파라미터만 추출
intersection = all_trainable & module_trainable
# 전체에서 module_list에 없는 파라미터
only_in_all = all_trainable - module_trainable
# module_list에만 있는 파라미터 (거의 없음)
only_in_module = module_trainable - all_trainable

print("공통 파라미터:", len(intersection))
print("전체에만 있는 파라미터:", len(only_in_all))
print("module_list에만 있는 파라미터:", len(only_in_module))

공통 파라미터: 416
전체에만 있는 파라미터: 270
module_list에만 있는 파라미터: 0


# Optimizer


In [17]:
# Optimizer
optimizer = torch.optim.AdamW(
    params_to_optimize,
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)

# Learning Scheduler


In [18]:
# Scheduler
lr_scheduler = get_scheduler(
    "cosine",
    optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
    num_training_steps=(
        args.max_train_steps
        if args.max_train_steps
        else len(train_dataloader) * args.num_train_epochs
    ),
)

# Accelerator 준비


In [19]:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    unet, optimizer, train_dataloader, lr_scheduler
)
weight_dtype = torch.float16 if accelerator.mixed_precision == "fp16" else torch.float32
vae.to(accelerator.device, dtype=weight_dtype)


num_update_steps_per_epoch = math.ceil(
    len(train_dataloader) / args.gradient_accumulation_steps
)
if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
else:
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

total_batch_size = (
    args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
)

print("***** Running training *****")
print(f"  Num examples = {len(train_dataset)}")
print(f"  Num Epochs = {args.num_train_epochs}")
print(f"  Instantaneous batch size per device = {args.train_batch_size}")
print(
    f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
print(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
print(f"  Total optimization steps = {args.max_train_steps}")

***** Running training *****
  Num examples = 11647
  Num Epochs = 1
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 2912


# Resume


In [20]:
class SaveCkptsQueue:
    def __init__(self, maxlen=2):
        self.queue = deque(maxlen=maxlen)
        self.maxlen = maxlen

    def enqueue(self, item):
        if len(self.queue) >= self.maxlen:
            removed = self.queue.popleft()  # 가장 오래된 요소 제거
            print(f"Removed: {removed}")
        self.queue.append(item)
        print(f"Added: {item}, SaveCkptsQueue: {list(self.queue)}")
        try:
            return removed
        except:
            return None

    def get_queue(self):
        return list(self.queue)

In [21]:
global_step = 0
first_epoch = 0

args.resume_from_checkpoint = None
old_save_path = None

save_path_queue = SaveCkptsQueue(maxlen=args.checkpoints_total_limit)
save_path_queue.enqueue(args.resume_from_checkpoint)

if args.resume_from_checkpoint:
    old_save_path = args.resume_from_checkpoint

    if args.resume_from_checkpoint != "latest":
        path = os.path.basename(args.resume_from_checkpoint)
    else:
        dirs = os.listdir(args.output_dir)
        dirs = [d for d in dirs if d.startswith("checkpoint-")]
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
        path = dirs[-1] if len(dirs) > 0 else None

    if path is None:
        accelerator.print(
            f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
        )
        args.resume_from_checkpoint = None

    else:
        accelerator.print(f"Resuming from chekcpoint {path}")
        accelerator.load_state(os.path.join(args.output_dir, path))
        global_step = int(path.split("-")[1])
        resume_global_step = global_step * args.gradient_accumulation_steps
        first_epoch = global_step // num_update_steps_per_epoch
        resume_step = resume_global_step % (
            num_update_steps_per_epoch * args.gradient_accumulation_steps
        )

    print(f"global_step: {global_step}")
    print(f"first_epoch: {first_epoch}")
    print(f"resume_global_step: {resume_global_step}")
    print(f"resume_step: {resume_step}")

Added: None, SaveCkptsQueue: [None]


In [22]:
def encode(image, vae):
    image = (
        image.to(memory_format=torch.contiguous_format)
        .float()
        .to(vae.device, dtype=vae.dtype)
    )
    with torch.no_grad():
        return vae.encode(image).latent_dist.sample() * vae.config.scaling_factor

In [23]:
concat_d = -2

In [None]:
progress_bar = tqdm(
    range(global_step, args.max_train_steps),
    disable=not accelerator.is_local_main_process,
)

progress_bar.set_description("Steps")

# for epoch in range(args.num_train_epochs):
for epoch in range(first_epoch, args.num_train_epochs):
    unet.train()
    train_loss = 0.0
    for step, batch in enumerate(train_dataloader):
        if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
            if step % args.gradient_accumulation_steps == 0:
                progress_bar.update(1)
            continue

        with accelerator.accumulate(unet):
            latents_c = encode(batch["cloth_c"], vae)
            latents_hc = encode(batch["person_hc"], vae)
            latents_hm = encode(batch["person_hm"], vae)

            mask_m_resized = F.interpolate(
                batch["mask_m"].to(dtype=weight_dtype),
                size=latents_c.shape[-2:],
                mode="nearest",
            )

            latents_x = torch.cat([latents_hm, latents_hc], dim=concat_d)
            latents_m = torch.cat(
                [mask_m_resized, torch.zeros_like(mask_m_resized)], dim=concat_d
            )
            latents_cm = torch.cat([latents_c, latents_hc], dim=concat_d)

            noise = torch.randn_like(latents_cm)
            bsz = latents_cm.shape[0]

            timesteps = torch.randint(
                0,
                noise_scheduler.config.num_train_timesteps,
                (bsz,),
                device=latents_cm.device,
            ).long()

            # Add noise to the latents_c according to the noise magnitude at each timestep
            noisy_latents_cm = noise_scheduler.add_noise(latents_cm, noise, timesteps)

            # Prepare U-Net input
            model_input = torch.cat([noisy_latents_cm, latents_m, latents_x], dim=1)

            # Predict the noise residual
            noise_pred = unet(model_input, timesteps, encoder_hidden_states=None).sample

            # Calcluate loss
            criterion = nn.MSELoss(reduction="mean")
            loss = criterion(noise_pred.float(), noise.float())
            # loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")

            # Gather the losses
            # 각 프로세스의 loss만 모아서 평균
            losses = accelerator.gather(loss)
            avg_loss = losses.mean()
            train_loss += avg_loss.item() / args.gradient_accumulation_steps

            # Backpropagate
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                params_to_clip = [p for p in params_to_optimize if p.grad is not None]
                if params_to_clip:
                    accelerator.clip_grad_norm_(params_to_clip, 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1
            accelerator.log(
                {
                    "train_loss": train_loss,
                    "step_loss": loss.detach().item(),
                    "lr": lr_scheduler.get_last_lr()[0],
                },
                step=global_step,
            )
            train_loss = 0.0

    if accelerator.is_main_process:
        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
        accelerator.save_state(save_path)
        logger.info(f"Saved checkpoint to {save_path}")
        time.sleep(5)

        # Save U-Net
        unwrapped_unet = accelerator.unwrap_model(unet)

        trainable_state_dict = {}
        for name, param in unwrapped_unet.named_parameters():
            if param.requires_grad:
                trainable_state_dict[name] = param.detach().cpu()

        if trainable_state_dict:
            torch.save(
                trainable_state_dict,
                os.path.join(save_path, "unet_transformer_block.pt"),
            )
            shutil.copytree(
                os.path.join(save_path),
                os.path.join(
                    "/content/drive/MyDrive/01_Project/EZPZ_TryOff_Test/try-off-anyone-ckpts",
                    os.path.basename(save_path),
                ),
            )
            old_save_path = save_path_queue.enqueue(save_path)
            time.sleep(5)

            if old_save_path:
                shutil.rmtree(old_save_path)
                shutil.rmtree(
                    os.path.join(
                        "/content/drive/MyDrive/01_Project/EZPZ_TryOff_Test/try-off-anyone-ckpts",
                        os.path.basename(old_save_path),
                    )
                )
                time.sleep(5)

            logger.info(
                f"Saved fine-tuned U-Net transformer block to {os.path.join(save_path, 'unet_transformer_block.pt')}"
            )
        else:
            logger.warning("No trainable parameters found.")

        logs = {
            "step_loss": loss.detach().item(),
            "lr": lr_scheduler.get_last_lr()[0],
        }
        progress_bar.set_postfix(**logs)
        if global_step >= args.max_train_steps:
            break

    if global_step >= args.max_train_steps:
        break



  0%|          | 0/2912 [00:00<?, ?it/s]

336


In [None]:
# Save the final trained_model
accelerator.wait_for_everyone()
if accelerator.is_main_process:
    unwrapped_unet = accelerator.unwrap_model(unet)
    # Save only the fine-tuned parameters
    final_trainable_state_dict = {}
    for name, param in unwrapped_unet.named_parameters():
        if param.requires_grad:
            final_trainable_state_dict[name] = param.cpu().clone()

    if final_trainable_state_dict:
        torch.save(
            final_trainable_state_dict,
            os.path.join(args.output_dir, "final_unet_transformer_block.pt"),
        )
        shutil.copyfile(
            os.path.join(args.output_dir, "final_unet_transformer_block.pt"),
            os.path.join(
                "/content/drive/MyDrive/01_Project/EZPZ_TryOff_Test/try-off-anyone-ckpts",
                "final_unet_transformer_block.pt",
            ),
        )

        logger.info(
            f"Saved_final fine-tuned U-Net transformer block to {os.path.join(args.output_dir, 'final_unet_transformer_block.pt')}"
        )
    else:
        logger.warning("No trainable parameters found.")
        unwrapped_unet.save_pretrained(os.path.join(args.output_dir, "final_unet_full"))
accelerator.end_training()



In [123]:
len(params_to_optimize)

416

In [130]:
params_to_optimize[0]

Parameter containing:
tensor([0.2295, 0.1383, 0.1957, 0.3308, 0.2764, 0.2856, 0.3503, 0.3184, 0.2834,
        0.3350, 0.2896, 0.2191, 0.3296, 0.2869, 0.2834, 0.1260, 0.2546, 0.3284,
        0.3145, 0.2174, 0.1908, 0.1902, 0.2084, 0.1967, 0.2139, 0.1903, 0.2183,
        0.2134, 0.2102, 0.2156, 0.2798, 0.2783, 0.2747, 0.2322, 0.3201, 0.2776,
        0.1564, 0.3127, 0.2986, 0.2805, 0.2515, 0.2474, 0.2383, 0.1976, 0.2296,
        0.1821, 0.2603, 0.2424, 0.1537, 0.2502, 0.2028, 0.1324, 0.2859, 0.2834,
        0.2451, 0.3274, 0.2281, 0.3359, 0.4624, 0.2595, 0.2418, 0.2213, 0.2371,
        0.2261, 0.1958, 0.1824, 0.2279, 0.2147, 0.1810, 0.2321, 0.3286, 0.3616,
        0.1765, 0.1646, 0.2942, 0.2450, 0.3550, 0.4141, 0.1387, 0.2435, 0.2450,
        0.2421, 0.2524, 0.2571, 0.1702, 0.2021, 0.2209, 0.2233, 0.2327, 0.1830,
        0.3374, 0.3323, 0.3916, 0.1970, 0.1550, 0.1902, 0.3982, 0.1771, 0.4236,
        0.3120, 0.1170, 0.2908, 0.2659, 0.2910, 0.2747, 0.1738, 0.2896, 0.3196,
        0.1595, 0.