In [2]:
import argparse
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import random
from tqdm import tqdm
import torch
from torchvision import models, transforms
import PIL

from photo_library import PhotoLibrary
from SOM import SelfOrganisingMap
from feature_extraction import FeatureExtractor
from draw_samples import sample_from_unit_cube
from patch_worker import Patchworker
from utils import ravel_index, simplify_shape

ModuleNotFoundError: No module named 'torch'

In [3]:
parser = argparse.ArgumentParser(description='Make a lovely colourful collage from your photo library')
parser.add_argument('--directory', default='/home/tom/Pictures/test_pics', type=str)
parser.add_argument('--width', default=50, type=int)
parser.add_argument('--height', default=35, type=int)
parser.add_argument('--feature', default='mean_color', type=str)
parser.add_argument('--epochs', default=20, type=int)
parser.add_argument('--reuse_penalty', default=100., type=float)
parser.add_argument('--sample_size', default=10, type=int)
parser.add_argument('--border', default=20, type=int)
parser.add_argument('--magnify', default=200, type=int)
args = parser.parse_args()

library = PhotoLibrary(args.directory)
extractor = FeatureExtractor.factory(args.feature)
som = SelfOrganisingMap(shape=[args.height, args.width, extractor.feature_dim], sigma=5., eta=1.)
patch = Patchworker([args.height, args.width])

usage: ipykernel_launcher.py [-h] [--directory DIRECTORY] [--width WIDTH]
                             [--height HEIGHT] [--feature FEATURE]
                             [--epochs EPOCHS] [--reuse_penalty REUSE_PENALTY]
                             [--sample_size SAMPLE_SIZE] [--border BORDER]
                             [--magnify MAGNIFY]
ipykernel_launcher.py: error: unrecognized arguments: -f /run/user/1000/jupyter/kernel-fd948f6d-123c-44ef-a5ed-364aac7c0593.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [4]:
try:
    with open(os.path.join(args.directory, args.feature), "rb") as f:
        feature_dict = pickle.load(f)

except FileNotFoundError:
    feature_dict = extractor.process_library(library)
    with open(os.path.join(args.directory, args.feature), "wb") as f:
        pickle.dump(feature_dict, f)

# draw a subset of just the most interesting the entire photo library
sample, keys = sample_from_unit_cube(feature_dict, args.sample_size)

for _ in range(args.epochs):
    for feature in sample[torch.randperm(args.sample_size),:]:
        som.update(feature.cuda())
        
plt.imshow(som.grid)

NameError: name 'args' is not defined

In [None]:
# now find the best vibing photo and stick it at its BMU
times_used = torch.zeros(len(sample)).unsqueeze(1).cuda()
layout = []

# I can omit the occupied parts of the grid to speed this up
while not patch.full():
    dist = (sample.cuda().unsqueeze(1) - som.grid.reshape(-1,3).unsqueeze(0)).norm(dim=2)
    dist += (times_used * args.reuse_penalty) # penalise the samples that have already been placed on the grid
    dist[:, patch.occupied.reshape(-1)] = dist.max() # omit any neurons that already have a photo covering them

    # find the neuron and sample that match most closely
    winning_sample, winning_neuron = ravel_index(dist.argmin(), (som.grid.shape[0] * som.grid.shape[1]))
    target_coord = torch.Tensor(ravel_index(winning_neuron, som.grid.shape[1]))
    target_shape = simplify_shape(library[keys[winning_sample]].size)

    # add the photo to the patch work
    coord, shape = patch.add_patch(target_coord, target_shape)

    layout.append([keys[winning_sample], coord, shape])

    times_used[winning_sample] += 1

In [None]:
canvas = PIL.Image.new("RGB", tuple(np.array(som.grid.shape[:2]) * args.magnify + args.border), "white")

for key, coord, shape in layout:
    img = library[key]

    size_tup = (shape * args.magnify - args.border).numpy()
    coord_tup = (coord * args.magnify + args.border - shape / 2).numpy()

    crop = transforms.Compose([transforms.Resize(int(size_tup.min())), transforms.CenterCrop((size_tup[1],size_tup[0]))])

    canvas.paste(crop(img), coord_tup)

plt.imshow(canvas)
plt.show()

canvas.save('invite.png')