In [3]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [4]:
## Hyperparams
BATCH_SIZE = 64

In [5]:
## Data Prep
DATA_DIM = 28 * 28

dataset = CIFAR10(
    root="data/",
    train=True,
    download=True,
    transform=Compose([ToTensor(), Normalize(0.5, 0.5)])
)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
for (X, y) in dataloader:
    writer.add_image("dataformats=z")

Files already downloaded and verified


In [8]:
## Model construction

class DenoiserNet(nn.Module):
    def __init__(self, sigma_data):
        super().__init__()
        self.sigma_data = sigma_data
        self.net = nn.Sequential(nn.Linear(29, 28)#TODO actually construct the denoising model...
            
        )
        
    def forward(self, noisy_image, noise_level, class_label=None, augment_labels=None):
        """
        
        Returns noise to subtract from scaled_noisy_image at a given noise_level to obtain a clean image
        
        scaled_noisy_image: an image scaled from 
        
        """
        sigma = noise_level
        sigma_data = self.sigma_data
        
        c_in = 1 / torch.sqrt(sigma_data**2 + sigma **2)
        c_out = sigma * sigma_data / torch.sqrt(sigma**2 + sigma_data**2)
        c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
        c_noise = torch.log(sigma) / 4      # noise label warp
        
        net_input = torch.concat(c_in * noisy_image, c_noise)
        
        return c_skip * noisy_image + \
                   c_out  * self.net(net_input)

In [9]:
## forward generation
NUM_STEPS = 100
sigma_max = 80
sigma_min = 0.002
rho = 7
LATENT_DIM = 100
S_churn = 0 # apparently best for this method? idk if depends on VE/VP?
S_min = 0
S_noise = 1
S_max = torch.inf
class_labels = None


step_indices = torch.arange(0, NUM_STEPS)

timesteps = (sigma_max ** (1/rho) + (step_indices / (NUM_STEPS - 1)) * (sigma_min ** (1/rho) - sigma_max ** (1/rho))) ** rho

def heun_sampling(net: DenoiserNet):
    """
    
    Returns a generated sample via reverse diffusion using 2nd order Heun solver
    
    net: trained denoiser that returns noise to subtract from a dirty image
    
    """
    # sample from gaussian
    # set up timesteps
    # per time step -> get to next one 
    
    initial_noise = torch.randn(LATENT_DIM)
    
    img_next = initial_noise # loop initialization
    
    for i, t_curr, t_next in enumerate(zip(timesteps[:-1], timesteps[1:])):
        img_curr = img_next
        
        # increase noise temporarily
        gamma = min(S_churn / NUM_STEPS, torch.sqrt(2) - 1) if S_min <= t_curr <= S_max else 0
        t_hat = net.round_sigma(t_curr + gamma * t_curr)
        img_hat = img_curr + (t_hat ** 2 - t_curr ** 2).sqrt() * S_noise * torch.randn_like(img_curr)
        
        # Euler step
        denoised = net(img_hat, t_hat, class_labels).to(torch.float64)
        d_cur = (img_hat - denoised) / t_hat
        img_next = img_curr + (t_next - t_hat) * d_cur
        
        # 2nd order correction
        if i < NUM_STEPS - 1:
            denoised = net(img_next, t_next, class_labels).to(torch.float64)
            d_prime = (img_next - denoised) / t_next
            img_next = img_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

    return img_next

In [10]:
## loss set up
P_mean = -1.2       # average noise level (logarithmic)
P_std = 1.2     # spread of random noise levels

class EDMLoss(nn.Module):
    def __init__(self, P_mean=P_mean, P_std=P_std, sigma_data=0.5):
        super().__init__()
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data
    
    def forward(self, net, images, labels=None, augment_pipe=None):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) # sample from random shape of device per image
        sigma = (rnd_normal * self.P_std + self.P_mean).exp() # scale sampled noise levels
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
        loss = weight * (D_yn - y).square().sum()
        return loss

In [None]:
## training loop
P_mean = -1.2       # average noise level (logarithmic)
P_std = 1.2     # spread of random noise levels

net = DenoiserNet(P_std)

optimizer = torch.optim.Adam(net.parameters())

loss_fn = EDMLoss()

EPOCHS = 10

train = True
count = 0

for epoch in range(EPOCHS):
    for (X, y) in dataloader:
        optimizer.zero_grad()
        loss = loss_fn(net, X)
        writer.add_scalar("Loss/train", loss, epoch)
        loss.backwards()
        
writer.flush()
writer.close()