In [1]:
from datetime import datetime
save_index = str('x'+datetime.now().strftime("%m%d%y-%H%M%S"))

import os
os.mkdir(save_index)

In [2]:
DECODER_DIMS = {"text": 400, "logo": 400, "bp": 400, "indus": 400}
ENCODER_DIMS = {"full": 400, "res": 50, "mgr": 200, "design": 200}
K = 20

FOLDS = 4
BATCHES = 5000
ITERS = 10

ADAM_LR = 1e-5
MIN_AF = 1e-6
ANNEALING_BATCHES = 4000
NUM_PARTICLES = 1

CENTER_BP = True

DISABLE_TQDM = False

WEIGHT_DECAY = 0.

In [3]:
import sys
sys.path.insert(1, '../')

import os

import numpy as np
import torch
import torchvision.datasets as dset
import torch.nn as nn
import torchvision.transforms as transforms

import pyro
import pyro.distributions as dist
import pyro.contrib.examples.util  # patches torchvision
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro import poutine

pyro.set_rng_seed(42)

import random
random.seed(42)

import pandas as pd

from skimage import io
from sklearn import metrics
from matplotlib import pyplot as plt

from tqdm.auto import tqdm, trange

from sklearn.metrics import classification_report

from data import SplitData
from model import LogoMVAE

assert pyro.__version__.startswith('1.3.0')

from IPython.display import clear_output
import time
from IPython import display

# Helper functions:

In [4]:
def split(a, n):
    k, m = divmod(len(a), n)
    return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))

def compute_distance(z):
    b = z.reshape(z.shape[0], 1, z.shape[1])
    return np.sqrt(np.einsum('ijk, ijk->ij', z-b, z-b))

# Data Loading

First, load text data, and apply word filter. Note on notation: `tx` stands for "true x," because the model variables are also called x.

In [5]:
textdf = pd.read_csv("../../../data/web_dtfm20_binary.csv", index_col=0)
tx_text = textdf.values
seltext = tx_text.sum(0) > 0.05
tx_text = textdf.values[:,seltext]

gt20words = tx_text.sum(1) > 20
tx_text = tx_text[gt20words,:]

words = textdf.columns[seltext]
N, V = tx_text.shape

binfeats = pd.read_csv("../../../data/y_bin_all_py2.csv", index_col=0)
tx_b = binfeats.values
tx_b = tx_b[gt20words,:]
M_b = tx_b.shape[1]

catfeats = pd.read_csv("../../../data/y_mult_ncolors_py2.csv", index_col=0)

tx_c1 = catfeats.values[:,0][gt20words]
M_c1 = len(np.unique(tx_c1))
tx_c1 = np.expand_dims(tx_c1, 1)

tx_c2 = catfeats.values[:,1][gt20words]
M_c2 = len(np.unique(tx_c2))
tx_c2 = np.expand_dims(tx_c2, 1)

tx_c3 = catfeats.values[:,2][gt20words]
M_c3 = len(np.unique(tx_c3))
tx_c3 = np.expand_dims(tx_c3, 1)

tx_c4 = catfeats.values[:,3][gt20words]
M_c4 = len(np.unique(tx_c4))
tx_c4 = np.expand_dims(tx_c4, 1)

tx_c5 = catfeats.values[:,4][gt20words]
M_c5 = len(np.unique(tx_c5))
tx_c5 = np.expand_dims(tx_c5, 1)

c1_labels = np.array(["black","blue_dark","blue_light","blue_medium","brown","green_dark",
                      "green_light","grey_dark","grey_light","orange","red","red_dark",
                      "yellow"])

c2_labels = np.array(["circle","rect-oval_medium","rect-oval_large","rect-oval_thin",
                      "square","triangle"])

c3_labels = np.array(["bad_letters","bulky_hollow_geometric","circular","dense_simple_geom",
                      "detailed_circle","hollow_circle","detailed_hor","long_hor","no_mark",
                      "simple","square","thin_vert_rect","vert_narrow","detailed","thin",
                      "hor_wispy"])

c4_labels = np.array(["nochars","sans","serif"])

c5_labels = np.array(["one_color","two_colors","three_colors","many_colors"])

bp = pd.read_csv("../../../data/bp_avg_all_traits.csv", index_col=0)

bp_labels = bp.columns

