# Simple diffusion
This is a simple diffusion model heavily based on [minDiffusion](https://github.com/cloneofsimo/minDiffusion).

In [1]:
                        import os, random, time, math
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image, make_grid
from tqdm import tqdm

In [2]:
import numpy as np
import cv2
from PIL import Image
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
MODEL_OUT_DIR = r"./models"
SAMPLE_OUT_DIR = r"./samples"

# Diffusion
NOISE_STEPS = 200
BETA_START = 1e-4
BETA_END = 2e-2

# Training
INIT_LR = 1e-5

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])


def load_img(filepath):
    img = Image.open(filepath).convert('RGB')
    img = img.resize((256, 256), Image.BICUBIC)
    return img


def save_img(image_tensor, filename):
    
    #image_numpy = image_tensor.detach().cpu().numpy()
    image_numpy = image_tensor[0].cpu().detach()
    image_numpy = image_numpy.numpy()
    print("shape of numpy_image",image_numpy.shape)
    imput_numpy = image_numpy.astype(np.float)
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
 
    #image_numpy = np.transpose(image_numpy, (1, 2, 0))
    image_numpy = image_numpy.clip(0, 255)
    #image_pil = Image.fromarray(np.uint8(255 * image_numpy))
    image_numpy = image_numpy.astype(np.uint8)
    image_pil = Image.fromarray(image_numpy)
    image_pil.save(filename)
    #cv2.imwrite(filename, image_pil)
    print("Image saved as {}".format(filename))


In [3]:
from os import listdir
from os.path import join
import random

from PIL import Image
import torch
import torch.utils.data as data
import torchvision.transforms as transforms

class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, direction):
        super(DatasetFromFolder, self).__init__()
        self.direction = direction
        self.a_path = join(image_dir, "a")
        self.b_path = join(image_dir, "b")
        
        self.image_filenames = [x for x in listdir(self.a_path) if is_image_file(x)]
        
        

        transform_list = [transforms.ToTensor(),
                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]

        self.transform = transforms.Compose(transform_list)

    def __getitem__(self, index):
        #print(join(self.b_path, self.image_filenames[index]))
        #print(join(self.a_path, self.image_filenames[index]))
        #print('//')
        a = Image.open(join(self.a_path, self.image_filenames[index]))#.convert('RGB')
        
        b = Image.open(join(self.b_path, self.image_filenames[index]).replace("png","jpg")).convert('RGB')
        
        a = a.resize((256, 256), Image.BICUBIC)
        b = b.resize((256, 256), Image.BICUBIC)
        a = transforms.ToTensor()(a)
        b = transforms.ToTensor()(b)

    
        a = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(a)
        b = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(b)

        if random.random() < 0.5:
            idx = [i for i in range(a.size(2) - 1, -1, -1)]
            idx = torch.LongTensor(idx)
            a = a.index_select(2, idx)
            b = b.index_select(2, idx)

        if self.direction == "a2b":
            return a, b
        else:
            return b, a

    def __len__(self):
        return len(self.image_filenames)   
    

In [4]:
from os.path import join


def get_training_set(root_dir, direction):
    train_dir = join(root_dir, "train")

    return DatasetFromFolder(train_dir, direction)


def get_test_set(root_dir, direction):
    test_dir = join(root_dir, "test")

    return DatasetFromFolder(test_dir, direction)


In [5]:
from torch.utils.data import DataLoader
root_path = "./data/"
print('hhhhh')
train_set = get_training_set(root_path, 'b2a')
print(train_set)
test_set = get_test_set(root_path, 'b2a')
train_data_loader= DataLoader(dataset=train_set,  batch_size=BATCH_SIZE, shuffle=True)
        
testing_data_loader = DataLoader(dataset=test_set, batch_size=BATCH_SIZE, shuffle=True)

device = torch.device("cuda:0")


hhhhh
<__main__.DatasetFromFolder object at 0x70290af23890>


