In [1]:
import sys
sys.path.append('/home/tom/repos/SelfOrganisingWeddingInvite/SelfOrganisingCollage')

import argparse
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import random
from tqdm import tqdm
import torch
from torchvision import models, transforms
import PIL

from photo_library import PhotoLibrary
from SOM import SelfOrganisingMap
from feature_extraction import FeatureExtractor
from draw_samples import sample_from_unit_cube
from patch_worker import Patchworker
from utils import ravel_index, simplify_shape

args = lambda: None
args.directory = '/home/tom/Pictures/test_pics'
args.height = 48
args.width = 27
args.feature = 'mean_color'
args.epochs = 20
args.reuse_penalty = 100.
args.sample_size = 1000
args.border = 20
args.magnify = 20

In [2]:
library = PhotoLibrary(args.directory)
extractor = FeatureExtractor.factory(args.feature)

try:
    with open(os.path.join(args.directory, args.feature), "rb") as f:
        feature_dict = pickle.load(f)

except FileNotFoundError:
    feature_dict = extractor.process_library(library)
    with open(os.path.join(args.directory, args.feature), "wb") as f:
        pickle.dump(feature_dict, f)

In [None]:
som = SelfOrganisingMap(shape=[args.height, args.width, extractor.feature_dim], sigma=5., eta=1.)

# draw a subset of just the most interesting the entire photo library
sample, keys = extractor.draw_interesting_samples(feature_dict, args.sample_size)

for i in range(args.epochs):
    for j, feature in enumerate(sample[torch.randperm(args.sample_size),:]):
        som.update(feature.cuda())
        plt.imsave('som_training/som_%i_%i' % (i, j), torch.transpose(som.grid,0,1))
        
plt.imshow(torch.transpose(som.grid,0,1))

In [None]:
patch = Patchworker([args.height, args.width])
times_used = torch.zeros(len(sample)).unsqueeze(1).cuda()
layout = []

# I can omit the occupied parts of the grid to speed this up
while not patch.full():
    dist = (sample.cuda().unsqueeze(1) - som.grid.reshape(-1,3).unsqueeze(0)).norm(dim=2)
    dist += (times_used * args.reuse_penalty) # penalise the samples that have already been placed on the grid
    dist[:, patch.occupied.reshape(-1)] = dist.max() # omit any neurons that already have a photo covering them

    # find the neuron and sample that match most closely
    winning_sample, winning_neuron = ravel_index(dist.argmin(), (som.grid.shape[0] * som.grid.shape[1]))
    target_coord = torch.Tensor(ravel_index(winning_neuron, som.grid.shape[1]))
    
    try:
        target_shape = simplify_shape(library[keys[winning_sample]].size, target_area=8)

        # add the photo to the patch work
        coord, shape = patch.add_patch(target_coord, target_shape)

        layout.append([keys[winning_sample], coord, shape, sample[winning_sample]])
    except:
        pass
    
    times_used[winning_sample] += 1

for key, coord, shape, feature in layout:
    plt.scatter(coord[1], -coord[0], c=feature.numpy()[None,:])

plt.show()

In [None]:
patchwork = torch.zeros([args.height, args.width, 3])

for key, coord, shape, feature in layout:
    tl = coord - shape / 2
    br = tl + shape

    patchwork[tl[0]:br[0],tl[1]:br[1],:] = feature

plt.imshow(torch.transpose(patchwork,0,1))

In [None]:
canvas = Image.fromarray((som.grid.cpu().numpy() * 255).astype(np.uint8))
canvas = canvas.resize((canvas.size[0]*args.magnify, canvas.size[1]*args.magnify),PIL.Image.BILINEAR)
canvas = canvas.transpose(PIL.Image.TRANSPOSE)

for i, (key, centroid, shape, feature) in tqdm(enumerate(layout)):
    canvas.save('canvas_layout/canvas_%i.png' % i)

    img = library[key]

    size = shape.numpy() * args.magnify
    corner = (centroid - shape / 2).numpy() * args.magnify
    
    target_aspect = size[0] / size[1]
    img_aspect = img.size[0] / img.size[1]
    
    if target_aspect < img_aspect:
        crop_size = (img.size[1], img.size[1] * target_aspect)
    else:
        crop_size = (img.size[0] / target_aspect, img.size[0])

    crop = transforms.Compose([
        transforms.CenterCrop(crop_size),
        transforms.Resize((size[1], size[0]))
    ])

    canvas.paste(crop(img), tuple(corner.astype(np.int)))
    
plt.imshow(canvas)

canvas.save('canvas_layout/final.png')

In [None]:
i = torch.LongTensor([[0, 1, 1], [2, 0, 2]])
v = torch.FloatTensor([3, 4, 5])
ten = torch.sparse.FloatTensor(i, v, torch.Size([2,3]))

