In [None]:
import torch.nn as nn
import tensorflow as tf
import functools
import torch
import numpy as np
from scipy import special as sf
from scipy.stats import binom as spbinom
from numba import njit,float64,int64,jit
from numba.types import UniTuple
from matplotlib import pyplot as plt
import numba_scipy
import gc
import os
from utils import save_checkpoint_withEval as save_checkpoint
from utils import restore_checkpoint_withEval as restore_checkpoint
from loadDataPipeline import generateData

In [None]:
relu = nn.functional.relu

In [None]:
from torch.utils.cpp_extension import load
from models import ncsnpp
from configs.vp import cifar10_ncsnpp_continuous as configLoader
from models import utils as mutils
from models.ema import ExponentialMovingAverage

### Loading the ML model from Song et al.

In [None]:
config =  configLoader.get_config()
config.training.batch_size=128
config.training.snapshot_freq_for_preemption=1000
config.training.snapshot_freq=50000
config.training.log_freq=5

In [None]:
config.data.dataset='MNIST'
config.data.image_size=32
config.data.num_channels=3
config.data.random_flip=False
config.model.nf=64
config.model.name='ncsnpp'
config.model.num_scales=1000

### Loading the dataset

In [None]:
train_ds, eval_ds, scaler = generateData(config,'mnist')

train_iter = iter(train_ds)
eval_iter = iter(eval_ds)


### Specify observation times (noise schedule)

In [None]:
tEnd = 15. #(approximately)
T = 1000
T = T+1
observationTimes1 = np.linspace(0, 1, 201)[1:]
observationTimes2 = np.linspace(1, tEnd, 801)[1:]
observationTimes = np.hstack((observationTimes1,observationTimes2))
T = T-1

In [None]:
plt.plot(observationTimes)
plt.xlabel('computational time step')
plt.ylabel('physical time')

### Analytically derived reverse-time transition rate

In [None]:
brTable = np.zeros((256,256,T))
for tIndex in range(T):
    p = np.exp(-observationTimes[tIndex])
    for n in range(256):
        for m in range(n):
            brTable[n,m,tIndex] = n-m 
        brTable[n,n,tIndex] = 0

### Noisifier

In [None]:
def generateDataLinearDegradation(img, tIndex, t):
    
    nx, ny, nz = img.shape
    
    p = np.exp(-t)
    
    output_image = np.random.binomial(img.astype('int32'), p)
    birthRate = np.zeros_like(output_image).astype('float32')
    
    for sx in range(nx):
        for sy in range(ny):
            for sz in range(nz):

                n0 = np.int64(img[sx, sy, sz])
                nt = np.int64(output_image[sx, sy, sz])
                
                birthRate[sx,sy,sz] = brTable[n0, nt, tIndex]
    
    width = 255.0/2*p
    mean_v = 255.0/2*p
    
    return (output_image-mean_v)/width, birthRate

In [None]:
def generateBatchData(imgBatch, T, observationTimes):
    
    imgBatchNumpy = imgBatch.detach().cpu().numpy()
    output_image_batch = np.zeros_like(imgBatchNumpy)
    birthRate_batch = np.zeros_like(imgBatchNumpy)
    tIndexArray = np.random.choice(T, size=len(output_image_batch))
    
    for i in range(len(output_image_batch)):
        
        testImage = np.round(np.transpose((1+imgBatchNumpy[i,:,:,:])/2*255, [1,2,0]))
        output_image, birth_rate = generateDataLinearDegradation(testImage, tIndexArray[i], observationTimes[tIndexArray[i]])
        
        output_image_batch[i,:] = np.transpose(output_image[:], [2,0,1])
        birthRate_batch[i,:] = np.transpose(birth_rate[:], [2,0,1])

    return np.repeat(output_image_batch,3,axis=1), birthRate_batch, tIndexArray

### Visualize one batch

In [None]:
train_batch = torch.from_numpy(next(train_iter)['image']._numpy()).to(config.device).float()
train_batch = train_batch.permute(0, 3, 1, 2)
train_batch = scaler(train_batch)

output_image_batch, birthRate_batch, tIndexArray = generateBatchData(train_batch, T, observationTimes)
birthRate_batch_torch = torch.tensor(birthRate_batch).to(config.device)
tIndexArray = torch.tensor(tIndexArray).to(config.device)