## Model
The denoising model is a simple U-Net structure based on [minDiffusion](https://github.com/cloneofsimo/minDiffusion). It uses sinusoidal position embeddings to encode time steps.

In [6]:
class SinusoidalPositionEmbeddings(nn.Module):
    """Taken verbatim from https://huggingface.co/blog/annotated-diffusion"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim


    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [7]:
import os
import cv2
import sys
import time
import random
import numpy as np
import scipy
import torch
from skimage import util
import torch.nn.functional as FF
import scipy.stats as st
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from skimage.color import rgb2gray, gray2rgb

def create_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)

def same_padding(images, ksizes, strides, rates):   
    assert len(images.size()) == 4
    batch_size, channel, rows, cols = images.size()
    out_rows = (rows + strides[0] - 1) // strides[0]
    out_cols = (cols + strides[1] - 1) // strides[1]
    effective_k_row = (ksizes[0] - 1) * rates[0] + 1
    effective_k_col = (ksizes[1] - 1) * rates[1] + 1
    padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
    padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
    # Pad the input
    padding_top = int(padding_rows / 2.)
    padding_left = int(padding_cols / 2.)
    padding_bottom = padding_rows - padding_top
    padding_right = padding_cols - padding_left
    paddings = (padding_left, padding_right, padding_top, padding_bottom)
    images = torch.nn.ZeroPad2d(paddings)(images)
    return images

def reduce_mean(x, axis=None, keepdim=False):
    if not axis:
        axis = range(len(x.shape))
    for i in sorted(axis, reverse=True):
        x = torch.mean(x, dim=i, keepdim=keepdim)
    return x

def reduce_sum(x, axis=None, keepdim=False):
    if not axis:
        axis = range(len(x.shape))
    for i in sorted(axis, reverse=True):
        x = torch.sum(x, dim=i, keepdim=keepdim)
    return x

def extract_image_patches(images, ksizes, strides, padding='same'):

    assert len(images.size()) == 4
    assert padding in ['same', 'valid']
    batch_size, channel, height, width = images.size()

    if padding == 'same':
        images = same_padding(images, ksizes, strides, [1, 1])
    elif padding == 'valid':
        pass
    else:
        raise NotImplementedError('Unsupported padding type: {}.\
            Only "same" or "valid" are supported.'.format(padding))
    batch_size, channel, height, width = images.size()

    unfold = torch.nn.Unfold(kernel_size=ksizes,
                            padding=0,
                            stride=strides)
    patches = unfold(images)
    return patches

def np_free_form_mask(h, w, maxVertex, maxLength, maxBrushWidth, maxAngle):

    mask = np.zeros((h, w), np.float32)
    numVertex = np.random.randint(maxVertex + 1)
    startY = np.random.randint(h)
    startX = np.random.randint(w)
    brushWidth = 0
    for i in range(numVertex):
        angle = np.random.randint(maxAngle + 1)
        angle = angle / 360.0 * 2 * np.pi
        if i % 2 == 0:
            angle = 2 * np.pi - angle
        length = np.random.randint(maxLength + 1)
        brushWidth = np.random.randint(10, maxBrushWidth + 1) // 2 * 2
        nextY = startY + length * np.cos(angle)
        nextX = startX + length * np.sin(angle)

        nextY = np.maximum(np.minimum(nextY, h - 1), 0).astype(np.int)
        nextX = np.maximum(np.minimum(nextX, w - 1), 0).astype(np.int)

        cv2.line(mask, (startY, startX), (nextY, nextX), 1, brushWidth)
        cv2.circle(mask, (startY, startX), brushWidth // 2, 2)

        startY, startX = nextY, nextX
    cv2.circle(mask, (startY, startX), brushWidth // 2, 2)
    return mask

def free_form_mask(h, w, parts=8, maxVertex=16, maxLength=80, maxBrushWidth=20, maxAngle=360):
    mask = np.zeros((h, w), np.float32)
    for i in range(parts):
        p = np_free_form_mask(h, w, maxVertex, maxLength, maxBrushWidth, maxAngle)
        mask = mask + p
    mask = np.minimum(mask, 1.0)
    return mask

def generate_mask_stroke(im_size, parts=16, maxVertex=24, maxLength=100, maxBrushWidth=24, maxAngle=360):
    h, w = im_size[:2]
    mask = np.zeros((h, w, 1), dtype=np.float32)
    for i in range(parts):
        mask = mask + np_free_form_mask( h, w, maxVertex, maxLength, maxBrushWidth, maxAngle)
    mask = np.minimum(mask, 1.0)
    return mask


def generate_noise(image, noise_type="gauss"):
    if noise_type == "gauss":
        noise = np.random.normal(0.0, 50/255.0, image.shape)
        out = noise+image

    if noise_type == "salt":
        out = util.random_noise(image=image, mode='salt', clip=True, amount=0.2)

    if noise_type == "poisson":
        vals = len(np.unique(image))
        vals = 2 ** np.ceil(np.log2(vals))
        out = np.random.poisson(image * vals) / float(vals)

    if noise_type == "speckle":
        row,col,ch = image.shape
        gauss = np.random.randn(row,col,ch)
        gauss = gauss.reshape(row,col,ch)        
        out = image + image * gauss

    if noise_type == "s&p":
        out = util.random_noise(image=image, mode='s&p', clip=True, amount=0.2, salt_vs_pepper=0.5)

    return np.uint8(noise)

def generate_rectangle(h, w):
    mask = np.ones((h, w))
    crop_size = h//2
    startY = np.random.randint(0, h-crop_size)
    startX = np.random.randint(0, w-crop_size)
    mask[startY: startY+crop_size, startX: startX+crop_size] = 0
    return mask 
    
def generate_graffiti(h, w, noise):
    mask = np.ones((h, w))
    idx1 = noise[:, :, 0] == 0
    idx2 = noise[:, :, 1] == 0                
    idx3 = noise[:, :, 2] == 0
    idx = idx1 == idx2
    idx = idx == idx3
    mask[idx] = 0
    return mask

def stitch_images(inputs, *outputs, img_per_row=2):
    gap = 5
    columns = len(outputs) + 1

    height, width = inputs[0][:, :, 0].shape
    img = Image.new('RGB', (width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row)))
    images = [inputs, *outputs]

    for ix in range(len(inputs)):
        xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap
        yoffset = int(ix / img_per_row) * height

        for cat in range(len(images)):
            im = np.array((images[cat][ix]).cpu()).astype(np.uint8).squeeze()
            im = Image.fromarray(im)
            img.paste(im, (xoffset + cat * width, yoffset))

    return img

# def random_crop(npdata, crop_size, datatype):
    
#     height, width = npdata.shape[0:2]
#     mask = np.ones((height, width))

#     if datatype == 1:
#         h = random.randint(0, height - crop_size)
#         w = random.randint(0, width - crop_size)
#         mask[h: h+crop_size, w: w+crop_size] = 0
#         crop_image = npdata[h: h+crop_size, w: w+crop_size]
    
#     if datatype == 2:
#         h = 0
#         w = random.randint(0, width - crop_size)
#         mask[:, w: w+crop_size] = 0
#         crop_image = npdata[:, w: w+crop_size] 
#     return crop_image, (w, h), mask

    
def gauss_kernel(size=21, sigma=3):
    interval = (2 * sigma + 1.0) / size
    x = np.linspace(-sigma-interval/2, sigma+interval/2, size+1)
    ker1d = np.diff(st.norm.cdf(x))
    kernel_raw = np.sqrt(np.outer(ker1d, ker1d))
    kernel = kernel_raw / kernel_raw.sum()
    out_filter = np.array(kernel, dtype=np.float32)
    out_filter = out_filter.reshape((size, size, 1, 1))
    return out_filter

def same_padding(images, ksizes, strides, rates):
    assert len(images.size()) == 4
    batch_size, channel, rows, cols = images.size()
    out_rows = (rows + strides[0] - 1) // strides[0]
    out_cols = (cols + strides[1] - 1) // strides[1]
    effective_k_row = (ksizes[0] - 1) * rates[0] + 1
    effective_k_col = (ksizes[1] - 1) * rates[1] + 1
    padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
    padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
    # Pad the input
    padding_top = int(padding_rows / 2.)
    padding_left = int(padding_cols / 2.)
    padding_bottom = padding_rows - padding_top
    padding_right = padding_cols - padding_left
    paddings = (padding_left, padding_right, padding_top, padding_bottom)
    images = torch.nn.ZeroPad2d(paddings)(images)
    return images

def random_crop(npdata, crop_size, datatype, count, pos, known_mask=None):
    
    height, width = npdata.shape[0:2]
    mask = np.ones((height, width))

    if datatype == 1:
        if count == 0 and not known_mask:
            h = random.randint(0, height - crop_size)
            w = random.randint(0, width - crop_size)
        else:
            w, h = pos[0], pos[1]
        mask[h: h+crop_size, w: w+crop_size] = 0
        crop_image = npdata[h: h+crop_size, w: w+crop_size]
    
    if datatype == 2:
        h = 0
        w = random.randint(0, width - crop_size)
        mask[:, w: w+crop_size] = 0
        crop_image = npdata[:, w: w+crop_size] 
    return crop_image, (w, h), mask

def center_crop(npdata, crop_size):
    height, width = npdata.shape[0:2]
    mask = np.ones((height, width))
    w = 64
    h = 64
    mask[h: h+crop_size, w: w+crop_size] = 0

    crop_image = npdata[h: h+crop_size, w: w+crop_size]
    return crop_image, (w, h), mask
    
def side_crop(data, crop_size):
    height, width = data.shape[0:2]
    mask = np.ones((height, width))
    
    w = (width - crop_size) // 2
    h = 0
    mask[:, 0: w] = 0.
    mask[:, w+crop_size:] = 0.
    
    return (w, h), mask

def imshow(img, title=''):
    fig = plt.gcf()
    fig.canvas.set_window_title(title)
    plt.axis('off')
    plt.imshow(img, interpolation='none')
    plt.show()


def imsave(img, path):
    im = Image.fromarray(img.cpu().numpy().astype(np.uint8).squeeze())
    im.save(path)

def savetxt(arr, path):
    np.savetxt(path, arr.cpu().numpy().squeeze(), fmt='%.2f')
    
def template_match(target, source):
    locs = []
    _src = []
    for i in range(target.shape[0]):
        src = source[i].detach().cpu().permute(1, 2, 0).numpy()
        tar = target[i].detach().cpu().permute(1, 2, 0).numpy()
        
        src_gray = cv2.cvtColor(src, cv2.COLOR_RGB2GRAY)
        tar_gray = cv2.cvtColor(tar, cv2.COLOR_RGB2GRAY)
        w, h = tar_gray.shape[::-1]

        res = cv2.matchTemplate(src_gray, tar_gray, cv2.TM_CCOEFF)
        min_val, max_val, min_loc, loc = cv2.minMaxLoc(res)
        locs.append(loc)
        
        src = src * 255
        im = Image.fromarray(src.astype(np.uint8).squeeze())
        draw = ImageDraw.Draw(im)
        draw.rectangle([loc, (loc[0] + w, loc[1] + h)], outline=0)
        im = np.array(im)
        _src.append(im)
        
    return torch.Tensor(_src), locs    


def make_mask(data, pdata, pos, device):
    
    crop_size = pdata.shape[3]
    mask_with_pdata = torch.zeros(data.shape).to(device)
    mask_with_ones = torch.ones(data.shape).to(device)

    for po in range(len(pos)):
        w, h = pos[po][0], pos[po][1]
        mask_with_pdata[po, :, h: h+crop_size, w: w+crop_size] = pdata[po]
        mask_with_ones[po, :, h: h+crop_size, w: w+crop_size] = 0

    return mask_with_pdata, mask_with_ones
    

class Progbar(object):
    """Displays a progress bar.

    Arguments:
        target: Total number of steps expected, None if unknown.
        width: Progress bar width on screen.
        verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
        stateful_metrics: Iterable of string names of metrics that
            should *not* be averaged over time. Metrics in this list
            will be displayed as-is. All others will be averaged
            by the progbar before display.
        interval: Minimum visual progress update interval (in seconds).
    """

    def __init__(self, target, width=25, verbose=1, interval=0.05,
                 stateful_metrics=None):
        self.target = target
        self.width = width
        self.verbose = verbose
        self.interval = interval
        if stateful_metrics:
            self.stateful_metrics = set(stateful_metrics)
        else:
            self.stateful_metrics = set()

        self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
                                  sys.stdout.isatty()) or
                                 'ipykernel' in sys.modules or
                                 'posix' in sys.modules)
        self._total_width = 0
        self._seen_so_far = 0
        # We use a dict + list to avoid garbage collection
        # issues found in OrderedDict
        self._values = {}
        self._values_order = []
        self._start = time.time()
        self._last_update = 0

    def update(self, current, values=None):
        """Updates the progress bar.

        Arguments:
            current: Index of current step.
            values: List of tuples:
                `(name, value_for_last_step)`.
                If `name` is in `stateful_metrics`,
                `value_for_last_step` will be displayed as-is.
                Else, an average of the metric over time will be displayed.
        """
        values = values or []
        for k, v in values:
            if k not in self._values_order:
                self._values_order.append(k)
            if k not in self.stateful_metrics:
                if k not in self._values:
                    self._values[k] = [v * (current - self._seen_so_far),
                                       current - self._seen_so_far]
                else:
                    self._values[k][0] += v * (current - self._seen_so_far)
                    self._values[k][1] += (current - self._seen_so_far)
            else:
                self._values[k] = v
        self._seen_so_far = current

        now = time.time()
        info = ' - %.0fs' % (now - self._start)
        if self.verbose == 1:
            if (now - self._last_update < self.interval and
                    self.target is not None and current < self.target):
                return

            prev_total_width = self._total_width
            if self._dynamic_display:
                sys.stdout.write('\b' * prev_total_width)
                sys.stdout.write('\r')
            else:
                sys.stdout.write('\n')

            if self.target is not None:
                numdigits = int(np.floor(np.log10(self.target))) + 1
                barstr = '%%%dd/%d [' % (numdigits, self.target)
                bar = barstr % current
                prog = float(current) / self.target
                prog_width = int(self.width * prog)
                if prog_width > 0:
                    bar += ('=' * (prog_width - 1))
                    if current < self.target:
                        bar += '>'
                    else:
                        bar += '='
                bar += ('.' * (self.width - prog_width))
                bar += ']'
            else:
                bar = '%7d/Unknown' % current

            self._total_width = len(bar)
            sys.stdout.write(bar)

            if current:
                time_per_unit = (now - self._start) / current
            else:
                time_per_unit = 0
            if self.target is not None and current < self.target:
                eta = time_per_unit * (self.target - current)
                if eta > 3600:
                    eta_format = '%d:%02d:%02d' % (eta // 3600,
                                                   (eta % 3600) // 60,
                                                   eta % 60)
                elif eta > 60:
                    eta_format = '%d:%02d' % (eta // 60, eta % 60)
                else:
                    eta_format = '%ds' % eta

                info = ' - ETA: %s' % eta_format
            else:
                if time_per_unit >= 1:
                    info += ' %.0fs/step' % time_per_unit
                elif time_per_unit >= 1e-3:
                    info += ' %.0fms/step' % (time_per_unit * 1e3)
                else:
                    info += ' %.0fus/step' % (time_per_unit * 1e6)

            for k in self._values_order:
                info += ' - %s:' % k
                if isinstance(self._values[k], list):
                    avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
                    if abs(avg) > 1e-3:
                        info += ' %.4f' % avg
                    else:
                        info += ' %.4e' % avg
                else:
                    info += ' %s' % self._values[k]

            self._total_width += len(info)
            if prev_total_width > self._total_width:
                info += (' ' * (prev_total_width - self._total_width))

            if self.target is not None and current >= self.target:
                info += '\n'

            sys.stdout.write(info)
            sys.stdout.flush()

        elif self.verbose == 2:
            if self.target is None or current >= self.target:
                for k in self._values_order:
                    info += ' - %s:' % k
                    avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
                    if avg > 1e-3:
                        info += ' %.4f' % avg
                    else:
                        info += ' %.4e' % avg
                info += '\n'

                sys.stdout.write(info)
                sys.stdout.flush()

        self._last_update = now

    def add(self, n, values=None):
        self.update(self._seen_so_far + n, values)

import math
from torch.optim.optimizer import Optimizer

class Adam16(Optimizer):
  def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0):
    
    defaults = dict(lr=lr, betas=betas, eps=eps,
            weight_decay=weight_decay)
    params = list(params)
    super(Adam16, self).__init__(params, defaults)
      
  # Safety modification to make sure we floatify our state
  def load_state_dict(self, state_dict):
    super(Adam16, self).load_state_dict(state_dict)
    for group in self.param_groups:
      for p in group['params']:
        
        self.state[p]['exp_avg'] = self.state[p]['exp_avg'].float()
        self.state[p]['exp_avg_sq'] = self.state[p]['exp_avg_sq'].float()
        self.state[p]['fp32_p'] = self.state[p]['fp32_p'].float()

  def step(self, closure=None):
    """Performs a single optimization step.
    Arguments:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.
    """
    loss = None
    if closure is not None:
      loss = closure()

    for group in self.param_groups:
      for p in group['params']:
        if p.grad is None:
          continue
          
        grad = p.grad.data.float()
        state = self.state[p]

        # State initialization
        if len(state) == 0:
          state['step'] = 0
          # Exponential moving average of gradient values
          state['exp_avg'] = grad.new().resize_as_(grad).zero_()
          # Exponential moving average of squared gradient values
          state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
          # Fp32 copy of the weights
          state['fp32_p'] = p.data.float()

        exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
        beta1, beta2 = group['betas']

        state['step'] += 1

        if group['weight_decay'] != 0:
          grad = grad.add(group['weight_decay'], state['fp32_p'])

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(1 - beta1, grad)
        exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
        denom = exp_avg_sq.sqrt().add_(group['eps'])

        bias_correction1 = 1 - beta1 ** state['step']
        bias_correction2 = 1 - beta2 ** state['step']
        step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
      
        state['fp32_p'].addcdiv_(-step_size, exp_avg, denom)
        p.data = state['fp32_p'].float()

    return loss


def cus_sample(feat, **kwargs):
    """
    :param feat: 输入特征
    :param kwargs: size或者scale_factor
    """
    assert len(kwargs.keys()) == 1 and list(kwargs.keys())[0] in ["size", "scale_factor"]
    return FF.interpolate(feat, **kwargs, mode="bilinear", align_corners=False)


def upsample_add(*xs):
    y = xs[-1]
    for x in xs[:-1]:
        y = y + FF.interpolate(x, size=y.size()[2:], mode="bilinear", align_corners=False)
    return y


def extract_image_patches(images, ksizes, strides, rates, padding='same'):
    """
    Extract patches from images and put them in the C output dimension.
    :param padding:
    :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
    :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
     each dimension of images
    :param strides: [stride_rows, stride_cols]
    :param rates: [dilation_rows, dilation_cols]
    :return: A Tensor
    """
    assert len(images.size()) == 4
    assert padding in ['same', 'valid']
    batch_size, channel, height, width = images.size()

    if padding == 'same':
        images = same_padding(images, ksizes, strides, rates)
    elif padding == 'valid':
        pass
    else:
        raise NotImplementedError('Unsupported padding type: {}.\
                Only "same" or "valid" are supported.'.format(padding))

    unfold = torch.nn.Unfold(kernel_size=ksizes,
                             dilation=rates,
                             padding=0,
                             stride=strides)
    patches = unfold(images)
    return patches  # [N, C*k*k, L], L is the total number of such blocks



def PositionEmbeddingSine(opt):
    temperature=10000
    feature_h = opt.crop_size//2**opt.n_downsample
    num_pos_feats = opt.ngf*(2**opt.n_downsample) // 2
    mask = torch.ones((feature_h, feature_h))
    y_embed = mask.cumsum(0, dtype=torch.float32)
    x_embed = mask.cumsum(1, dtype=torch.float32)
    # if self.normalize:
    #     eps = 1e-6
    #     y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
    #     x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

    dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
    dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)

    pos_x = x_embed[:, :, None] / dim_t
    pos_y = y_embed[:, :, None] / dim_t
    pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
    pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
    pos = torch.cat((pos_y, pos_x), dim=2).permute(2, 0, 1)
    return pos


def PatchPositionEmbeddingSine(ksize, stride):
    temperature=10000
    feature_h = int((256-ksize)/stride)+1
    num_pos_feats = 256//2
    mask = torch.ones((feature_h, feature_h))
    y_embed = mask.cumsum(0, dtype=torch.float32)
    x_embed = mask.cumsum(1, dtype=torch.float32)
    # if self.normalize:
    #     eps = 1e-6
    #     y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
    #     x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

    dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
    dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)

    pos_x = x_embed[:, :, None] / dim_t
    pos_y = y_embed[:, :, None] / dim_t
    pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
    pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
    pos = torch.cat((pos_y, pos_x), dim=2).permute(2, 0, 1)
    return pos

In [8]:
import copy
from typing import Optional, List

import torch
import torch.nn.functional as F
from torch import nn, Tensor


class TransformerEncoders(nn.Module):
    
    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, dim_feedforward=2048, dropout=0.0,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False, withCIA=False):
        super().__init__()

        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before, withCIA)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, src_key_padding_mask=None, src_pos=None):
        memory = self.encoder(src, pos=src_pos, src_key_padding_mask=src_key_padding_mask)
        return memory

class TransformerEncoder(nn.Module):

    def __init__(self, encoder_layer, num_layers, norm=None, withCIA=None):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        output = src

        for layer in self.layers:

            output = layer(output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos)
        if self.norm is not None:
            output = self.norm(output)

        return output



class TransformerEncoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False, withCIA=False):
        super().__init__()
        self.global_token_mixer = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.local_token_mixer = Local_Token_Mixer(dim=d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.accelerator = CIA()

        self.normalize_before = normalize_before
        self.withCIA = withCIA

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self,
                     src,
                     src_mask: Optional[Tensor] = None,
                     src_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None):
        q = k = self.with_pos_embed(src, pos)
        src2 = self.global_token_mixer(q, k, value=src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        if self.withCIA == True:
            src2 = self.accelerator(src, src2)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.local_token_mixer(src)
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

    def forward_pre(self, src,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

    def forward(self, src,
                src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
        return self.forward_post(src, src_mask, src_key_padding_mask, pos)


class CIA(nn.Module):
    def __init__(self):
        super().__init__()
        self.sigmoid = nn.Sigmoid() 
        
    def forward(self, x_pre, x_pos):
        [N, B, C] = x_pre.shape
        fea_pred = F.normalize(x_pre, dim=2)
        fea_later = F.normalize(x_pos, dim=2)
        dis = torch.bmm(fea_pred.permute(1, 0, 2), fea_later.permute(1, 2, 0))
        dis = torch.diagonal(dis, dim1=1, dim2=2).unsqueeze(-1)
        weight = self.sigmoid(dis)
        weight = 1 - weight
        out = x_pos * weight.unsqueeze(1).reshape(N, B, 1)
        return out
 

class Local_Token_Mixer(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.CTI = Cross_correlation_Token_Interaction(dim=dim)
        
    def forward(self, x):
        x = self.CTI(x)
        return x


class Cross_correlation_Token_Interaction(nn.Module):

    def __init__(self, dim):
        super().__init__()
        
        self.dim = dim
        self.softmax = nn.Softmax(dim=-1) 
        self.linear = nn.Linear(dim, 2, bias=False)
        self.depthwise = Depthwise_Conv(dim)
    
    def interaction(self, x):
        [B, C, N] = x.shape
        q = k = x
        matmul = torch.bmm(q.permute(0, 2, 1), k) # transpose check
        q_abs = torch.sqrt(torch.sum(q.pow(2) + 1e-6, dim=1, keepdim=True))
        k_abs = torch.sqrt(torch.sum(k.pow(2) + 1e-6, dim=1, keepdim=True))
        abs_matmul = torch.bmm(q_abs.permute(0, 2, 1), k_abs)
        io_abs = matmul / abs_matmul

        corr_seq = torch.zeros(x.shape).cuda()
        
        for i in range(B):

            abs = io_abs[i].fill_diagonal_(0)
            _map=torch.argmax(abs, dim=1)

            corr_seq[i, :, :] = x[i, :, _map]
        
        fus = x + corr_seq
        fus = fus.permute(0, 2, 1)
        weight = self.linear(fus)
        weight = self.softmax(weight)
        weight = weight.permute(0, 2, 1)
        output = x * weight[:, 0:1, :] + corr_seq * weight[:, 1:2, :]

        return output

    def token_partition(self, x, local_size):

        [N, B, C] = x.shape
        x = x.view(N // local_size, local_size, B, C)
        local = x.permute(1, 0, 2, 3).reshape(local_size, -1, C).permute(1, 2, 0)
        return local

    def token_reverse(self, x, local_size, token_size):
        B = int(x.shape[0] / (token_size / local_size))
        x = x.view(B, token_size // local_size, -1, local_size)
        output = x.reshape(B, -1, token_size)
        return output     
            
    def forward(self, x):
        local_seq = self.token_partition(x, local_size=1024)
        intered_seq = self.interaction(local_seq)
        s_intered_seq = self.token_reverse(intered_seq, local_size=1024, token_size=4096)
        output = self.depthwise(s_intered_seq)
        return output


class Depthwise_Conv(nn.Module):

    def __init__(self, dim, bias=False, kernel_size=7, padding=3):
        super().__init__()
        med_channels = int(dim * 2)
        self.dwconv = nn.Conv2d(
                    dim, dim, kernel_size=kernel_size,
                    padding=padding, groups=dim, bias=bias) 
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, med_channels, bias=bias)
        self.act1 = nn.GELU()
        self.pwconv2 = nn.Linear(med_channels, dim, bias=bias)

    def forward(self, x):
        [B, C, N] = x.shape
        x = x.view(B, C, 64, 64)
        x = self.dwconv(x)
        x = x.permute(2, 3, 0, 1).reshape(N, B, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act1(x)
        x = self.pwconv2(x)
        return x



def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [9]:
import torch
import torch.nn as nn
import functools
import torch.nn.functional as F
from collections import OrderedDict
import numpy as np
import torch.nn.utils.spectral_norm as spectral_norm
import math
from einops.layers.torch import Rearrange
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _single, _pair, _triple
#from . import transformer
#import PatchPositionEmbeddingSine

class BaseNetwork(nn.Module):
    def __init__(self):
        super(BaseNetwork, self).__init__()

    def init_weights(self, init_type='xavier', gain=0.02):
        def init_func(m):
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                if init_type == 'normal':
                    nn.init.normal_(m.weight.data, 0.0, gain)
                elif init_type == 'xavier':
                    nn.init.xavier_normal_(m.weight.data, gain=gain)
                elif init_type == 'kaiming':
                    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    nn.init.orthogonal_(m.weight.data, gain=gain)

                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)

            elif classname.find('BatchNorm2d') != -1:
                nn.init.normal_(m.weight.data, 1.0, gain)
                nn.init.constant_(m.bias.data, 0.0)

        self.apply(init_func)

class FFA(nn.Module):#TransCNN_Plus(nn.Module):
    def __init__(self, DEVICE):
        super(FFA, self).__init__()
        dim = 256
        #self.config = config
        self.patch_to_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = 4, p2 = 4),
            nn.Linear(4*4*3, dim)
        )
        self.transformer_enc = TransformerEncoders(dim, nhead=2, num_encoder_layers=9, dim_feedforward=dim*2, activation='gelu', withCIA=True)
        self.cnn_dec = CNNDecoder(256, 3, 'ln', 'lrelu', 'reflect')
        
        b = 3 #self.config.BATCH_SIZE
        MODE=2
        if MODE == 2:#self.config.MODE == 2:
            b = 1
        input_pos = PatchPositionEmbeddingSine(ksize=4, stride=4)
        self.input_pos = input_pos.unsqueeze(0).repeat(b, 1, 1, 1).to(DEVICE)
        self.input_pos = self.input_pos.flatten(2).permute(2, 0, 1)

    def forward(self, inputs):
        #inputs = F.interpolate(inputs, (256, 256), mode="bilinear", align_corners=False)
        patch_embedding = self.patch_to_embedding(inputs)
        content = self.transformer_enc(patch_embedding.permute(1, 0, 2), src_pos=self.input_pos)
        bs, L, C  = patch_embedding.size()
        content = content.permute(1,2,0).view(bs, C, int(math.sqrt(L)), int(math.sqrt(L)))
        output = self.cnn_dec(content)
        #output = F.interpolate(output, (256, 768), mode="bilinear", align_corners=False)
        return output


class CNNDecoder(nn.Module):
    def __init__(self, input_dim, output_dim, norm, activ, pad_type):
        super(CNNDecoder, self).__init__()
        self.model = []
        dim = input_dim
        self.conv1 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            Conv2dBlock(dim, dim // 2, 3, 1, 1, norm=norm, activation=activ, pad_type=pad_type)
        )
        dim //= 2
        self.conv2 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            Conv2dBlock(dim, dim // 2, 3, 1, 1, norm=norm, activation=activ, pad_type=pad_type)
        )
        self.conv3 = Conv2dBlock(dim//2, output_dim, 5, 1, 2, norm='none', activation='tanh', pad_type=pad_type)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        output = self.conv3(x2)
        return output


class Conv2dBlock(nn.Module):
    def __init__(self, input_dim ,output_dim, kernel_size, stride,
                 padding=0, norm='none', activation='relu', pad_type='zero', groupcount=16):
        super(Conv2dBlock, self).__init__()
        self.use_bias = True
        self.norm_type = norm
        # initialize padding
        if pad_type == 'reflect':
            self.pad = nn.ReflectionPad2d(padding)
        elif pad_type == 'replicate':
            self.pad = nn.ReplicationPad2d(padding)
        elif pad_type == 'zero':
            self.pad = nn.ZeroPad2d(padding)
        else:
            assert 0, "Unsupported padding type: {}".format(pad_type)

        # initialize normalization
        norm_dim = output_dim
        if norm == 'bn':
            self.norm = nn.BatchNorm2d(norm_dim)
        elif norm == 'in':
            #self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
            self.norm = nn.InstanceNorm2d(norm_dim)
        elif norm == 'ln':
            self.norm = LayerNorm(norm_dim)
        elif norm == 'adain':
            self.norm = AdaptiveInstanceNorm2d(norm_dim)
        elif norm == 'adain_ori':
            self.norm = AdaptiveInstanceNorm2d_IN(norm_dim)
        elif norm == 'remove_render':
            self.norm = RemoveRender(norm_dim)
        elif norm == 'grp':
            self.norm = nn.GroupNorm(groupcount, norm_dim)
        
        elif norm == 'none' or norm == 'sn':
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)

        # initialize activation
        if activation == 'relu':
            self.activation = nn.ReLU(inplace=True)
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == 'prelu':
            self.activation = nn.PReLU()
        elif activation == 'selu':
            self.activation = nn.SELU(inplace=True)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

        # initialize convolution
        if norm == 'sn':
            self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
        else:
            self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)

    def forward(self, x):
        x = self.conv(self.pad(x))
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x


class LayerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, affine=True):
        super(LayerNorm, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps

        if self.affine:
            self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
            self.beta = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        shape = [-1] + [1] * (x.dim() - 1)
        # print(x.size())
        if x.size(0) == 1:
            # These two lines run much faster in pytorch 0.4 than the two lines listed below.
            mean = x.view(-1).mean().view(*shape)
            std = x.view(-1).std().view(*shape)
        else:
            mean = x.view(x.size(0), -1).mean(1).view(*shape)
            std = x.view(x.size(0), -1).std(1).view(*shape)

        x = (x - mean) / (std + self.eps)

        if self.affine:
            shape = [1, -1] + [1] * (x.dim() - 2)
            x = x * self.gamma.view(*shape) + self.beta.view(*shape)
        return x

In [10]:
model = FFA(DEVICE).to(DEVICE)

In [11]:
opt = Adam(model.parameters(), lr=INIT_LR)
loss_fn = nn.MSELoss()

## Training
Training is done by taking random timesteps and learning the noise. The test model was trained by overfitting on a single batch to check model functionality.

In [12]:
def training_loop(model, loader, n_epochs, optim, device, display=False, store_path="simplenet.pt"):
    mse = nn.MSELoss()
    best_loss = float("inf")
    n_steps = 0
    display=0
    for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):
        epoch_loss = 0.0
        pbar = tqdm(train_data_loader)
        display=0
        for x0, y in pbar:#step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")):
            # Loading data
            #x0 = batch[0].to(device)
            x0 = x0.to(DEVICE)
            y =y.to(DEVICE)

            # Getting model estimation of noise based on the images and the time-step
            eta_theta = model(x0)

            # Optimizing the MSE between the noise plugged and the predicted noise
            loss = mse(eta_theta, y)#y
            optim.zero_grad()
            loss.backward()
            optim.step()
            if epoch%2==0:
                epoch1=epoch
                grid2 = make_grid(eta_theta, normalize=True, value_range=(-1, 1), nrow=4)
                save_image(grid2, f"modelsimple512_{epoch1 // 1}.png")
                grid3 = make_grid(y, normalize=True, value_range=(-1, 1), nrow=4)
                save_image(grid3, f"tarsimple_{epoch1 // 1}.png")
                display=1
                    

            epoch_loss += loss.item() * len(x0) / len(loader.dataset)

        # Display images generated at this epoch
        #if epoch%2==0:
        #    show_images(, f"Images generated at epoch {epoch + 1}")

        log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"

        # Storing the model
        if best_loss > epoch_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), store_path)
            log_string += " --> Best model ever (stored)"

        print(log_string)

In [13]:
def show_images(images, title=""):
    """Shows the provided images as sub-pictures in a square"""

    # Converting images to CPU numpy arrays
    if type(images) is torch.Tensor:
        images = images.detach().cpu().numpy()

    # Defining number of rows and columns
    fig = plt.figure(figsize=(8, 8))
    rows = int(len(images) ** (1 / 2))
    cols = round(len(images) / rows)

    # Populating figure with sub-plots
    idx = 0
    for r in range(rows):
        for c in range(cols):
            fig.add_subplot(rows, cols, idx + 1)

            if idx < len(images):
                plt.imshow(images[idx][0], cmap="gray")
                idx += 1
    fig.suptitle(title, fontsize=30)

    # Showing the figure
    plt.show()

In [14]:
# Training
store_path ='./samples/models/simplenet.pt'
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_OUT_DIR = r"./models"
SAMPLE_OUT_DIR = r"./samples"

# Diffusion
NOISE_STEPS = 200
BETA_START = 1e-4
BETA_END = 2e-2


# Training
lr = 1e-5

EPOCHS = 200

# Low effort reproducibility
SEED = 1
torch.manual_seed(SEED)
random.seed(SEED)

loader = train_data_loader

training_loop(model, loader, EPOCHS, optim=Adam(model.parameters(), lr), device=device, store_path=store_path)

Training progress:   0%|[38;2;0;255;0m                                [0m| 0/200 [00:00<?, ?it/s][0m
  0%|                                                   | 0/667 [00:00<?, ?it/s][A
Training progress:   0%|[38;2;0;255;0m                                [0m| 0/200 [00:00<?, ?it/s][0m


OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 0 has a total capacity of 23.66 GiB of which 315.56 MiB is free. Process 41026 has 18.56 GiB memory in use. Including non-PyTorch memory, this process has 3.94 GiB memory in use. Of the allocated memory 3.15 GiB is allocated by PyTorch, and 561.86 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
from torch.autograd import Variable
from matplotlib import pyplot as plt

from matplotlib import cm as c

#model = SimpleNet(3,3,128).to(DEVICE)
#model.load_state_dict(torch.load(f"{MODEL_OUT_DIR}/diff41689934001018.pt"))
model.load_state_dict(torch.load(f"./samples/models/simplenet.pt"))

print(testing_data_loader)
pbar = tqdm(testing_data_loader)
i=1
for x_0, y in pbar:
    x_0 = x_0.to(DEVICE)
    y =y.to(DEVICE)


            # Getting model estimation of noise based on the images and the time-step
    eta_theta = model(x_0)
        
    target = y
        
    target = target.type(torch.FloatTensor)#.cuda()
    target = Variable(target)
    #print(target.shape)
    #print(eta_theta.shape)
    
    for im in range(0,4):
        print(im)
        grid = make_grid(eta_theta[im], normalize=True, value_range=(-1, 1), nrow=4)
        save_image(grid, f"{SAMPLE_OUT_DIR}/pre512_{im}{i // 1}.png")
        grid0 = make_grid(target[im], normalize=True, value_range=(-1, 1), nrow=4)
        save_image(grid0, f"{SAMPLE_OUT_DIR}/targ_{im}{i // 1}.png")



    
    i=i+1
        

print(f"Training loop finished")


## Loading
Loading the saved model.

In [None]:
# model = SimpleNet(3,3,128).to(DEVICE)
# model.load_state_dict(torch.load(f"{MODEL_OUT_DIR}/1686675513310.pt"))