In [20]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os

from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import numpy as np
import skimage
import matplotlib.pyplot as plt

import time

def get_coords(*sidelengths):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int
    dim: int'''
    
    tensors = []
    
    for sidelen in sidelengths:
        tensors.append(torch.linspace(-1, 1, steps=sidelen))
    
    tensors = tuple(tensors)
    coords = torch.stack(torch.meshgrid(*tensors), dim=-1)
    return coords.reshape(-1, len(sidelengths))

In [21]:
def get_sorted_ind(tensor, axis):
    values, ind = torch.sort(tensor[:, axis])
    
    sorted_ind = []
    print(sorted(set(values.tolist())))
    for v in sorted(set(values.tolist())): 
        subset = ((values == v).nonzero(as_tuple=True)[0])
        sorted_ind += ind[subset].sort()[0].tolist()
    
    return sorted_ind

def sort_coords_and_pixels(coords, pixels):
    for axis in reversed(range(coords.shape[1])):
        ind = get_sorted_ind(coords, axis)
        coords = coords[ind]
        pixels = pixels[ind]
        print(axis)
        print(coords)
        print(pixels)
    
    return coords, pixels

In [22]:
coords = get_coords(2, 3, 4)

pixels = torch.randperm(24)

coords = coords[pixels]

pixels = pixels.unsqueeze(1)

In [23]:
print(coords)
print(pixels)

tensor([[ 1.0000, -1.0000,  1.0000],
        [ 1.0000,  0.0000,  0.3333],
        [-1.0000,  0.0000,  0.3333],
        [-1.0000,  1.0000,  0.3333],
        [ 1.0000,  0.0000,  1.0000],
        [-1.0000,  1.0000, -1.0000],
        [ 1.0000,  0.0000, -0.3333],
        [-1.0000,  0.0000,  1.0000],
        [-1.0000,  0.0000, -1.0000],
        [ 1.0000,  1.0000,  1.0000],
        [-1.0000, -1.0000, -1.0000],
        [ 1.0000, -1.0000, -1.0000],
        [-1.0000, -1.0000,  1.0000],
        [-1.0000,  1.0000, -0.3333],
        [-1.0000,  1.0000,  1.0000],
        [ 1.0000,  0.0000, -1.0000],
        [-1.0000, -1.0000,  0.3333],
        [-1.0000,  0.0000, -0.3333],
        [ 1.0000,  1.0000, -0.3333],
        [ 1.0000, -1.0000, -0.3333],
        [ 1.0000, -1.0000,  0.3333],
        [ 1.0000,  1.0000, -1.0000],
        [-1.0000, -1.0000, -0.3333],
        [ 1.0000,  1.0000,  0.3333]])
tensor([[15],
        [18],
        [ 6],
        [10],
        [19],
        [ 8],
        [17],
        [ 7],

In [24]:
coords, pixels = sort_coords_and_pixels(coords, pixels)

[-1.0, -0.3333333134651184, 0.3333333134651184, 1.0]
2
tensor([[-1.0000,  1.0000, -1.0000],
        [-1.0000,  0.0000, -1.0000],
        [-1.0000, -1.0000, -1.0000],
        [ 1.0000, -1.0000, -1.0000],
        [ 1.0000,  0.0000, -1.0000],
        [ 1.0000,  1.0000, -1.0000],
        [ 1.0000,  0.0000, -0.3333],
        [-1.0000,  1.0000, -0.3333],
        [-1.0000,  0.0000, -0.3333],
        [ 1.0000,  1.0000, -0.3333],
        [ 1.0000, -1.0000, -0.3333],
        [-1.0000, -1.0000, -0.3333],
        [ 1.0000,  0.0000,  0.3333],
        [-1.0000,  0.0000,  0.3333],
        [-1.0000,  1.0000,  0.3333],
        [-1.0000, -1.0000,  0.3333],
        [ 1.0000, -1.0000,  0.3333],
        [ 1.0000,  1.0000,  0.3333],
        [ 1.0000, -1.0000,  1.0000],
        [ 1.0000,  0.0000,  1.0000],
        [-1.0000,  0.0000,  1.0000],
        [ 1.0000,  1.0000,  1.0000],
        [-1.0000, -1.0000,  1.0000],
        [-1.0000,  1.0000,  1.0000]])
tensor([[ 8],
        [ 4],
        [ 0],
        [12],


In [25]:
print(coords)
print(pixels)

tensor([[-1.0000, -1.0000, -1.0000],
        [-1.0000, -1.0000, -0.3333],
        [-1.0000, -1.0000,  0.3333],
        [-1.0000, -1.0000,  1.0000],
        [-1.0000,  0.0000, -1.0000],
        [-1.0000,  0.0000, -0.3333],
        [-1.0000,  0.0000,  0.3333],
        [-1.0000,  0.0000,  1.0000],
        [-1.0000,  1.0000, -1.0000],
        [-1.0000,  1.0000, -0.3333],
        [-1.0000,  1.0000,  0.3333],
        [-1.0000,  1.0000,  1.0000],
        [ 1.0000, -1.0000, -1.0000],
        [ 1.0000, -1.0000, -0.3333],
        [ 1.0000, -1.0000,  0.3333],
        [ 1.0000, -1.0000,  1.0000],
        [ 1.0000,  0.0000, -1.0000],
        [ 1.0000,  0.0000, -0.3333],
        [ 1.0000,  0.0000,  0.3333],
        [ 1.0000,  0.0000,  1.0000],
        [ 1.0000,  1.0000, -1.0000],
        [ 1.0000,  1.0000, -0.3333],
        [ 1.0000,  1.0000,  0.3333],
        [ 1.0000,  1.0000,  1.0000]])
tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],

In [56]:
def prod(val) :  
    res = 1 
    for ele in val:  
        res *= ele  
    return res   

def image_to_array(image): 
    length = prod(image.shape)
    return get_coords(*image.shape), image.view(1, length)
    print(length)
    
    
def array_to_image(coords, pixels): 
    size = list()
    for dim in range(coords.shape[1]): 
        i = len(set(coords[:, dim].tolist()))
        size.append(i)
    image = pixels.view(*size)
    return image

In [57]:
coords, pixels = image_to_array(image)
print(coords)
print(pixels)

tensor([[-1.0000, -1.0000, -1.0000],
        [-1.0000, -1.0000, -0.3333],
        [-1.0000, -1.0000,  0.3333],
        [-1.0000, -1.0000,  1.0000],
        [-1.0000,  0.0000, -1.0000],
        [-1.0000,  0.0000, -0.3333],
        [-1.0000,  0.0000,  0.3333],
        [-1.0000,  0.0000,  1.0000],
        [-1.0000,  1.0000, -1.0000],
        [-1.0000,  1.0000, -0.3333],
        [-1.0000,  1.0000,  0.3333],
        [-1.0000,  1.0000,  1.0000],
        [ 1.0000, -1.0000, -1.0000],
        [ 1.0000, -1.0000, -0.3333],
        [ 1.0000, -1.0000,  0.3333],
        [ 1.0000, -1.0000,  1.0000],
        [ 1.0000,  0.0000, -1.0000],
        [ 1.0000,  0.0000, -0.3333],
        [ 1.0000,  0.0000,  0.3333],
        [ 1.0000,  0.0000,  1.0000],
        [ 1.0000,  1.0000, -1.0000],
        [ 1.0000,  1.0000, -0.3333],
        [ 1.0000,  1.0000,  0.3333],
        [ 1.0000,  1.0000,  1.0000]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 2

In [63]:
image1 = array_to_image(coords, pixels)

In [64]:
print(image1)

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])


In [49]:
dim = 0 
len(set(coords[:, dim].tolist()))

2