for i in range(20):

    #testImage = np.round(np.transpose((1+train_batch[i].detach().cpu().numpy())/2*255, [1,2,0]))
    #targetTime = np.around(np.random.uniform(low=0.00001, high=15), 4)
    #output_image, birthRate = generateDataLinearDegradation(testImage, targetTime)
    
    testImage = np.round(np.transpose((1+train_batch[i].detach().cpu().numpy())/2*255, [1,2,0]))
    output_image = np.transpose((1+output_image_batch[i,:,:,:])/2*255, [1,2,0])
    birthRate = np.transpose(birthRate_batch[i,:,:,:], [1,2,0])
    targetTime = np.around(observationTimes[tIndexArray[i]], 4)
    
    fig, ax = plt.subplots(1,3, figsize=(4.8,1.5))
    
    ax[0].imshow(testImage/255.)
    
    if np.amax(output_image)!=0:
        ax[1].imshow(output_image/np.amax(output_image))
    else:
        ax[1].imshow(output_image)
        
    ax[1].set_title('$t='+str(targetTime)+'$')

    if np.amax(birthRate)-np.amin(birthRate)!=0:
        ax[2].imshow((birthRate-np.amin(birthRate))/(np.amax(birthRate)-np.amin(birthRate)))
    else:
        ax[2].imshow(birthRate)
    
    for j in range(3):
        
        ax[j].set_xticklabels('')
        ax[j].set_yticklabels('')
    
    fig.tight_layout()

### Instantiate an ML model to learn the transition rate

In [None]:
score_model = mutils.create_model(config)
score_fn = mutils.get_model_fn(score_model, train=True)
optimizer = torch.optim.Adam(score_model.parameters(),lr=config.optim.lr) 

ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)

train_batch = torch.from_numpy(next(train_iter)['image']._numpy()).to(config.device).float()
train_batch = train_batch.permute(0, 3, 1, 2)
imgBatch = scaler(train_batch)

workdir = 'linearDegradation-mnist'

state = dict(optimizer=optimizer, model=score_model, ema=ema, lossHistory=[], evalLossHistory=[], step=0)

checkpoint_dir = os.path.join(workdir, "checkpoints")
checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth")
tf.io.gfile.makedirs(checkpoint_dir)
tf.io.gfile.makedirs(os.path.dirname(checkpoint_meta_dir))
state = restore_checkpoint(checkpoint_meta_dir, state, config.device)
initial_step = int(state['step'])
lossHistory = state['lossHistory']
evalLossHistory = state['evalLossHistory']

### Training

In [None]:
for step in range(initial_step, config.training.n_iters):
    
    train_batch = torch.from_numpy(next(train_iter)['image']._numpy()).to(config.device).float()
    train_batch = train_batch.permute(0, 3, 1, 2)
    train_batch = scaler(train_batch)

    output_image_batch, birthRate_batch, tIndexArray = generateBatchData(train_batch, T, observationTimes)
    birthRate_batch_torch = torch.from_numpy(birthRate_batch).to(config.device)
    output_image_batch_torch = torch.from_numpy(output_image_batch).to(config.device)
    tIndexArray = torch.from_numpy(tIndexArray).to(config.device)

    y = relu(score_fn(output_image_batch_torch, tIndexArray))

    optimizer.zero_grad()
    
    loss = torch.mean(torch.square(y-birthRate_batch_torch))
   
    loss.backward()
    state['ema'].update(state['model'].parameters())
    
    optimizer.step()
    
    lossHistory.append(loss.detach().cpu().numpy())

    if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:
        save_checkpoint(checkpoint_meta_dir, state)
        
    if step != 0 and step % config.training.snapshot_freq == 0 or step == config.training.n_iters:
        save_step = step // config.training.snapshot_freq
        save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{save_step}.pth'), state)    
    
    if np.mod(step, config.training.log_freq)==0:
        
        eval_batch = torch.from_numpy(next(eval_iter)['image']._numpy()).to(config.device).float()
        eval_batch = eval_batch.permute(0, 3, 1, 2)
        eval_batch = scaler(eval_batch)
        
        output_image_batch, birthRate_batch, tIndexArray = generateBatchData(eval_batch, T, observationTimes)
        birthRate_batch_torch = torch.from_numpy(birthRate_batch).to(config.device)
        output_image_batch_torch = torch.from_numpy(output_image_batch).to(config.device)
        tIndexArray = torch.from_numpy(tIndexArray).to(config.device)

        ema.store(score_model.parameters())
        ema.copy_to(score_model.parameters())
        
        y = relu(score_fn(output_image_batch_torch, tIndexArray))
        loss = torch.mean(torch.square(y-birthRate_batch_torch))
        
        ema.restore(score_model.parameters())
        
        evalLossHistory.append(loss.detach().cpu().numpy())

        print(f'current iter: {step}, loss: {lossHistory[-1]}, eval loss: {evalLossHistory[-1]}')
        
    state['step'] = step
    state['lossHistory'] = lossHistory
    state['evalLossHistory'] = evalLossHistory
    
    gc.collect()
    torch.cuda.empty_cache()