### Finetuning on 900

In [16]:
from improved_diffusion.script_util import create_model, create_gaussian_diffusion

a = 1 # Temporary


image_size=64
num_channels=128
num_res_blocks=3
num_heads=4
num_heads_upsample=-1
attention_resolutions="16,8"
dropout=0.0
learn_sigma=True
sigma_small=False
class_cond=False
diffusion_steps=4000
noise_schedule="cosine"
timestep_respacing=""
use_kl=False
predict_xstart=False
rescale_timesteps=True
rescale_learned_sigmas=True
use_checkpoint=False # To do gradient checkpointing
use_scale_shift_norm=True

clip_denoised=True
num_samples=1000
batch_size=16
use_ddim=False

# --------------Training related ------------

schedule_sampler="uniform" # For time-step, should it be uniform or changing based on loss function
lr=1e-4
weight_decay=0.0
lr_anneal_steps=0
microbatch=-1  # -1 disables microbatches
ema_rate="0.9999"  # comma-separated list of EMA values
log_interval=10
save_interval=10000 # Save checkpoints every X steps
use_fp16=False
fp16_scale_growth=1e-3

# ------------- PATHS -------------------
# Load pretrained model from here
load_model_path="./results/pretrained_imagenet/checkpoints/imagenet64_uncond_100M_1500K.pt"
# The dataset you want to finetune on
data_dir = './results/pokemon900/dataset/'
# If you are resuming a previously aborted training, include the path to the checkpoint here
resume_checkpoint="" 
# Where to log the training loss (File does not have to exist)
loss_logger="./results/pokemon900/finetuning/trainlog.csv"
# Directory to save checkpoints in
checkpoint_dir = "./results/pokemon900/finetuning/checkpoints/"
# Whenever you are saving checkpoints, a batch of images are also sampled, where to produce these images
save_samples_dir= "./results/pokemon900/finetuning/samples/"

### Create UNet and Diffusion model

In [17]:
import os 
import matplotlib.pyplot as plt 
import torch as th
from improved_diffusion.script_util import create_model, create_gaussian_diffusion


model = create_model(
        image_size = image_size,
        num_channels = num_channels,
        num_res_blocks = num_res_blocks,
        learn_sigma= learn_sigma,
        class_cond= class_cond,
        use_checkpoint= use_checkpoint,
        attention_resolutions=attention_resolutions,
        num_heads=num_heads,
        num_heads_upsample=num_heads_upsample,
        use_scale_shift_norm=use_scale_shift_norm,
        dropout=dropout,
        time_aware = time_aware # TIMEAWARE
)


diffusion = create_gaussian_diffusion(
    steps=diffusion_steps,
    learn_sigma=learn_sigma,
    sigma_small=sigma_small,
    noise_schedule=noise_schedule,
    use_kl=use_kl,
    predict_xstart=predict_xstart,
    rescale_timesteps=rescale_timesteps,
    rescale_learned_sigmas=rescale_learned_sigmas,
    timestep_respacing=timestep_respacing,
)


### Load pretrained model

In [18]:
model_path=load_model_path
checkpoint = th.load(model_path)
model.load_state_dict(checkpoint, strict = True) 

<All keys matched successfully>

### Train

In [19]:
from improved_diffusion.image_datasets import load_data
from improved_diffusion.resample import create_named_schedule_sampler
from improved_diffusion.train_util import TrainLoop

data = load_data(
    data_dir=data_dir,
    batch_size=batch_size,
    image_size=image_size,
    class_cond=False,
)
data

<generator object load_data at 0x000001FE2943F4C0>

### Train the model

In [None]:
from improved_diffusion import dist_util

model.to('cuda')

schedule_sampler = create_named_schedule_sampler("uniform", diffusion)

TrainLoop(
    model=model,
    diffusion=diffusion,
    data=data,
    batch_size=batch_size,
    microbatch=microbatch,
    lr=lr,
    ema_rate=ema_rate,
    log_interval=log_interval,
    save_interval=save_interval,
    resume_checkpoint=resume_checkpoint,
    use_fp16=use_fp16,
    fp16_scale_growth=fp16_scale_growth,
    schedule_sampler=schedule_sampler,
    weight_decay=weight_decay,
    lr_anneal_steps=lr_anneal_steps,
    # next 2 For logging
    loss_logger=loss_logger,
    checkpoint_dir = checkpoint_dir,
    # next 4 For sampling
    sample = True, # Doing sampling for a batch in training every time saving
    use_ddim=use_ddim,
    save_samples_dir=save_samples_dir,
    image_size=image_size
).run_loop()

0it [00:00, ?it/s]

saving model 0...
saving model 0.9999...
sampling 16 images


10000it [2:26:39,  1.29it/s]

saving model 0...
saving model 0.9999...
sampling 16 images


20000it [4:53:19,  1.26it/s] 

saving model 0...
saving model 0.9999...
sampling 16 images


30000it [7:20:00,  1.28it/s] 

saving model 0...
saving model 0.9999...
sampling 16 images


40000it [9:46:45,  1.28it/s] 

saving model 0...
saving model 0.9999...
sampling 16 images


50000it [12:13:22,  1.29it/s]

saving model 0...
saving model 0.9999...
sampling 16 images


60000it [14:40:46,  1.14it/s] 

saving model 0...
saving model 0.9999...
sampling 16 images


70000it [17:09:28,  1.30it/s] 

saving model 0...
saving model 0.9999...
sampling 16 images


80000it [19:36:33,  1.14it/s] 

saving model 0...
saving model 0.9999...
sampling 16 images


87155it [21:26:45,  1.22it/s] 

### Sampling

In [None]:
model.eval()

clip_denoised=True
num_samples=1000
batch_size=20
use_ddim=False 
save_samples_dir ="./results/pokemon900/a3ft/samples/_25000/" # Change these
model_path = "./results/pokemon900/a3ft/checkpoints/model250000.pt" # Change these

checkpoint = th.load(model_path)
model.load_state_dict(checkpoint)

model.to('cuda')
model.eval()

all_images = []
all_labels = []
i = 1
while len(all_images) * batch_size < num_samples:

    print(f"sampling {batch_size} images")
    sample_fn = (diffusion.p_sample_loop if not use_ddim else self.diffusion.ddim_sample_loop)
    sample = sample_fn(
        model,
        (batch_size, 3, image_size , image_size),
        clip_denoised=True,
        model_kwargs={}, # This is not needed, just class conditional stuff
        progress=True
    )
    sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous().cpu().numpy()

    # Save images
    for sidx, s in enumerate(sample):
        plt.imsave(os.path.join(save_samples_dir, f'{sidx + i*b}.jpg'), s)

    i = i+1