In [2]:
%reload_ext autoreload
%autoreload 2
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision
from typing import List, Tuple, Union, Optional, Callable
import math
from abc import ABC
from os import listdir
from PIL import Image
from diffusion import Diffusion
from unet import UNet

In [3]:
class PILImageDataset(torch.utils.data.Dataset, ABC):
    directory: str
    files: List[str]
    cache: List
    tsfm: Callable

    def __init__(self, directory: str, tsfm: Callable):
        self.directory = directory
        self.files = [directory + '/' + i for i in listdir(self.directory)]
        # self.cache = [None] * len(self.files)
        self.tsfm = tsfm

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx: int):
        # if self.cache[idx] is not None:
        #     return self.cache[idx]
        img = Image.open(self.files[idx])
        # self.cache[idx] = self.tsfm(img)
        res = self.tsfm(img)
        img.close()
        return res, torch.Tensor([])

In [8]:
from torchvision.datasets import CIFAR10, MNIST

MODEL_SAVE_PATH = "../model/"

In [3]:
DATA_PATH = "./data/anime_icons"
data_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
    transforms.RandomHorizontalFlip()
])

dataset = PILImageDataset(DATA_PATH, tsfm=data_transform)

data_loader_train = DataLoader(dataset, batch_size=64, shuffle=True)

FileNotFoundError: [WinError 3] 系统找不到指定的路径。: './data/anime_icons'

In [9]:
num_steps = 1000
betas = torch.linspace(0.0001, 0.02, num_steps)
device = torch.device("cuda")
num_epoch = 100

In [10]:
model = UNet(3, 3, resolute_multiplication=(1, 2, 2, 2),
             is_attention=(False, False, False, True)).to(device)
model.load_state_dict(torch.load(f"{MODEL_SAVE_PATH}/model"))
optim = torch.optim.Adam(model.parameters(), 1e-5)
diffusion = Diffusion(model, optim, betas, device)

In [None]:
from PIL import Image
from torchvision.utils import save_image, make_grid

img_tsfm = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
    transforms.RandomHorizontalFlip()
])

def generate_by(img_path: str, times: int = 10,
                frm: Callable[[int], int] = lambda x: 300, stp: int = 10) -> None:
  img = img_tsfm(Image.open(img_path)).to(device).view(1, 3, 64, 64)
  imgs = [img]
  for i in range(1, times + 1):
    print(f"\r{i}", end="")
    frming = frm(i)
    img_gen = diffusion.generate_from(img, frming)
    save_image(
        make_grid(torch.cat(img_gen[-frming::stp], dim = 0) * 0.5 + 0.5, value_range=(-1, 1), nrow=4),
        f"diffusion_{i}.png"
    )
    save_image(img_gen[-1] * 0.5 + 0.5, f"generated_{i}.png")
    imgs.append(img_gen[-1])
    img = img_gen[-1]
  print()
  save_image(
      make_grid(torch.cat(imgs, dim = 0) * 0.5 + 0.5, value_range=(-1, 1), nrow=4),
      f"generating_process.png"
  )

In [12]:
from torchvision.utils import save_image, make_grid


def train():
    for k in range(1, num_epoch + 1):
        loss_avg = 0.

        tot = len(data_loader_train)
        idx = 0

        for img, lbl in data_loader_train:
            loss = diffusion.train(img)
            loss_avg += loss

            idx += 1

            print(f"\r{k}: {idx}/{tot}; {loss}", end="")

        print(f"\r{k}: {idx}/{tot}; {loss_avg / len(data_loader_train)}")

        xh = diffusion.sample((8, 3, 64, 64))
        grid = make_grid(torch.cat([img[:8].to(device), xh[-1]], dim=0) * 0.5 + 0.5, value_range=(-1, 1), nrow=4)
        save_image(grid, f"./ddpm_sample_{k}.png")

        torch.save(model.state_dict(), f"{MODEL_SAVE_PATH}/model")

def gen():
    xh = diffusion.sample((16, 3, 64, 64), prog=True)
    grid = make_grid(xh[-1].to(device) * 0.5 + 0.5, value_range=(-1, 1), nrow=4)
    save_image(grid, f"./ddpm_sample.png")

In [13]:
gen()

4/1000

KeyboardInterrupt: 