In [1]:
import numpy as np

In [2]:
from tensorflow import keras
from PIL import Image
from torchvision import models, transforms
import torch.nn as nn
import torch.optim as optim
import torch
import random
from torchvision import models, transforms
from scipy.spatial import distance_matrix
from tqdm import tqdm
import os

import json
import datetime

import time

In [3]:
import sys
root = '../../'
sys.path.append(root)
from HelpfulFunctions.batchCreation import createBatch
from HelpfulFunctions.metrics import meanAveragePrecision

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
print(device)

cpu


In [5]:
X_train = torch.tensor( np.load( root + "Features/train_features_vgg16_cifar10.npy" ) )
y_train = np.load( root + "Features/train_labels_vgg16_cifar10.npy" )

X_test = torch.tensor( np.load( root + "Features/test_features_vgg16_cifar10.npy" ) )
y_test = np.load( root + "Features/test_labels_vgg16_cifar10.npy" )

------

In [6]:
def getAPN(a, pos, dMatrix, margin):
    aP = [a]+pos
    neg_i = [i for i in range(len(dMatrix)) if i not in aP ]

    outPut = []

    for p in  pos:
        posDist = dMatrix[a][p]
        
        legal_i = [i for i in neg_i if (posDist < dMatrix[a][i]) and (dMatrix[a][i] <= posDist+margin)] # i = "global" index of the Hash
        
        legal_dist = [dMatrix[a][i] for i in legal_i]
        
        if legal_dist == []: continue
        max_dist = min(legal_dist)

        n = legal_i[ legal_dist.index(max_dist) ]
        outPut.append( (a, p, n) )

    return outPut

In [7]:
def customBatch(X_train, y_train, pos_sample_bal, batchSize, pos_label):    
    pos_i = [ i for i in range(len(X_train)) if y_train[i] == pos_label ]
    neg_i = list(set( range(len(X_train)) ) - set(pos_i))

    pos_sample_size = int(batchSize*pos_sample_bal)
    neg_sample_size = batchSize - pos_sample_size

    pos_i = random.sample( pos_i, pos_sample_size)
    neg_i = random.sample( neg_i, neg_sample_size)

    X_sample = torch.stack( [X_train[i] for i in pos_i+neg_i] )
    y_sample = [y_train[i] for i in pos_i+neg_i]
    return X_sample, y_sample

In [8]:
def tripletLoss(bits, margin, batchSize, pos_sample_bal):
    t_start = time.time()
    movingAvg_window = 2000
    noImprove_breakVal = 2000
    
    L = []
    mean_loss = 0

    model = nn.Sequential(  nn.Linear(4096,1024),
                            nn.ReLU(),
                            nn.Linear(1024, bits),
                            nn.Sigmoid()
                            )
    
    model = model.to(device)

    optimizer = optim.Adam( model.parameters() )
    criterion = nn.TripletMarginLoss(p=1, # Manhatten Distance
                                    margin=margin
                                    )

    loss_list = []
    lowest_loss = 10000
    no_improves = 0

    y_unique = list(set(y_train))
    for i in tqdm( range(20000) ):
        if batchSize != None:   xBatch, yBatch = customBatch( X_train, y_train, pos_sample_bal, batchSize, random.sample(y_unique,1)[0] )
        else:                   xBatch, yBatch = createBatch(X_train, y_train, batchSize)

        xBatch = xBatch.to(device)

        results = model(xBatch)
        results_np = results.cpu().detach().numpy()
        dMatrix = distance_matrix(results_np, results_np, p=1)

        APN_list = []

        for label in set(yBatch):
            pos_i_list = [j for j in range(len(yBatch)) if yBatch[j] == label]
            for anchor_i in pos_i_list:
                pos_i = [j for j in pos_i_list if j != anchor_i]
                APN_list += getAPN(anchor_i, pos_i, dMatrix, margin)

        if len(APN_list) > 0:
            a_list = []
            p_list = []
            n_list = []

            for apn in APN_list:
                a_list.append(results[ apn[0] ])
                p_list.append(results[ apn[1] ])
                n_list.append(results[ apn[2] ])
            
            # === Improve Model ===
            optimizer.zero_grad()
            loss = criterion( torch.stack(a_list).to(device), 
                              torch.stack(p_list).to(device), 
                              torch.stack(n_list).to(device) )
            loss.backward()
            optimizer.step()

            loss_list.append( float(loss) )
            L.append( float(loss) )
        
        if (i > movingAvg_window):
            while len(L) > movingAvg_window:
                L = L[1:]

            mean_loss = sum(L) / len(L)

            if mean_loss < lowest_loss:
                # print(f"\033[92m {'+'}\033[00m", float(loss))
                lowest_loss = mean_loss
                no_improves = 0
            else:
                # print(f"\033[91m {'-'}\033[00m", float(loss))
                no_improves += 1
            
            if no_improves > noImprove_breakVal: break

        if (i % 500) == 0: 
            #print(f"Make Bacthes + get results: {t2-t1}\n", f"Distance Matrix: {t3-t2}\n", f"MAke APN List: {t4-t3}\n", f"Back Propegate: {t5-t4}")
            
            print(i, mean_loss , no_improves)
            

    hash_train = (model(X_train.to(device)).cpu().detach().numpy() > 0.5).astype(int)
    hash_test = (model(X_test.to(device)).cpu().detach().numpy() > 0.5).astype(int)
    map = meanAveragePrecision( hash_test , hash_train, y_test, y_train)

    delta_t = time.time() - t_start

    return model, loss_list, map, i, delta_t

