In [1]:
import os
import sys
sys.path.insert(1, '../')

DECODER_DIMS = {"text": 400, "bin": 400, "cat": 400, "bp": 400, "indus": 400}
ENCODER_DIMS = {"full": 400, "res": 200, "mgr": 200, "design": 200}

K = 20

FOLDS = 4
BATCHES = 4000
ITERS = 10

ADAM_LR = 0.00001
MIN_AF = 1e-6
ANNEALING_BATCHES = 3500
NUM_PARTICLES = 1

CENTER_BP = True

WEIGHT_DECAY = 0.

DISABLE_TQDM = False

In [2]:
import numpy as np
import torch
import torchvision.datasets as dset
import torch.nn as nn
import torchvision.transforms as transforms

import pyro
import pyro.distributions as dist
import pyro.contrib.examples.util  # patches torchvision
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro import poutine

pyro.set_rng_seed(42)

import random
random.seed(42)

import pandas as pd

from skimage import io
from sklearn import metrics
from matplotlib import pyplot as plt

from tqdm.auto import tqdm, trange

from sklearn.metrics import classification_report

from data import SplitData
from model import LogoMVAE

assert pyro.__version__.startswith('1.3.0')



# # Helper functions:


def split(a, n):
    k, m = divmod(len(a), n)
    return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))

def compute_distance(z):
    b = z.reshape(z.shape[0], 1, z.shape[1])
    return np.sqrt(np.einsum('ijk, ijk->ij', z-b, z-b))





## Data Loading

# First, load text data, and apply word filter. Note on notation: `tx` stands for "true x," because the model variables are also called x.

textdf = pd.read_csv("../../../data/web_dtfm20_binary.csv", index_col=0)
tx_text = textdf.values
seltext = tx_text.sum(0) > 0.05
tx_text = textdf.values[:,seltext]

gt20words = tx_text.sum(1) > 20
tx_text = tx_text[gt20words,:]

words = textdf.columns[seltext]
N, V = tx_text.shape

binfeats = pd.read_csv("../../../data/y_bin_all_py2.csv", index_col=0)
tx_b = binfeats.values
tx_b = tx_b[gt20words,:]
M_b = tx_b.shape[1]

catfeats = pd.read_csv("../../../data/y_mult_ncolors_py2.csv", index_col=0)

tx_c1 = catfeats.values[:,0][gt20words]
M_c1 = len(np.unique(tx_c1))
tx_c1 = np.expand_dims(tx_c1, 1)

tx_c2 = catfeats.values[:,1][gt20words]
M_c2 = len(np.unique(tx_c2))
tx_c2 = np.expand_dims(tx_c2, 1)

tx_c3 = catfeats.values[:,2][gt20words]
M_c3 = len(np.unique(tx_c3))
tx_c3 = np.expand_dims(tx_c3, 1)

tx_c4 = catfeats.values[:,3][gt20words]
M_c4 = len(np.unique(tx_c4))
tx_c4 = np.expand_dims(tx_c4, 1)

tx_c5 = catfeats.values[:,4][gt20words]
M_c5 = len(np.unique(tx_c5))
tx_c5 = np.expand_dims(tx_c5, 1)

c1_labels = np.array(["black","blue_dark","blue_light","blue_medium","brown","green_dark",
                      "green_light","grey_dark","grey_light","orange","red","red_dark",
                      "yellow"])

c2_labels = np.array(["circle","rect-oval_medium","rect-oval_large","rect-oval_thin",
                      "square","triangle"])

c3_labels = np.array(["bad_letters","bulky_hollow_geometric","circular","dense_simple_geom",
                      "detailed_circle","hollow_circle","detailed_hor","long_hor","no_mark",
                      "simple","square","thin_vert_rect","vert_narrow","detailed","thin",
                      "hor_wispy"])

c4_labels = np.array(["nochars","sans","serif"])

c5_labels = np.array(["one_color","two_colors","three_colors","many_colors"])

bp = pd.read_csv("../../../data/bp_avg_all_traits.csv", index_col=0)

bp_labels = bp.columns

tx_bp = bp.values
tx_bp = tx_bp[gt20words]
if CENTER_BP:
    tx_bp = (tx_bp - tx_bp.mean(0)) / tx_bp.std(0)
M_bp = tx_bp.shape[1]

indus = pd.read_csv("../../../data/industry_codes_b2bc.csv", index_col=0)
indus = indus.iloc[np.in1d(indus.index, bp.index),:]
indus = indus.sort_index()

tx_indus = indus.values.astype('int')
tx_indus = tx_indus[:, tx_indus.sum(0) > 9]
tx_indus = tx_indus[gt20words,:]
M_indus = tx_indus.shape[1]

indus_labels = indus.columns[indus.values.sum(0) > 9]

