<a href="https://colab.research.google.com/github/pollinations/hive/blob/main/interesting_notebooks/QuantumVisions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## QuantumVisions
Derived from DirectVisions by Jens Goldberg / [Aransentisssdd(https://https://twitter.com/aransentin)

In [None]:
#@title color quantization for init images
class Color(object):
    """
    Color class
    """

    def __init__(self, red=0, green=0, blue=0, alpha=None):
        """
        Initialize color
        """
        self.red = int(red)
        self.green = int(green)
        self.blue = int(blue)
        self.alpha = int(alpha)if alpha is not None else None
class OctreeNode(object):
    """
    Octree Node class for color quantization
    """

    def __init__(self, level, parent):
        """
        Init new Octree Node
        """
        self.color = Color(0, 0, 0)
        self.pixel_count = 0
        self.palette_index = 0
        self.children = [None for _ in range(8)]
        # add node to current level
        if level < OctreeQuantizer.MAX_DEPTH - 1:
            parent.add_level_node(level, self)

    def is_leaf(self):
        """
        Check that node is leaf
        """
        return self.pixel_count > 0

    def get_leaf_nodes(self):
        """
        Get all leaf nodes
        """
        leaf_nodes = []
        for i in range(8):
            node = self.children[i]
            if node:
                if node.is_leaf():
                    leaf_nodes.append(node)
                else:
                    leaf_nodes.extend(node.get_leaf_nodes())
        return leaf_nodes

    def get_nodes_pixel_count(self):
        """
        Get a sum of pixel count for node and its children
        """
        sum_count = self.pixel_count
        for i in range(8):
            node = self.children[i]
            if node:
                sum_count += node.pixel_count
        return sum_count

    def add_color(self, color, level, parent):
        """
        Add `color` to the tree
        """
        if level >= OctreeQuantizer.MAX_DEPTH:
            self.color.red += color.red
            self.color.green += color.green
            self.color.blue += color.blue
            self.pixel_count += 1
            return
        index = self.get_color_index_for_level(color, level)
        if not self.children[index]:
            self.children[index] = OctreeNode(level, parent)
        self.children[index].add_color(color, level + 1, parent)

    def get_palette_index(self, color, level):
        """
        Get palette index for `color`
        Uses `level` to go one level deeper if the node is not a leaf
        """
        if self.is_leaf():
            return self.palette_index
        index = self.get_color_index_for_level(color, level)
        if self.children[index]:
            return self.children[index].get_palette_index(color, level + 1)
        else:
            # get palette index for a first found child node
            for i in range(8):
                if self.children[i]:
                    return self.children[i].get_palette_index(color, level + 1)

    def remove_leaves(self):
        """
        Add all children pixels count and color channels to parent node 
        Return the number of removed leaves
        """
        result = 0
        for i in range(8):
            node = self.children[i]
            if node:
                self.color.red += node.color.red
                self.color.green += node.color.green
                self.color.blue += node.color.blue
                self.pixel_count += node.pixel_count
                result += 1
        return result - 1

    def get_color_index_for_level(self, color, level):
        """
        Get index of `color` for next `level`
        """
        index = 0
        mask = 0x80 >> level
        if color.red & mask:
            index |= 4
        if color.green & mask:
            index |= 2
        if color.blue & mask:
            index |= 1
        return index

    def get_color(self):
        """
        Get average color
        """
        return Color(
            self.color.red / self.pixel_count,
            self.color.green / self.pixel_count,
            self.color.blue / self.pixel_count)


class OctreeQuantizer(object):
    """
    Octree Quantizer class for image color quantization
    Use MAX_DEPTH to limit a number of levels
    """

    MAX_DEPTH = 8

    def __init__(self):
        """
        Init Octree Quantizer
        """
        self.levels = {i: [] for i in range(OctreeQuantizer.MAX_DEPTH)}
        self.root = OctreeNode(0, self)

    def get_leaves(self):
        """
        Get all leaves
        """
        return [node for node in self.root.get_leaf_nodes()]

    def add_level_node(self, level, node):
        """
        Add `node` to the nodes at `level`
        """
        self.levels[level].append(node)

    def add_color(self, color):
        """
        Add `color` to the Octree
        """
        # passes self value as `parent` to save nodes to levels dict
        self.root.add_color(color, 0, self)

    def make_palette(self, color_count):
        """
        Make color palette with `color_count` colors maximum
        """
        palette = []
        palette_index = 0
        leaf_count = len(self.get_leaves())
        # reduce nodes
        # up to 8 leaves can be reduced here and the palette will have
        # only 248 colors (in worst case) instead of expected 256 colors
        print("creating palette...")
        for level in range(OctreeQuantizer.MAX_DEPTH - 1, -1, -1):
            if self.levels[level]:
                for node in self.levels[level]:
                    leaf_count -= node.remove_leaves()
                    if leaf_count <= color_count:
                        break
                if leaf_count <= color_count:
                    break
                self.levels[level] = []
        # build palette
        for node in self.get_leaves():
            if palette_index >= color_count:
                break
            if node.is_leaf():
                palette.append(node.get_color())
            node.palette_index = palette_index
            palette_index += 1
        return palette

    def get_palette_index(self, color):
        """
        Get palette index for `color`
        """
        return self.root.get_palette_index(color, 0)



In [None]:
#@title Sinkhorn distance for histogram loss
! pip install Ninja
import math
import time
import torch
import torch.utils.cpp_extension
cuda_source = """

#include <torch/extension.h>
#include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/CUDAContext.h>

using at::RestrictPtrTraits;
using at::PackedTensorAccessor;

#if defined(__HIP_PLATFORM_HCC__)
constexpr int WARP_SIZE = 64;
#else
constexpr int WARP_SIZE = 32;
#endif

// The maximum number of threads in a block
#if defined(__HIP_PLATFORM_HCC__)
constexpr int MAX_BLOCK_SIZE = 256;
#else
constexpr int MAX_BLOCK_SIZE = 512;
#endif

// Returns the index of the most significant 1 bit in `val`.
__device__ __forceinline__ int getMSB(int val) {
  return 31 - __clz(val);
}

// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(__HIP_PLATFORM_HCC__)
  int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
#else
  int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
#endif
  for (int i = 0; i != 5; ++i) {
    if (nElem <= threadSizes[i]) {
      return threadSizes[i];
    }
  }
  return MAX_BLOCK_SIZE;
}


template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
    return __shfl_xor_sync(mask, value, laneMask, width);
#else
    return __shfl_xor(value, laneMask, width);
#endif
}

// While this might be the most efficient sinkhorn step / logsumexp-matmul implementation I have seen,
// this is awfully inefficient compared to matrix multiplication and e.g. NVidia cutlass may provide
// many great ideas for improvement
template <typename scalar_t, typename index_t>
__global__ void sinkstep_kernel(
  // compute log v_bj = log nu_bj - logsumexp_i 1/lambda dist_ij - log u_bi
  // for this compute maxdiff_bj = max_i(1/lambda dist_ij - log u_bi)
  // i = reduction dim, using threadIdx.x
  PackedTensorAccessor<scalar_t, 2, RestrictPtrTraits, index_t> log_v,
  const PackedTensorAccessor<scalar_t, 2, RestrictPtrTraits, index_t> dist,
  const PackedTensorAccessor<scalar_t, 2, RestrictPtrTraits, index_t> log_nu,
  const PackedTensorAccessor<scalar_t, 2, RestrictPtrTraits, index_t> log_u,
  const scalar_t lambda) {

  using accscalar_t = scalar_t;

  __shared__ accscalar_t shared_mem[2 * WARP_SIZE];

  index_t b = blockIdx.y;
  index_t j = blockIdx.x;
  int tid = threadIdx.x;

  if (b >= log_u.size(0) || j >= log_v.size(1)) {
    return;
  }
  // reduce within thread
  accscalar_t max = -std::numeric_limits<accscalar_t>::infinity();
  accscalar_t sumexp = 0;
  
  if (log_nu[b][j] == -std::numeric_limits<accscalar_t>::infinity()) {
    if (tid == 0) {
      log_v[b][j] = -std::numeric_limits<accscalar_t>::infinity();
    }
    return;
  }

  for (index_t i = threadIdx.x; i < log_u.size(1); i += blockDim.x) {
    accscalar_t oldmax = max;
    accscalar_t value = -dist[i][j]/lambda + log_u[b][i];
    max = max > value ? max : value;
    if (oldmax == -std::numeric_limits<accscalar_t>::infinity()) {
      // sumexp used to be 0, so the new max is value and we can set 1 here,
      // because we will come back here again
      sumexp = 1;
    } else {
      sumexp *= exp(oldmax - max);
      sumexp += exp(value - max); // if oldmax was not -infinity, max is not either...
    }
  }

  // now we have one value per thread. we'll make it into one value per warp
  // first warpSum to get one value per thread to
  // one value per warp
  for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
    accscalar_t o_max    = WARP_SHFL_XOR(max, 1 << i, WARP_SIZE);
    accscalar_t o_sumexp = WARP_SHFL_XOR(sumexp, 1 << i, WARP_SIZE);
    if (o_max > max) { // we're less concerned about divergence here
      sumexp *= exp(max - o_max);
      sumexp += o_sumexp;
      max = o_max;
    } else if (max != -std::numeric_limits<accscalar_t>::infinity()) {
      sumexp += o_sumexp * exp(o_max - max);
    }
  }
  
  __syncthreads();
  // this writes each warps accumulation into shared memory
  // there are at most WARP_SIZE items left because
  // there are at most WARP_SIZE**2 threads at the beginning
  if (tid % WARP_SIZE == 0) {
    shared_mem[tid / WARP_SIZE * 2] = max;
    shared_mem[tid / WARP_SIZE * 2 + 1] = sumexp;
  }
  __syncthreads();
  if (tid < WARP_SIZE) {
    max = (tid < blockDim.x / WARP_SIZE ? shared_mem[2 * tid] : -std::numeric_limits<accscalar_t>::infinity());
    sumexp = (tid < blockDim.x / WARP_SIZE ? shared_mem[2 * tid + 1] : 0);
  }
  for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
    accscalar_t o_max    = WARP_SHFL_XOR(max, 1 << i, WARP_SIZE);
    accscalar_t o_sumexp = WARP_SHFL_XOR(sumexp, 1 << i, WARP_SIZE);
    if (o_max > max) { // we're less concerned about divergence here
      sumexp *= exp(max - o_max);
      sumexp += o_sumexp;
      max = o_max;
    } else if (max != -std::numeric_limits<accscalar_t>::infinity()) {
      sumexp += o_sumexp * exp(o_max - max);
    }
  }

  if (tid == 0) {
    log_v[b][j] = (max > -std::numeric_limits<accscalar_t>::infinity() ?
                   log_nu[b][j] - log(sumexp) - max : 
                   -std::numeric_limits<accscalar_t>::infinity());
  }
}

template <typename scalar_t>
torch::Tensor sinkstep_cuda_template(const torch::Tensor& dist, const torch::Tensor& log_nu, const torch::Tensor& log_u,
                                     const double lambda) {
  TORCH_CHECK(dist.is_cuda(), "need cuda tensors");
  TORCH_CHECK(dist.device() == log_nu.device() && dist.device() == log_u.device(), "need tensors on same GPU");
  TORCH_CHECK(dist.dim()==2 && log_nu.dim()==2 && log_u.dim()==2, "invalid sizes");
  TORCH_CHECK(dist.size(0) == log_u.size(1) &&
           dist.size(1) == log_nu.size(1) &&
           log_u.size(0) == log_nu.size(0), "invalid sizes");
  auto log_v = torch::empty_like(log_nu);
  using index_t = int32_t;
  
  auto log_v_a = log_v.packed_accessor<scalar_t, 2, RestrictPtrTraits, index_t>();
  auto dist_a = dist.packed_accessor<scalar_t, 2, RestrictPtrTraits, index_t>();
  auto log_nu_a = log_nu.packed_accessor<scalar_t, 2, RestrictPtrTraits, index_t>();
  auto log_u_a = log_u.packed_accessor<scalar_t, 2, RestrictPtrTraits, index_t>();
  
  auto stream = at::cuda::getCurrentCUDAStream();

  int tf = getNumThreads(log_u.size(1));
  dim3 blocks(log_v.size(1), log_u.size(0));
  dim3 threads(tf);
  
  sinkstep_kernel<<<blocks, threads, 2*WARP_SIZE*sizeof(scalar_t), stream>>>(
    log_v_a, dist_a, log_nu_a, log_u_a, static_cast<scalar_t>(lambda)
    );

  return log_v;
}

torch::Tensor sinkstep_cuda(const torch::Tensor& dist, const torch::Tensor& log_nu, const torch::Tensor& log_u,
                            const double lambda) {
    return AT_DISPATCH_FLOATING_TYPES(log_u.scalar_type(), "sinkstep", [&] {
       return sinkstep_cuda_template<scalar_t>(dist, log_nu, log_u, lambda);
    });
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("sinkstep", &sinkstep_cuda, "sinkhorn step");
}

"""

wasserstein_ext = torch.utils.cpp_extension.load_inline("wasserstein", cpp_sources="", cuda_sources=cuda_source,
                                                    extra_cuda_cflags=["--expt-relaxed-constexpr"], build_directory = "."   )

def sinkstep(dist, log_nu, log_u, lam: float):
    # dispatch to optimized GPU implementation for GPU tensors, slow fallback for CPU
    if dist.is_cuda:
        return wasserstein_ext.sinkstep(dist, log_nu, log_u, lam)
    assert dist.dim() == 2 and log_nu.dim() == 2 and log_u.dim() == 2
    assert dist.size(0) == log_u.size(1) and dist.size(1) == log_nu.size(1) and log_u.size(0) == log_nu.size(0)
    log_v = log_nu.clone()
    for b in range(log_u.size(0)):
        log_v[b] -= torch.logsumexp(-dist/lam+log_u[b, :, None], 0)
    return log_v

class SinkhornOT(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mu, nu, dist, lam=1e-3, N=100):
        assert mu.dim() == 2 and nu.dim() == 2 and dist.dim() == 2
        bs = mu.size(0)
        d1, d2 = dist.size()
        assert nu.size(0) == bs and mu.size(1) == d1 and nu.size(1) == d2
        log_mu = mu.log()
        log_nu = nu.log()
        log_u = torch.full_like(mu, -math.log(d1))
        log_v = torch.full_like(nu, -math.log(d2))
        for i in range(N):
            log_v = sinkstep(dist, log_nu, log_u, lam)
            log_u = sinkstep(dist.t(), log_mu, log_v, lam)

        # this is slight abuse of the function. it computes (diag(exp(log_u))*Mt*exp(-Mt/lam)*diag(exp(log_v))).sum()
        # in an efficient (i.e. no bxnxm tensors) way in log space
        distances = (-sinkstep(-dist.log()+dist/lam, -log_v, log_u, 1.0)).logsumexp(1).exp()
        ctx.log_v = log_v
        ctx.log_u = log_u
        ctx.dist = dist
        ctx.lam = lam
        return distances

    @staticmethod
    def backward(ctx, grad_out):
        return grad_out[:, None] * ctx.log_u * ctx.lam, grad_out[:, None] * ctx.log_v * ctx.lam, None, None, None

def get_coupling(mu, nu, dist, lam=1e-3, N=1000):
    assert mu.dim() == 2 and nu.dim() == 2 and dist.dim() == 2
    bs = mu.size(0)
    d1, d2 = dist.size()
    assert nu.size(0) == bs and mu.size(1) == d1 and nu.size(1) == d2
    log_mu = mu.log()
    log_nu = nu.log()
    log_u = torch.full_like(mu, -math.log(d1))
    log_v = torch.full_like(nu, -math.log(d2))
    for i in range(N):
        log_v = sinkstep(dist, log_nu, log_u, lam)
        log_u = sinkstep(dist.t(), log_mu, log_v, lam)
    return (log_v[:, None, :]-dist/lam+log_u[:, :, None]).exp()

In [None]:

from numpy.core.numeric import False_
from torch._C import LongStorageBase
! nvidia-smi -L

! rm -rf images
# ! rm *.png
! mkdir images

# Input prompts. Each prompt has "text" and a "weight"
# Weights can be negatives, useful for discouraging specific artifacts
texts = [
    {
        "text": "A mermaid eating sushi",
        "weight": 1.0,
    },{ # Set to 1 for very pixel-art style images, or -1 for smoother, more natural images when using scaling mode = nearest
        "text": "#pixelart",
        "weight": 1.0,
    # },{
    #     "text": "vivid CryEngine watercolor and pencil sketch, 8k resolution",
    #     "weight": 0.5,
    # # },{
    # },{
    #     "text": "Beautiful and detailed fantasy painting.",
    #     "weight": 0.2,
    # # },{
    # # #     "text": "Full body.",
    # # #     "weight": 0.1,
    # },{ # Improves contrast, object coherence, and adds a nice depth of field effect
    #     "text": "Rendered in unreal engine, trending on artstation.",
    #     "weight": 0.2,
    # },{
    #     "text": "speedpainting matte painting",
    #     "weight": 0.2,
    # # },{
    # #     "text": "Vivid Colors",
    # #     "weight": 0.15,
    # },{ # Doesn't seem to do much, but also doesn't seem to hurt. 
    #     "text": "confusing, incoherent",
    #     "weight": -0.25,
    },{ # Not really strong enough to remove all signatures... but I'm ok with small ones
        "text":"text, signature",
        "weight":-1
    }
]

#Image prompts
images = [
          # {
          #     "fpath": "waste.png",
          #     "weight": 0.2,
          #     "cuts": 16,
          #     "noise": 0.0
          # },{
          #     "fpath": "waste_2.png",
          #     "weight": 0.2,
          #     "cuts": 16,
          #     "noise": 0.0
          # }
          ]

# random seed
# Set to None for random seed
seed = None

#starting noise 
pix_noise_scale = 2.0
pix_noise_persistence = 0.8
pix_noise_clamp = 50.0

palette_noise_scale = 1.0
palette_noise_clamp = 6.0
palette_brightness = 0
init_rgb = True

use_transparent = True

# Number of times to run
images_n = 1

# Save rate for video. Does slow down training if you set it low.
save_interval = 1000

# Use "nearest" for pixel-art style images with very precise edges and sharp corners
# Use Lanczos for images with smoother edges and more natural shapes
# Bicubic and bilinear are intermediate between the two
pyramid_scaling_mode = "nearest" # "lanczos" #'bicubic' "nearest" "bilinear" 

# AdamW is real basic and gets the job done
# RAdam seems to work *extremely well* but seems to introduce some color instability?, use 0.5x lr
# Yogi is just really blurry for some reason, use 5x + lr
# Ranger works great. use 3-4x LR
optimizer_type = "Ranger" # "AdamW", "AccSGD","Ranger","RangerQH","RangerVA","AdaBound","AdaMod","Adafactor","AdamP","AggMo","DiffGrad","Lamb","NovoGrad","PID","QHAdam","QHM","RAdam","SGDP","SGDW","Shampoo","SWATS","Yogi"

# Image that sets pixel values and color palette can be set independently
initial_image = None #"skeleton handd.png"
palette_image = None

# How strong is the influence of the initial image? 
initial_image_strength = 4

# Aspect ratio of the output image
# Also sets the size of the smallest pyramid step, so make this as small as possible for your target shape
aspect_ratio = (4, 4) # (3, 4)

# Max dim of the final output image.
max_dim = 800

# number of colors to use.
# overridden by manual palette
# Use 2 for duotone / 1-bit images. 
# Or try a very large number of colors for a more painterly look
num_colors = 4

# Use this to define a manual palette (in hex codes)
# Set to None to use a randomly initialized palette
palette = None
# ["FDFDFD", "222222", "444444", "666666", "888888", "AAAAAA", "CCCCCC", "020202"] 
# [
#             "614ED9",
#             "405B73",
#             "F2BC57",
#             "F29727",
#             "D97F30",
#             "4FE3E4",
#             "E6E7B9",
# ]

# How to scale the palette before applying sigmoid
# Higher = more extremes, lower = more subtle
# Recommended = 0.1, haven't experimented with this much
palette_contrast = 0.1

# How many levels of detail to use. You want to have at least log_2 of your max dimension/aspect ratio
# So for a 1024 max_dim and (3,4) aspect ratio use at least 8 (you can use fewer but it'll be less detailed probably)
# Higher isn't necessarily better? Hard to tell exactly what effect more steps has.
pyramid_steps = 11

# Add an extra step to the pyramid with size (1, 1)
# The model will have an easier time setting an overall color and brightness for the image, but may have less large-scale contrast over the image.
add_global_color = True

# Optimizer settings for different training steps
stages = (
            { #First stage does rough detail.
        "cuts": 1,
        "cycles": 1000,
        # Higher pixel LR seems to result in less blending
        "lr_pixel": 4.0, 
        # Some decay on the pixels helps keep them from saturating and getting stuck
        "decay_pixel": 1e-5,
        # Learning rate of the colors-- set to 0 to use a fixed palette
        "lr_palette": 4.0, 
        "decay_palette": 0.0,
        # Higher noise helps increase contrast and edge sharpness
        "noise": 0.2,
        # TV loss weight. You probably won't need this but if your image is too noisy this will smooth it out
        # May result in large flat areas of a single color if set too high
        "denoise": 0.0,
        # Use this to control how much blending is allowed. Positive = no blending. Negative = blending encouraged.
        "uncertanty_loss_scale": -0.1,#0.05, 
        # Histogram loss: not used currently
        "hist_loss_scale": 2,
        "checkin_interval": 100,
        # Three different ways to set the scale of LR at each resolution level:
        # "lr_persistence": 0.95, # Depends on number of pyramid levels
        "pyramid_lr_min" : 0.5, # Independent of number of pyramid levels. Defaults to 1
        # "lr_scales": [0.25,0.15,0.15,0.15,0.15,0.05,0.05,0.01,0.01], # manually set lr at each level
    }, { # 2nd stage does fine detail and polish
        # (If you change settings in stage 1 make sure to also change things here!)
        "cuts": 2,
        "cycles": 1000,
        "lr_pixel": 1.0,
        "decay_pixel": 1e-8,
        "lr_palette": 1.0,
        "decay_palette": 0.0,
        "noise": 0.2,
        "denoise": 0.00,
        "uncertanty_loss_scale": 0,# 0.1,#0.2,
        "hist_loss_scale": 2,
        "checkin_interval": 100,
        # "lr_persistence": 1.0,
        # "pyramid_lr_min" : 1
        # "lr_scales": [0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1],
    },
    # You can add new stages, or remove stage 2; but I find that the two stages is often just right
)

# Experimental color histogram matching
# Doesn't really work as intended 
hist_image = None #"waste_4.png"
hist_bins = 100

##### END OF CONFIG #####


#Calculate layer dims
if pyramid_steps > 1:
  pyramid_lacunarity = (max_dim / max(aspect_ratio))**(1.0/(pyramid_steps-1))
else:
  pyramid_lacunarity = 1
scales = [pyramid_lacunarity**step for step in range(pyramid_steps)]
dims = []
if add_global_color:
  dims.append([1,1])
for step in range(pyramid_steps):
  scale = pyramid_lacunarity**step
  dim = [int(round(aspect_ratio[0] * scale)), int(round(aspect_ratio[1] * scale))]
  # Ensure that no two levels have the same dims
  if len(dims) > 0:
    if dim[0] <= dims[-1][0]:
      dim[0] = dims[-1][0]+1
    if dim[1] <= dims[-1][1]:
      dim[1] = dims[-1][1]+1
  dims.append(dim)
print(dims)
display_size = [i * 160 for i in aspect_ratio]
pyramid_steps = len(dims)
for stage in stages:
  if "lr_scales" not in stage:
    if "lr_persistence" in stage:
      persistence = stage["lr_persistence"]
    elif "pyramid_lr_min" in stage:
      if pyramid_steps > 1:
        persistence = stage["pyramid_lr_min"]**(1.0/float(pyramid_steps-1))
      else:
        persistence = 1
    else:
      persistence = 1.0  
    lrs = [persistence**i for i in range(pyramid_steps)]
    sum_lrs = sum(lrs)
    stage["lr_scales"] = [rate / sum_lrs for rate in lrs]
    print(persistence, stage["lr_scales"])

if palette is not None:
  num_colors = len(palette)

debug_clip_cuts = False

import sys, os, random, shutil, math
import torch, torchvision
from IPython import display
import numpy as np
from PIL import Image


bilinear = torchvision.transforms.functional.InterpolationMode.BILINEAR
bicubic = torchvision.transforms.functional.InterpolationMode.BICUBIC

torch.autograd.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True
torch.set_default_tensor_type(torch.cuda.FloatTensor)

if seed is not None:
  random.seed(seed)
  torch.manual_seed(seed)
  # torch.use_deterministic_algorithms(mode=False)
  np.random.seed(seed)

if not os.path.isdir("CLIP"):
  ! pip -q install ftfy
  ! git clone https://github.com/openai/CLIP.git --depth 1
  ! pip -q install torch_optimizer

from CLIP import clip
import torch_optimizer as optim


def normalize_image(image):
  R = (image[:,0:1] - 0.48145466) /  0.26862954
  G = (image[:,1:2] - 0.4578275) / 0.26130258 
  B = (image[:,2:3] - 0.40821073) / 0.27577711
  return torch.cat((R, G, B), dim=1)

@torch.no_grad()
def loadImage(filename):
  data = open(filename, "rb").read()
  image = torch.ops.image.decode_png(torch.as_tensor(bytearray(data)).cpu().to(torch.uint8), 3).cuda().to(torch.float32) / 255.0
  # image = normalize_image(image)
  return image.unsqueeze(0).cuda()

def getClipTokens(image, cuts, noise, do_checkin, perceptor):
    im = normalize_image(image)
    cut_data = torch.zeros(cuts, 3, perceptor["size"], perceptor["size"])
    for c in range(cuts):
      angle = random.uniform(-20.0, 20.0)
      img = torchvision.transforms.functional.rotate(im, angle=angle, expand=True, interpolation=bilinear)

      # if random.random() > 0.5:
      #   img = torchvision.transforms.functional.vflip(img)
                                                  
      padv = im.size()[2] // 8
      img = torch.nn.functional.pad(img, pad=(padv, padv, padv, padv))

      size = img.size()[2:4]
      mindim = min(*size)

      if mindim <= perceptor["size"]-32:
        width = mindim - 1
      else:
        width = random.randint( perceptor["size"]-32, mindim-1 )

      oy = random.randrange(0, size[0]-width)
      ox = random.randrange(0, size[1]-width)
      img = img[:,:,oy:oy+width,ox:ox+width]

      img = torch.nn.functional.interpolate(img, size=(perceptor["size"], perceptor["size"]), mode='bilinear', align_corners=False)
      cut_data[c] = img
    cut_data += noise * torch.randn_like(cut_data, requires_grad=False)

    if debug_clip_cuts and do_checkin:
      displayImage(cut_data)

    clip_tokens = perceptor['model'].encode_image(cut_data)
    return clip_tokens


def loadPerceptor(name):
  model, preprocess = clip.load(name, device="cuda")

  tokens = []
  imgs = []
  for text in texts:
    tok = model.encode_text(clip.tokenize(text["text"]).cuda())
    tokens.append( tok )

  perceptor = {"model":model, "size": preprocess.transforms[0].size, "tokens": tokens, }
  for img in images:
    image = loadImage(img["fpath"])
    tokens = getClipTokens(image, img["cuts"], img["noise"], False, perceptor )
    imgs.append(tokens)
  perceptor["images"] = imgs
  return perceptor

perceptors = (
  loadPerceptor("ViT-B/32"),
  loadPerceptor("ViT-B/16"),
  # loadPerceptor("RN50x16"),
)

@torch.no_grad()
def saveImage(image, filename):
  # R = image[:,0:1] * 0.26862954 + 0.48145466
  # G = image[:,1:2] * 0.26130258 + 0.4578275
  # B = image[:,2:3] * 0.27577711 + 0.40821073
  # image = torch.cat((R, G, B), dim=1)
  size = image.size()

  if use_transparent:
    im_np = image.cpu().numpy() * 255
    image = Image.fromarray(im_np, "RGBA")#torchvision.transforms.functional.to_pil_image(image, 'RGBA')
    image.save(filename)
  else:
    image = (image[0].clamp(0, 1) * 255).to(torch.uint8)
    png_data = torch.ops.image.encode_png(image.cpu(), 6)
    open(filename, "wb").write(bytes(png_data))

# TODO: Use torchvision normalize / unnormalize
def unnormalize_image(image):
  
  R = image[:,0:1] * 0.26862954 + 0.48145466
  G = image[:,1:2] * 0.26130258 + 0.4578275
  B = image[:,2:3] * 0.27577711 + 0.40821073
  
  return torch.cat((R, G, B), dim=1)


def paramsToImage(params_pyramid, params_palette, quantize = False):
  pixels = torch.zeros_like(params_pyramid[-1])
  out_dim = params_pyramid[-1].shape[2:]
  for i in range(len(params_pyramid)):
    if pyramid_scaling_mode == "lanczos" and not (params_pyramid[i].shape[2] == 1 and params_pyramid[i].shape[3] == 1):
      pixels += resample(params_pyramid[i], params_pyramid[-1].shape[2:])
    else:
      if pyramid_scaling_mode == "nearest" or (params_pyramid[i].shape[2] == 1 and params_pyramid[i].shape[3] == 1):
        pixels += torch.nn.functional.interpolate(params_pyramid[i], size=out_dim, mode="nearest")
      else:
        pixels += torch.nn.functional.interpolate(params_pyramid[i], size=out_dim, mode=pyramid_scaling_mode, align_corners=True)
    # pixels += torch.nn.functional.softmax(p, dim=1)
  pixels = torch.nn.functional.softmax(pixels, dim=1)
  max_val, _ = pixels.max(dim=1)
  uncertanty = (1.0 - max_val).mean()
  if quantize:
    a = torch.argmax(pixels, dim=1)
    p = torch.zeros_like(pixels)
    p = p.scatter(1, a.unsqueeze(1), 1.0)
    pixels = p

  palette = torch.sigmoid(params_palette * palette_contrast)
  # palette = params_palette.clamp(0.0, 1.0)
  if use_transparent:
    pixels = pixels.view(num_colors, -1)
    pixels = torch.matmul(palette.T, pixels)
    pixels = pixels.view((1, 4, *out_dim))
  else:
    pixels = pixels.view(num_colors, -1)
    pixels = torch.matmul(palette.T, pixels)
    pixels = pixels.view((1, 3, *out_dim))

  return pixels, uncertanty

def paletteToImage(params_palette):
  p = torch.sigmoid(params_palette.T.unsqueeze(0).unsqueeze(2) * palette_contrast)
  p = torch.nn.functional.interpolate(p, size=(64, max_dim), mode="nearest")
  return p

def imageToParams(image, palette_only = False):
  octree = OctreeQuantizer()
  height, width = image.shape[-2:]
  pixels = torch.zeros(1, num_colors, height, width)
  for j in range(height):
        for i in range(width):
            octree.add_color(Color(*image[0,:, j, i] * 255))

  palette = octree.make_palette(num_colors)

  if not palette_only:
    for j in range(height):
        for i in range(width):
            index = octree.get_palette_index(Color(*image[0, :, j, i] * 255))
            pixels[0, index, j, i] = 1.0

  palette_colors = torch.zeros((num_colors, 3))
  for i, color in enumerate(palette):
    palette_colors[i][0] = color.red
    palette_colors[i][1] = color.green
    palette_colors[i][2] = color.blue
  palette_colors = torch.logit(palette_colors.float() / 255.0, eps = 1e-6) / palette_contrast
  return pixels, palette_colors

@torch.no_grad()
def displayImage(image):
  size = image.size()

  width = size[0] * size[3] + (size[0]-1) * 4
  image_row = torch.randint(256, size=(3, size[2], width)).to(torch.float32)

  nw = 0
  for n in range(size[0]):
    if image.shape[1] == 4:
      alpha = image[n, 3, :,:]
      image_row[:,:,nw:nw+size[3]] *= 1.0 - alpha
      image_row[:,:,nw:nw+size[3]] += image[n,:3,:,:].clamp(0, 1) * 255 * alpha
    else:
      image_row[:,:,nw:nw+size[3]] = image[n,:3,:,:].clamp(0, 1) * 255
    nw += size[3] + 4

  png_data = torch.ops.image.encode_png(image_row.to(torch.uint8).cpu(), 6)
  image = display.Image(bytes(png_data))
  display.display( image )

def lossClip(image, cuts, noise, do_checkin):
  losses = []

  max_loss = 0.0
  for text in texts:
    max_loss += abs(text["weight"]) * len(perceptors)
  for img in images:
    max_loss += abs(img["weight"]) * len(perceptors)

  for perceptor in perceptors:
    clip_tokens = getClipTokens(image, cuts, noise, do_checkin, perceptor)
    for t, tokens in enumerate( perceptor["tokens"] ):
      similarity = torch.cosine_similarity(tokens, clip_tokens)
      weight = texts[t]["weight"]
      if weight > 0.0:
        loss = (1.0 - similarity) * weight
      else:
        loss = similarity * (-weight)
      losses.append(loss / max_loss)

    for img in images:
      for i, prompt_image in enumerate(perceptor["images"]):
        img_tokens = prompt_image
        weight = images[i]["weight"] / float(images[i]["cuts"])
        for token in img_tokens:
          similarity = torch.cosine_similarity(token.unsqueeze(0), clip_tokens)
          if weight > 0.0:
            loss = (1.0 - similarity) * weight
          else:
            loss = similarity * (-weight)
          losses.append(loss / max_loss)
  return losses

def lossTV(image, strength):
  Y = (image[:,:,1:,:] - image[:,:,:-1,:]).abs().mean()
  X = (image[:,:,:,1:] - image[:,:,:,:-1]).abs().mean()
  loss = (X + Y) * 0.5 * strength
  return loss

def lossHist(image, target_hist):
  if target_hist is None:
    return 0
  img_hist = torch.histc(image, bins=hist_bins, min=0, max = 1)
  img_hist /= img_hist.sum()
  cost = (img_hist[None, :] - target_hist[:, None])**2
  cost /= cost.max()
  # print(cost.shape, img_hist.shape, target_hist.shape)
  return sinkstep(cost, img_hist.log().unsqueeze(0), target_hist.log().unsqueeze(0), 1e-3)


def cycle(c, stage, optimizer, params_pyramid, params_palette):
  do_checkin = (c+1) % stage["checkin_interval"] == 0 or c == 0
  with torch.enable_grad():
    losses = []
    image, uncertanty = paramsToImage(params_pyramid, params_palette)
    
    # image = torchvision.transforms.functional.gaussian_blur(image, 5)
    losses += lossClip( image, stage["cuts"], stage["noise"], do_checkin )
    losses += [lossTV( image, stage["denoise"] )]
    losses += [uncertanty * stage["uncertanty_loss_scale"]]
    # losses += [lossHist(image, stage["target_hist"]) * stage["hist_loss_scale"]]

    loss_total = sum(losses).sum()
    optimizer.zero_grad(set_to_none=True)
    loss_total.backward(retain_graph=False)
    # if c <= warmup_its:
    #   optimizer.param_groups[0]["lr"] = stage["lr_luma"] * c / warmup_its
    #   optimizer.param_groups[1]["lr"] = stage["lr_chroma"] * c / warmup_its
    optimizer.step()

  if (c+1) % save_interval == 0 or c == 0:
    nimg, _ = paramsToImage(params_pyramid, params_palette)
    saveImage(nimg, f"images/frame_{stage['n']:02}_{c:05}.png")
  if do_checkin:
    TV = losses[1].sum().item()
    nimg, uncertanty = paramsToImage(params_pyramid, params_palette, quantize=False)
    print( "Cycle:", str(stage["n"]) + ":" + str(c), "CLIP Loss:", losses[0].sum().item(), "TV loss:", TV, "Uncertanty:", uncertanty.item())#, "hist loss:", losses[3].item())
    displayImage(paletteToImage(params_palette))
    displayImage(torch.nn.functional.interpolate(nimg, size=display_size, mode='nearest'))
    saveImage(nimg, texts[0]["text"] + f"_{stage['n']}" + ".png" )
    nimg, _ = paramsToImage(params_pyramid, params_palette, quantize = True)
    displayImage(torch.nn.functional.interpolate(nimg, size=display_size, mode='nearest'))
    saveImage(nimg, texts[0]["text"] + f"_quantized_{stage['n']}" + ".png" )

    # for i in range(len(dims)):
    #   print(i, "luma", params_luma[i].min().item(), params_luma[i].max().item())
    #   print(i, "chroma", params_chroma[i].min().item(), params_chroma[i].max().item())
      
    # for i in range(len(dims)):
    #   if pyramid_scaling_mode == "lanczos":
    #     nimg = paramsToImage([resample(params_luma[i], display_size)], [resample(params_chroma[i], display_size)])
    #   else:
    #     nimg = paramsToImage([params_luma[:i+1]], [params_chroma[:i+1]])
    #     nimg = torch.nn.functional.interpolate(nimg, size=display_size, mode=pyramid_scaling_mode)
    #   displayImage(nimg)
    

def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))

