In [1]:
import numpy as np
import matplotlib as mpl
mpl.use('agg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pandas as pd
from tqdm import tqdm, trange
import torch

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
filter_w2v = False
w2v_model = 'glove-wiki-gigaword-50'

In [3]:
shape_prototypes = dict()
polys = ['triangle', 'rectangle', 'pentagon', 'hexagon', 'heptagon', 'octagon']
for name, sides in zip(polys, range(3, 9)):
    
    thetas = 2*np.pi/sides * np.arange(sides) + np.pi/2 * (sides % 2)
    
    verts = 50 * np.stack([np.cos(thetas), np.sin(thetas)]).T
        
    shape_prototypes[name] = np.round(verts, 6)

In [4]:
def make_shape(
    shape='rectangle',
    alpha=1.0, 
    color='black', 
    center=(0, 0), 
    scale=(1.0, 1.0),
    skew=(0, 0),
    hatch=None,
    rotation=0, 
    shadow=False):
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.axis('off')
    
    center = np.array(center)
    
    
    if shape == 'arrow':
        default_tail_head = np.array([(0, -50), (0, 50)])
        r = patches.FancyArrowPatch(*(default_tail_head + center), mutation_scale=50,
                                    color=color, alpha=alpha, hatch=hatch, fill=hatch is None)
    elif shape == 'ellipse':
        r = patches.Ellipse(center, 50, 50,
                            color=color, alpha=alpha, hatch=hatch, fill=hatch is None)
    elif shape in shape_prototypes:
        r = patches.Polygon(shape_prototypes[shape] + center,
                                    color=color, alpha=alpha, hatch=hatch, fill=hatch is None)
    else:
        assert False, 'Shape not recognized'
    
    t = mpl.transforms.Affine2D().rotate_deg_around(
        *center, rotation).scale(*scale).skew_deg(*skew) + ax.transData
    
    r.set_transform(t)
    ax.add_patch(r)
    
    if shadow:
        s = patches.Shadow(r, 10, -10)
        ax.add_patch(s)
    
    
    plt.xlim(-100, 100)
    plt.ylim(-100, 100)
    fig.tight_layout(pad=0)
    fig.set_size_inches(0.7, 0.7)
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    plt.close('all')
    return data
    
#     plt.show()
    

In [5]:
name = 'shapes'
colors = mpl.colors.XKCD_COLORS

if filter_w2v:
    name = 'shapes_w2v'
    from encoders import W2VEncode
    w2v_model = W2VEncode(w2v_model)
    colors = {k:v for k,v in colors.items() \
              if len(k[5:].split())==1 and k[5:] in w2v_model.vectorizer}

color_names = list(colors.keys())
scale_range = [0.5, 1.5]
skew_range = [5, 30]
rotation_range = [20, 340]
# alpha_range = [0.2, 0.9]
location_range = [-40, 40]
shadows = [True, False]
hatches = ['*', '-', 'o', '.']
shapes = polys + ['arrow', 'ellipse']

In [6]:
data = []
all_shapes = []
for _ in trange(20000):
    shape_name = np.random.choice(shapes)
    color = np.random.choice(color_names)
    r, g, b = mpl.colors.to_rgb(color)
    
    wscale = hscale = 1
    rotation = xskew = yskew = xcenter = ycenter = 0
    hatch = None
    shadow = np.random.rand() < 0.2
    
    if np.random.rand() < 0.3:
        rotation = np.random.uniform(*rotation_range)
    if np.random.rand() < 0.2:
        xskew = np.random.choice([-1, 1]) * np.random.uniform(*skew_range)
    if np.random.rand() < 0.5:
        if np.random.rand() < 0.5:
            wscale = hscale = np.random.uniform(*scale_range)
        else:
            wscale = np.random.uniform(*scale_range)
            hscale = np.random.uniform(*scale_range) 
    if np.random.rand() < 0.6:
        if np.random.rand() < 0.5:
            xcenter = np.random.uniform(*location_range)
        if np.random.rand() < 0.5:
            ycenter = np.random.uniform(*location_range)
    if np.random.rand() < 0.2:
        hatch = np.random.choice(hatches)
    
    shape = make_shape(shape=shape_name, 
                       rotation = rotation, 
                       scale=(wscale, hscale), 
                       center=(xcenter, ycenter),
                       color=color, 
                       shadow=shadow,
                       hatch=hatch,
                       skew=(xskew, yskew))
    all_shapes.append(shape)
    data.append([shape_name, color[5:], r, g, b, 
                 wscale, hscale, rotation, xskew, 
                 xcenter, ycenter, hatch, shadow])
    
df = pd.DataFrame(data, columns=[
    'shape', 'color', 'r', 'g', 'b',
    'wscale', 'hscale', 'rotation', 'skew', 
    'xcenter', 'ycenter', 'hatch', 'shadow'])

all_shapes = torch.Tensor(np.array(all_shapes))
print(all_shapes.shape)
torch.save(all_shapes, f'{name}.pt')

100%|██████████| 20000/20000 [08:03<00:00, 41.35it/s]


torch.Size([20000, 70, 70, 3])


In [7]:
df.to_csv(f'{name}.csv', index=False)

In [8]:
pd.read_csv(f'{name}.csv').head()

Unnamed: 0,shape,color,r,g,b,wscale,hscale,rotation,skew,xcenter,ycenter,hatch,shadow
0,pentagon,deep teal,0.0,0.333333,0.352941,1.0,1.0,0.0,0.0,0.0,0.0,,False
1,pentagon,ruby,0.792157,0.003922,0.278431,1.44306,1.44306,176.249624,0.0,12.474021,22.318936,,False
2,rectangle,indigo blue,0.227451,0.094118,0.694118,1.144956,1.416463,0.0,0.0,0.0,0.0,,False
3,pentagon,maize,0.956863,0.815686,0.329412,1.0,1.0,0.0,0.0,0.0,0.0,,False
4,rectangle,wine,0.501961,0.003922,0.247059,1.0,1.0,0.0,0.0,-2.050478,29.824324,-,False