allnames = binfeats.index.values[gt20words]

x_sizes = {"text": V, 
           "bin": M_b, 
           "cat1": M_c1, 
           "cat2": M_c2, 
           "cat3": M_c3, 
           "cat4": M_c4, 
           "cat5": M_c5, 
           "bp": M_bp, 
           "indus": M_indus, 
           "logo": M_b + M_c1 + M_c2 + M_c3 + M_c4 + M_c5, 
           "all": V + M_b + M_c1 + M_c2 + M_c3 + M_c4 + M_c5 + M_bp + M_indus}

task_sizes = {"full": x_sizes["all"], 
              "res": x_sizes["logo"] + x_sizes["indus"], 
              "design": x_sizes["text"] + x_sizes["bp"] + x_sizes["indus"], 
              "mgr": x_sizes["all"] - x_sizes["bp"]}

noptions = np.array([M_c1, M_c2, M_c3, M_c4, M_c5])


## Training: Instantiate Model and Run

givens = pd.DataFrame(np.concatenate(([[K], list(DECODER_DIMS.values()), list(ENCODER_DIMS.values()), [BATCHES], [ITERS], [ADAM_LR], [ANNEALING_BATCHES], [NUM_PARTICLES], [CENTER_BP], [WEIGHT_DECAY]]))).T
givens.columns = ["K", "text_dec", "bin_dec", "cat_dec", "bp_dec", "indus_dec", "full_enc", "logo_enc", "mgr_enc", "des_enc", "batches", "iters", "adam_lr", "annealing_batches", "num_particles", "center_bp", "weight_decay"]


# Create holdout and cross-validation subsets (just the indices):

if FOLDS > 1:
    holdout_indices = list(split(np.arange(N), FOLDS))
    holdout_indices.append(np.array([]))
    fold_indices = [np.setdiff1d(np.arange(N), holdout_indices[i]) for i in range(FOLDS)]
    fold_indices.append(np.arange(N))
else:
    holdout_indices = [np.array([])]
    
    
# Set the KL annealing schedule (same across each fold):
schedule = np.linspace(MIN_AF, 1., ANNEALING_BATCHES)
# schedule = np.concatenate([np.linspace(MIN_AF, 1., round(ANNEALING_BATCHES/4.)) for _ in range(4)])

In [3]:
track_training = []
track_test = []
track_mgr_bp = []
track_des_bin = []
track_res_bp = []

In [None]:
# Run the model across all folds (sequentially):
for fold in tqdm(range(1), desc="Folds", disable=DISABLE_TQDM):
    
    pyro.clear_param_store()

    data = SplitData(tx_text, tx_b, tx_c1, tx_c2, tx_c3, tx_c4, tx_c5, tx_bp, tx_indus, 
                     allnames, noptions, test_indices = holdout_indices[fold])   
    
    has_test = hasattr(data, 'test')
    if has_test:
        data.test.make_torch()
        
    lmvae = LogoMVAE(K, ENCODER_DIMS, DECODER_DIMS, x_sizes, task_sizes, use_cuda=True)
    scheduler = pyro.optim.ExponentialLR({'optimizer': torch.optim.Adam, 'optim_args': {'lr': 1e-7}, 'gamma': 1.1})
    svi = SVI(lmvae.model, lmvae.guide, scheduler, loss=Trace_ELBO(num_particles = NUM_PARTICLES))

    for i in tqdm(range(BATCHES), desc="Batches", leave=False, disable=DISABLE_TQDM):

        if i < ANNEALING_BATCHES:
            annealing_factor = schedule[i]
        else:
            annealing_factor = 1.

        data.training.shuffle()

        for j in tqdm(range(ITERS), desc="Iters", leave=False, disable=True):
            svi.step(data.training, annealing_factor)
            
        if (i % 50 == 0) or (i == BATCHES-1):
            track_training.append(svi.evaluate_loss(data.training, annealing_factor))
            if has_test: 
                track_test.append(svi.evaluate_loss(data.test, annealing_factor))
                
                lmvae.eval();
                
                # Predictions for res task:
                lmvae.predict(data.test, network = "res")
                track_res_bp.append(lmvae.pred.metrics.bp_mse.features.mean())
                
                # Predictions for des task:        
                lmvae.predict(data.test, network = "des")
                track_des_bin.append(lmvae.pred.metrics.bin_report['macro avg']['f1-score'])
                
                # Predictions for mgr task:
                lmvae.predict(data.test, network = "mgr")
                track_mgr_bp.append(lmvae.pred.metrics.bp_mse.features.mean())
                
                lmvae.train();
            
            scheduler.step()                

HBox(children=(FloatProgress(value=0.0, description='Folds', max=1.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=4000.0, style=ProgressStyle(description_wid…