In [1]:
from comet_ml import OfflineExperiment
import json
import argparse
from models import *
from models.clustering import *
from utils.ali_utils import *
from utils.utils import *
from utils.utils import load_datasets
from utils.constants import Constants
from data.dataset import HoromaDataset
import torch




In [2]:



datapath = Constants.DATAPATH
parser = argparse.ArgumentParser()

path_to_model = None
config_key = 'HALI'
config = 'HALI'

with open(Constants.CONFIG_PATH, 'r') as f:
    configuration = json.load(f)[config_key]

# Parse configuration file
clustering_model = configuration['cluster_model']
encoding_model = configuration['enc_model']
batch_size = configuration['batch_size']
seed = configuration['seed']
n_epochs = configuration['n_epochs']
train_subset = configuration['train_subset']
train_split = configuration['train_split']
valid_split = configuration['valid_split']
train_labeled_split = configuration['train_labeled_split']
encode = configuration['encode']
cluster = configuration['cluster']
flattened = False  # Default
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Set all seeds for full reproducibility
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

latent_dim = configuration['Zdim']

experiment = OfflineExperiment(project_name='general',
                               workspace='timothynest',  # Replace this with appropriate comet workspace
                               offline_directory="experiments")
experiment.set_name(
    name=config + "_dim={}_overlapped={}".format(latent_dim, train_split))
experiment.log_parameters(configuration)

In [130]:

# Initialize necessary objects
clustering_model = SVMClustering(seed)

train = HoromaDataset(datapath, split=train_split, subset=train_subset,
                      flattened=flattened)
labeled = HoromaDataset(datapath, split=train_labeled_split, subset=train_subset,
                        flattened=flattened)
valid_data = HoromaDataset(
    datapath, split=valid_split, subset=train_subset, flattened=flattened)

train_label_indices = labeled.targets
valid_indices = valid_data.targets

print("Shape of training set: ", train.data.shape)
print("Shape of validation set: ", valid_data.data.shape)

Shape of training set:  (152228, 3, 32, 32)
Shape of validation set:  (252, 3, 32, 32)


array([ 6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6, 11, 11, 11, 11, 16, 16,
       16, 16, 16, 16, 16, 16, 16, 16,  5,  5,  5,  5,  5,  5,  5,  5,  5,
        5,  5,  5,  5,  5,  5,  5, 13, 13, 13, 13, 13, 13, 13, 13,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
       15, 15,  9,  9,  9,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
       10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
       10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  4,  4,  4,
        4,  4,  4,  4,  4

In [137]:
loader = DataLoader(labeled, batch_size = 6,shuffle=True)

In [141]:
img,l = next(iter(loader))

In [142]:
l.size()

torch.Size([6, 1])

ok


In [None]:

if encode:
    # Train and apply encoding model
    if encoding_model =="hali":
        Gx1,Gx2,Gz1,Gz2,Disc,z_pred1,z_pred2,optim_g,optim_d,train_loader,cuda =  initialize_hali(configs,train)
        training_loop_hali(Gz1,Gz2,Gx1,Gx2,Disc,optim_d,optim_g,train_loader,configs,experiment,cuda,z_pred1,z_pred2)
    else: #default to ALI
        Gx,Gz,Disc,z_pred,optim_g,optim_d,train_loader,cuda = initialize_ali(configs,train)
        training_loop_ali(Gz,Gx,Disc,optim_d,optim_g,train_loader,configs,experiment,cuda,z_pred)                    
else:
    if encoding_model =="hali":

        Gx1,Gx2,Gz1,Gz2,Disc,z_pred1,z_pred2,optim_g,optim_d,train_loader,cuda =  initialize_hali(configs,train)
        Gz1.load_state_dict(torch.load(configs['MODEL_PATH']+'/Gz1-'+str(configs['load_from_epoch'])+'.pth'))
        Gz2.load_state_dict(torch.load(configs['MODEL_PATH']+'/Gz2-'+str(configs['load_from_epoch'])+'.pth'))
        Gx1.load_state_dict(torch.load(configs['MODEL_PATH']+'/Gx1-'+str(configs['load_from_epoch'])+'.pth'))
        Gx2.load_state_dict(torch.load(configs['MODEL_PATH']+'/Gx2-'+str(configs['load_from_epoch'])+'.pth'))

    else: #default to ALI
        Gx,Gz,Disc,z_pred,optim_g,optim_d,train_loader,cuda = initialize_ali(configs,train)
        Gz.load_state_dict(torch.load(configs['MODEL_PATH']+'/Gz-'+str(configs['load_from_epoch'])+'.pth'))
        Gx.load_state_dict(torch.load(configs['MODEL_PATH']+'/Gx-'+str(configs['load_from_epoch'])+'.pth'))

if cluster:

    if encoding_model =="hali":
        train_enc,train_labels = get_hali_embeddings(Gz1,Gz2,labeled[train_label_indices])
        valid_enc,val_labels = get_hali_embeddings(Gz1,Gz2,labeled[valid_indices])
    # else:

    # Train and apply clustering model
    clustering_model.train(train_enc)
    cluster_labels = assign_labels_to_clusters(clustering_model, train_labeled_enc,
                                               labeled.targets[train_label_indices])
    _, accuracy, f1 = eval_model_predictions(clustering_model, valid_enc, labeled.targets[valid_indices],
                                             cluster_labels)
    experiment.log_metric('accuracy', accuracy)
    experiment.log_metric('f1-score', f1)



    # Save models
    model = {'cluster': clustering_model,'cluster_labels': cluster_labels}
    torch.save(model, configs['MODEL_PATH'] +
               str(experiment.get_key()) + '.pth')