#bits = 32
#margin = random.randint(0,5) # int(bits / 10)
#batchSize = 100
#pos_sample_bal = 0.50

#model, loss_list, map, i = tripletLoss(bits, margin, batchSize, pos_sample_bal)

In [9]:
def hpo():
    bits = 32
    margin          = random.uniform(0,5)
    batchSize       = random.randint(20,120)
    pos_sample_bal  = random.uniform(0.05, 0.5)

    if margin == 0: margin = 1

    hp = {"bits":bits,"margin":margin, "batchSize": batchSize, "pos_sample_bal":pos_sample_bal}
    print(hp)

    model, loss_list, map, i, delta_t = tripletLoss(bits, margin, batchSize, pos_sample_bal)

    res = {}
    res["hp"] = hp
    res["loss_list"] = loss_list
    res["map"] = map
    res["i"] = i
    res["delta_t"] = delta_t


    now = str(datetime.datetime.now())
    now = now[:now.index(".")].replace(" ", "_").replace(":", "-")

    newpath = f"{root}Results/HPO/TripletLoss/{os.environ['COMPUTERNAME']}"
    if not os.path.exists(newpath):
        os.makedirs(newpath)

    filePath = f"{root}Results/HPO/TripletLoss/{os.environ['COMPUTERNAME']}/{now}.json"

    with open(filePath, "w") as fp:
        json.dump(res , fp)

In [10]:
while True:
    hpo()

{'bits': 32, 'margin': 1.2986597551323598, 'batchSize': 34, 'pos_sample_bal': 0.4820865008961892}


  0%|          | 3/20000 [00:00<11:47, 28.27it/s]

0 0 0


  3%|▎         | 507/20000 [00:16<10:21, 31.38it/s]

500 0 0


  5%|▌         | 1007/20000 [00:32<10:03, 31.45it/s]

1000 0 0


  8%|▊         | 1507/20000 [00:48<09:51, 31.27it/s]

1500 0 0


 10%|█         | 2003/20000 [01:04<09:55, 30.21it/s]

2000 0 0


 13%|█▎        | 2506/20000 [01:20<09:20, 31.19it/s]

2500 1.2194434243887662 0


 15%|█▌        | 3004/20000 [01:36<09:00, 31.46it/s]

3000 1.193735620304942 0


 18%|█▊        | 3506/20000 [01:52<08:47, 31.25it/s]

3500 1.1789282300174235 0


 20%|██        | 4007/20000 [02:08<08:34, 31.07it/s]

4000 1.1735587587058545 0


 23%|██▎       | 4506/20000 [02:24<08:24, 30.72it/s]

4500 1.1794432884156705 283


 25%|██▌       | 5006/20000 [02:40<07:44, 32.28it/s]

5000 1.1800064705610276 783


 28%|██▊       | 5506/20000 [02:56<07:50, 30.83it/s]

5500 1.1676735567450522 40


 30%|███       | 6004/20000 [03:12<07:13, 32.26it/s]

6000 1.16099528414011 237


 33%|███▎      | 6505/20000 [03:28<07:15, 31.01it/s]

6500 1.1456386283785105 0


 35%|███▌      | 7005/20000 [03:44<06:43, 32.23it/s]

7000 1.1269530166685582 66


 38%|███▊      | 7505/20000 [04:00<06:41, 31.11it/s]

7500 1.1358912203907967 566


 40%|████      | 8005/20000 [04:15<06:18, 31.67it/s]

8000 1.114435644492507 51


 43%|████▎     | 8504/20000 [04:31<06:01, 31.76it/s]

8500 1.1172093024551868 551


 45%|████▌     | 9006/20000 [04:48<05:50, 31.41it/s]

9000 1.1050064563006163 0


 48%|████▊     | 9506/20000 [05:04<05:22, 32.55it/s]

9500 1.0830490877181291 165


 50%|█████     | 10006/20000 [05:19<05:20, 31.14it/s]

10000 1.1037763662338256 665


 53%|█████▎    | 10506/20000 [05:36<05:01, 31.52it/s]

10500 1.1016120964586735 1165


 55%|█████▌    | 11006/20000 [05:52<04:44, 31.64it/s]

11000 1.1138787551224232 1665


 57%|█████▋    | 11336/20000 [06:02<04:37, 31.25it/s]
100%|██████████| 10000/10000 [01:14<00:00, 133.52it/s]