tx_bp = bp.values
tx_bp = tx_bp[gt20words]
if CENTER_BP:
    tx_bp = (tx_bp - tx_bp.mean(0)) / tx_bp.std(0)
M_bp = tx_bp.shape[1]

indus = pd.read_csv("../../../data/industry_codes_b2bc.csv", index_col=0)
indus = indus.iloc[np.in1d(indus.index, bp.index),:]
indus = indus.sort_index()

tx_indus = indus.values.astype('int')
tx_indus = tx_indus[:, tx_indus.sum(0) > 9]
tx_indus = tx_indus[gt20words,:]
M_indus = tx_indus.shape[1]

indus_labels = indus.columns[indus.values.sum(0) > 9]

allnames = binfeats.index.values[gt20words]

x_sizes = {"text": V, 
           "bin": M_b, 
           "cat1": M_c1, 
           "cat2": M_c2, 
           "cat3": M_c3, 
           "cat4": M_c4, 
           "cat5": M_c5, 
           "bp": M_bp, 
           "indus": M_indus, 
           "logo": M_b + M_c1 + M_c2 + M_c3 + M_c4 + M_c5, 
           "all": V + M_b + M_c1 + M_c2 + M_c3 + M_c4 + M_c5 + M_bp + M_indus}

task_sizes = {"full": x_sizes["all"], 
              "res": x_sizes["logo"] + x_sizes["indus"], 
              "design": x_sizes["text"] + x_sizes["bp"] + x_sizes["indus"], 
              "mgr": x_sizes["all"] - x_sizes["bp"]}

noptions = np.array([M_c1, M_c2, M_c3, M_c4, M_c5])

# Training: Instantiate Model and Run

In [6]:
givens = pd.DataFrame(np.concatenate(([[K], list(DECODER_DIMS.values()), list(ENCODER_DIMS.values()), [BATCHES], [ITERS], [ADAM_LR], [ANNEALING_BATCHES], [NUM_PARTICLES], [CENTER_BP], [WEIGHT_DECAY], [FOLDS]]))).T
givens.columns = ["K", "text_dec", "logo_dec", "bp_dec", "indus_dec", "full_enc", "res_enc", "mgr_enc", "des_enc", "batches", "iters", "adam_lr", "annealing_batches", "num_particles", "center_bp", "weight_decay","folds"]

Create holdout and cross-validation subsets (just the indices):

In [7]:
if FOLDS > 1:
    holdout_indices = list(split(np.arange(N), FOLDS))
    holdout_indices.append(np.array([]))
    fold_indices = [np.setdiff1d(np.arange(N), holdout_indices[i]) for i in range(FOLDS)]
    fold_indices.append(np.arange(N))
else:
    holdout_indices = [np.array([])]

Set the KL annealing schedule (same across each fold):

In [8]:
schedule = np.linspace(MIN_AF, 1., ANNEALING_BATCHES)
# schedule = np.concatenate([np.linspace(MIN_AF, 1., round(ANNEALING_BATCHES/4.)) for _ in range(4)])

In [9]:
track_everything = dict()

In [None]:
# Run the model across all folds (sequentially):

fold = FOLDS

