In [None]:
import numpy as np 
from glob import glob
from PIL import Image
from pathlib import Path
from tqdm.auto import tqdm
from functools import partial
import matplotlib.pyplot as plt 

import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF

import utils
from models import UNet2DModel 

device = "cuda" if torch.cuda.is_available() else "cpu"

Before we proceed, please git clone the [apebase](https://github.com/skogard/apebase) repo to download 10,000 BAYC NFT images. We create a custom PyTorch dataset by implementing `__len__` and `__getitem__` methods.

In the `__getitem__` method, we read in the image using PIL library, convert into PyTorch tensor in the range [0,1] and resize into 512x512x3 so that it's compatible with the autoencoder model we use below.

In [None]:
class CustomDataset(Dataset):
  def __init__(self, root_dir, transform=None):
    self.root_dir = Path(root_dir)
    self.files = glob(str(self.root_dir / "*"))
    self.transform = transform

  def __len__(self): 
    return len(self.files)

  def __getitem__(self, idx):
    idx = idx if isinstance(idx, slice) else slice(idx, idx+1)
    xs = self.files[idx]
    out = []
    for x in xs:
      x = Image.open(x).convert("RGB")
      x = TF.to_tensor(x)
      x = TF.resize(x, size=(512,512), antialias=True)
      x = x.to(device)
      out.append(x)
    return torch.cat(out), ""

One key idea behind Stable Diffusion is "latent diffusion" which does the diffusion process in the "latent space" using the compressed representations from the autoencoder rather than raw images. These representations are information rich and can be small enough to handle on consumer hardware. Since the diffusion is done in the pixel space, which makes high-resolution image generation very computationally expensive. 
The autoencoder is trained to squish down an image into a smaller representation (encoder) and then reconstruct the image back from the compressed image (decoder). 

In this exercise, we will use `AutoencoderKL` from Hugging Face, where we start with a 512x512x3 image and compress to a latent vector 64x64x4 (the compression factor of 48!).

In [None]:
#!pip install diffusers 
from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)

In [None]:
# path where the model weights will be saved
save_dir = Path("./weights")
save_dir.mkdir(parents=True, exist_ok=True)

# create custom dataset
bs = 16
dsd = CustomDataset(root_dir="apebase/ipfs")
dls = DataLoader(dsd, batch_size=bs, shuffle=True, num_workers=0)

# model instantiation
# note the in_channels = out_channels = 4 as the model is trained on the latent images, not the raw images
model = UNet2DModel(in_channels=4, out_channels=4, nfs=(32,64,128,256), num_layers=2)
model = model.to(device)

# hyperparameters
epochs = 25
tmax = epochs * len(dls)
optimizer = Adam(model.parameters(), eps=1e-5)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
schedo = sched(optimizer)

# training loop!
model.train()
train_losses = []
for epoch in range(epochs):
  batch_train_losses = []
  pbar = tqdm(dls, mininterval=2)
  for xb,_ in pbar:
    optimizer.zero_grad()
    
    xb = xb.to(device)
    encoded_imgs = utils.img_to_latent(vae, xb)
    (noised_input, t), target = utils.noisify(encoded_imgs)
    out = model((noised_input, t))
    loss = F.mse_loss(out, target)
    loss.backward()
    loss.detach()
    schedo.optimizer.step()
    schedo.step()
    batch_train_losses.append(loss.item())
    pbar.set_description(f"loss {loss.item():.2f}")        

  train_losses.extend(batch_train_losses)
  print(f"Epoch {epoch}, loss: {np.mean(train_losses)}")

  # save the model weights every 4 epochs
  if epoch % 4 == 0 or epoch == int(epochs-1):
    model_path = save_dir / f"bayc_model_{epoch}_bs_{bs}.pth"
    torch.save(model.state_dict(), model_path)
    print(f"saved model at {model_path.absolute().as_posix()}")