def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()

def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]

def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size

    input = input.reshape([n * c, 1, h, w])

    # if dh < h:
    kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
    pad_h = (kernel_h.shape[0] - 1) // 2
    input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), 'reflect')
    input = torch.nn.functional.conv2d(input, kernel_h[None, None, :, None])

    # if dw < w:
    kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
    pad_w = (kernel_w.shape[0] - 1) // 2
    input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
    input = torch.nn.functional.conv2d(input, kernel_w[None, None, None, :])

    input = input.reshape([n, c, h, w])
    return torch.nn.functional.interpolate(input, size, mode='bicubic', align_corners=align_corners)

def init_optim(params_pyramid, params_palette, stage):
  lr_scales = stage["lr_scales"]
  params = []
  for i in range(len(lr_scales)):
    params.append({"params": params_pyramid[i], "lr":stage["lr_pixel"] * lr_scales[i], "weight_decay":stage["decay_pixel"]})
  params.append({"params": params_palette, "lr":stage["lr_palette"], "weight_decay":stage["decay_palette"]})
  optimizer = getattr(optim, optimizer_type, None)(params)
  return optimizer

def main():
  params_pyramid = []
  if initial_image is not None:
    for dim in dims:
      pix = torch.zeros((1, num_colors, dim[0], dim[1])) + 1e-6
      param_pix = torch.nn.parameter.Parameter( pix.cuda(), requires_grad=True)
      params_pyramid.append(param_pix)
    print("Loading image")
    image = loadImage(initial_image)
    image = torch.nn.functional.interpolate(image, size=dims[-1], mode='bicubic', align_corners=False)
    pix, palette_colors = imageToParams(image)
    pix *= initial_image_strength + 1e-6
    param_pix = torch.nn.parameter.Parameter( pix.cuda(), requires_grad=True)
    params_pyramid[-1] = param_pix
    print("loading complete")
  else:
    for i, dim in enumerate(dims):
        pix_c = (torch.randn(size = (1, num_colors, dim[0], dim[1])) * pix_noise_scale * pix_noise_persistence**i / len(dims)).clamp(-pix_noise_clamp / len(dims), pix_noise_clamp / len(dims)) 
        param_pix = torch.nn.parameter.Parameter( pix_c.cuda(), requires_grad=True)
        params_pyramid.append(param_pix)
    if palette is None:
      if init_rgb:
        palette_colors = torch.rand(size=(num_colors, 4 if use_transparent else 3))
        palette_colors = torch.logit(palette_colors, eps=1e-6) / palette_contrast  + palette_brightness
      else:
        palette_colors = (torch.randn(size = (num_colors, 4 if use_transparent else 3)) * palette_noise_scale).clamp(-palette_noise_clamp, palette_noise_clamp) + paette_brightness
    else:
      palette_colors = torch.zeros((num_colors, 4 if use_transparent else 3))
      for c, color in enumerate(palette):
        palette_colors[c, :] = torch.tensor([int(color[i:i+2], 16) for i in (0, 2, 4)])
      palette_colors = torch.logit(palette_colors.float() / 255.0, eps = 1e-6) / palette_contrast
      print("Palette initialized:")
      displayImage(paletteToImage(palette_colors))
  if palette_image is not None:
    image = loadImage(palette_image)
    _, palette_colors = imageToParams(image, True)
  params_palette = torch.nn.parameter.Parameter( palette_colors.cuda(), requires_grad=True)
  # if color_space == "YCoCg":
  #   params_pyramid[0][:, 0, :, :] += luma_noise_mean
  # elif color_space == "RGB":
  #   params_pyramid[0] += luma_noise_mean

  target_hist = None  
  if hist_image is not None:
    target_hist = torch.histc(loadImage(hist_image), bins=hist_bins, min=0, max = 1)
    target_hist /= target_hist.sum

  optimizer = init_optim(params_pyramid, params_palette, stages[0])

  for n, stage in enumerate(stages):
    stage["n"] = n
    stage["target_hist"] = target_hist
    if n > 0:
      # if stage['dim'][0] != param_luma.shape[2]:
      #   if upscaling_mode == "lanczos":
      #     luma = resample(param_luma, ( stage['dim'][0], stage['dim'][1] ))
      #     chroma = resample(param_chroma, ( stage['dim'][0], stage['dim'][1] )) 
      #     param_luma = torch.nn.parameter.Parameter( luma.cuda(), requires_grad=True )
      #     param_chroma = torch.nn.parameter.Parameter( chroma.cuda(), requires_grad=True )
      #   else:
      #     param_luma = torch.nn.parameter.Parameter(torch.nn.functional.interpolate(param_luma.data, size=( stage['dim'][0], stage['dim'][1] ), mode=upscaling_mode, align_corners=False), requires_grad=True ).cuda()
      #     param_chroma = torch.nn.parameter.Parameter(torch.nn.functional.interpolate(param_chroma.data, size=( stage['dim'][0]//chroma_fraction, stage['dim'][1]//chroma_fraction ), mode=upscaling_mode, align_corners=False), requires_grad=True ).cuda()
      # if "init_noise" in stage:
      #   param_luma += torch.randn_like(param_luma) * stage["init_noise"]
      #   param_chroma += torch.randn_like(param_chroma) * stage["init_noise"]
      # optimizer = init_optim(params_luma, params_chroma, stage)
      for i in range(len(dims)):
          optimizer.param_groups[i]["lr"] = stage["lr_pixel"] * stage["lr_scales"][i]
      optimizer.param_groups[-1]["lr"] = stage["lr_palette"]
      #   optimizer.param_groups[0]["lr"] = stage["lr_luma"] * c / warmup_its
      #   optimizer.param_groups[1]["lr"] = stage["lr_chroma"] * c / warmup_its
        # params.append({"params":params_luma[i], "lr":stage["lr_luma"] * lr_scales[i], "weight_decay":stage["decay_luma"]})
        # params.append({"params":params_chroma[i], "lr":stage["lr_chroma"] * lr_scales[i], "weight_decay":stage["decay_chroma"]})

    for c in range(stage["cycles"]):
      cycle( c, stage, optimizer, params_pyramid, params_palette)
    # for i in range(len(dims)):
    #   if pyramid_scaling_mode == "lanczos":
    #     nimg = paramsToImage(params_luma[:i+1], params_chroma[:i+1])
    #     nimg = torch.nn.functional.interpolate(nimg, size=display_size, mode="bicubic")
    #   else:
    #     nimg = paramsToImage(params_luma[:i+1], params_chroma[:i+1])
    #     nimg = torch.nn.functional.interpolate(nimg, size=display_size, mode=pyramid_scaling_mode)
    #   displayImage(nimg)

for _ in range(images_n):
  main()


In [None]:
auto_video = True
if auto_video:
    !rm /content/video.mp4
    # Choose frame rate and final save path (in Colab's filesystem)
    fps = 24
    save_path = '/content/video.mp4'

    # Create video
    !ffmpeg -r {fps} -pattern_type glob -i '/content/images/*.png' -vcodec libx264 -crf 15 -pix_fmt yuv420p {save_path} -hide_banner -loglevel error