for trial in tqdm(range(20)):
    track_z = dict()
    track_neighbors = dict()
    track_names = dict()

    for scale_zero in ["text","logo","bp","indus","full","full2"]:

        domain_scaling = {"text": 1., 
                          "logo": 1., 
                          "bp": 1., 
                          "indus": 1.}

        if scale_zero is not "full" and scale_zero is not "full2":
            domain_scaling[scale_zero] = 1e-8

        pyro.clear_param_store()

        data = SplitData(tx_text, tx_b, tx_c1, tx_c2, tx_c3, tx_c4, tx_c5, tx_bp, tx_indus, 
                         allnames, noptions, test_indices = holdout_indices[fold])   

        lmvae = LogoMVAE(K, ENCODER_DIMS, DECODER_DIMS, x_sizes, task_sizes, use_cuda = True, domain_scaling = domain_scaling)
        optimizer = Adam({"lr": ADAM_LR}) #, "weight_decay": 0.4})
        svi = SVI(lmvae.model, lmvae.guide, optimizer, loss=Trace_ELBO(num_particles = NUM_PARTICLES))

        for i in tqdm(range(BATCHES), desc="Batches", leave=False, disable=DISABLE_TQDM):

            if i < ANNEALING_BATCHES:
                annealing_factor = schedule[i]
            else:
                annealing_factor = 1.

            data.training.shuffle()

            for j in tqdm(range(ITERS), desc="Iters", leave=False, disable=True):
                svi.step(data.training, annealing_factor)

        # Final save of stats
        lmvae.eval()

        lmvae.predict(data.training)

        z = lmvae.pred.z.z_loc.cpu().numpy()
        end_names = data.training.names
        # z_est = z_est[:,z_est.std(0) > 0.5]

        dist_z = compute_distance(z)

        test_firms = ['itw','harman-intl','lilly','goldman-sachs','21st-century-fox','facebook','gucci','old-navy','3m','actavis','mcdonalds', 'kfc']
        test_neighbors = [end_names[dist_z[np.where(end_names == test_firms[i])[0][0],:].argsort()][1:5] for i in range(len(test_firms))]
        test_dist = [np.sort(dist_z[np.where(end_names == test_firms[i])[0][0],:].round(2))[1:5] for i in range(len(test_firms))]
        formatted_neighbors = [", ".join(test_neighbors[i].tolist()) for i in range(len(test_neighbors))]

        neighbors_df = pd.DataFrame(test_neighbors)
        neighbors_df.index = test_firms
        neighbors_df.columns = np.arange(1,5)

        track_z[scale_zero] = z
        track_names[scale_zero] = end_names
        track_neighbors[scale_zero] = neighbors_df
        
    track_everything[trial] = {'z': track_z, 'neighbors': track_neighbors, 'names': track_names}

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=5000.0, style=ProgressStyle(description_wid…

In [None]:
import pickle

tracked_things = track_everything

with open("track_results_big.dat", "wb") as f:
    pickle.dump(tracked_things, f)

In [None]:
import pickle
infile = open("track_results_big.dat",'rb')
tracked_things = pickle.load(infile)
infile.close()

In [None]:
results = pd.DataFrame()

In [27]:
for s in range(20):
    full_z = tracked_things[s]['z']['full']
    dist_full = compute_distance(full_z)

    top10_full = np.array([np.argsort(dist_full[i])[1:11] for i in range(dist_full.shape[0])])
    top10_full_names = tracked_things[s]['names']['full'][top10_full]
    top10_full_names_ordered = top10_full_names[np.argsort(tracked_things[s]['names']['full'])]
    
    full2_z = tracked_things[s]['z']['full2']
    dist_full2 = compute_distance(full2_z)

    top10_full2 = np.array([np.argsort(dist_full2[i])[1:11] for i in range(dist_full2.shape[0])])
    top10_full2_names = tracked_things[s]['names']['full2'][top10_full]
    top10_full2_names_ordered = top10_full2_names[np.argsort(tracked_things[s]['names']['full2'])]

    no_logo_z = tracked_things[s]['z']['logo']
    dist_no_logo = compute_distance(no_logo_z)

    top10_no_logo = np.array([np.argsort(dist_no_logo[i])[1:11] for i in range(dist_no_logo.shape[0])])
    top10_no_logo_names = tracked_things[s]['names']['logo'][top10_no_logo]
    top10_no_logo_names_ordered = top10_no_logo_names[np.argsort(tracked_things[s]['names']['logo'])]

    no_text_z = tracked_things[s]['z']['text']
    dist_no_text = compute_distance(no_text_z)

    top10_no_text = np.array([np.argsort(dist_no_text[i])[1:11] for i in range(dist_no_text.shape[0])])
    top10_no_text_names = tracked_things[s]['names']['text'][top10_no_text]
    top10_no_text_names_ordered = top10_no_text_names[np.argsort(tracked_things[s]['names']['text'])]

    no_bp_z = tracked_things[s]['z']['bp']
    dist_no_bp = compute_distance(no_bp_z)

    top10_no_bp = np.array([np.argsort(dist_no_bp[i])[1:11] for i in range(dist_no_bp.shape[0])])
    top10_no_bp_names = tracked_things[s]['names']['bp'][top10_no_bp]
    top10_no_bp_names_ordered = top10_no_bp_names[np.argsort(tracked_things[s]['names']['bp'])]

    no_indus_z = tracked_things[s]['z']['indus']
    dist_no_indus = compute_distance(no_indus_z)

    top10_no_indus = np.array([np.argsort(dist_no_indus[i])[1:11] for i in range(dist_no_indus.shape[0])])
    top10_no_indus_names = tracked_things[s]['names']['indus'][top10_no_indus]
    top10_no_indus_names_ordered = top10_no_indus_names[np.argsort(tracked_things[s]['names']['indus'])]
    
    if s == 0:
        results = pd.DataFrame({'full2': np.array([np.isin(top10_full2_names_ordered[i], top10_full_names_ordered[i]) for i in range(dist_full2.shape[0])]).sum(1).mean(),
                                'logo': np.array([np.isin(top10_no_logo_names_ordered[i], top10_full_names_ordered[i]) for i in range(dist_full.shape[0])]).sum(1).mean(),
                                'text': np.array([np.isin(top10_no_text_names_ordered[i], top10_full_names_ordered[i]) for i in range(dist_full.shape[0])]).sum(1).mean(),
                                'bp': np.array([np.isin(top10_no_bp_names_ordered[i], top10_full_names_ordered[i]) for i in range(dist_full.shape[0])]).sum(1).mean(),
                                'indus': np.array([np.isin(top10_no_indus_names_ordered[i], top10_full_names_ordered[i]) for i in range(dist_full.shape[0])]).sum(1).mean()},
                               index = [0])
    else:
        temp = pd.DataFrame({'full2': np.array([np.isin(top10_full2_names_ordered[i], top10_full_names_ordered[i]) for i in range(dist_full2.shape[0])]).sum(1).mean(),
                             'logo': np.array([np.isin(top10_no_logo_names_ordered[i], top10_full_names_ordered[i]) for i in range(dist_full.shape[0])]).sum(1).mean(),
                             'text': np.array([np.isin(top10_no_text_names_ordered[i], top10_full_names_ordered[i]) for i in range(dist_full.shape[0])]).sum(1).mean(),
                             'bp': np.array([np.isin(top10_no_bp_names_ordered[i], top10_full_names_ordered[i]) for i in range(dist_full.shape[0])]).sum(1).mean(),
                             'indus': np.array([np.isin(top10_no_indus_names_ordered[i], top10_full_names_ordered[i]) for i in range(dist_full.shape[0])]).sum(1).mean()},
                             index = [s])
        results = results.append(temp)

In [28]:
results

Unnamed: 0,full2,logo,text,bp,indus
0,0.140227,4.668555,2.72238,3.685552,5.847025
1,0.140227,4.668555,2.72238,3.685552,5.847025
2,0.140227,4.668555,2.72238,3.685552,5.847025
3,0.140227,4.668555,2.72238,3.685552,5.847025
4,0.140227,4.668555,2.72238,3.685552,5.847025
5,0.140227,4.668555,2.72238,3.685552,5.847025
6,0.140227,4.668555,2.72238,3.685552,5.847025
7,0.140227,4.668555,2.72238,3.685552,5.847025
8,0.140227,4.668555,2.72238,3.685552,5.847025
9,0.140227,4.668555,2.72238,3.685552,5.847025


In [26]:
tracked_things[1]

{'z': {'text': array([[-0.98406446, -1.3058614 ,  0.5451891 , ...,  0.76194596,
           0.0218945 , -0.92350864],
         [ 0.23579463,  0.31325474, -1.6017869 , ...,  0.44085103,
           0.31370008,  0.89894575],
         [-0.2747689 ,  0.1701757 ,  0.6431251 , ...,  0.13568465,
           0.12741002,  1.7813652 ],
         ...,
         [ 0.74947035, -0.03597721,  1.924873  , ..., -0.8110901 ,
          -1.2328434 , -0.9932321 ],
         [-1.7213477 ,  0.14785708,  1.7714406 , ...,  0.75371766,
          -0.66321766, -0.5214674 ],
         [ 0.6025658 ,  0.02221411, -2.090066  , ..., -0.75964963,
          -0.02249207,  1.3496554 ]], dtype=float32),
  'logo': array([[-0.54542446, -1.0141017 ,  0.30068758, ..., -2.1954021 ,
           1.5750016 , -2.1292183 ],
         [-0.48202243, -1.0954057 ,  0.8479999 , ..., -2.1275017 ,
           0.01437005,  0.51638556],
         [ 0.5931289 ,  0.25300363, -0.46347943, ..., -0.02884687,
          -0.21792991, -0.12029386],
         ...