In [1]:
import torch
import numpy as np
import time
import gc
from scipy import stats
from torch.multiprocessing import Pool
dim = 10 * 100 #X.shape[-1]
rank = 10 #X.shape[-2]
fisher = torch.rand((dim, dim)).cuda()
currentInv = torch.rand((dim, dim)).cuda()
xt = torch.rand((100, rank, dim)).cuda()
n = 10 #nquery
init = []
n_pool = 100 #debug size
idxs_lb = np.zeros(n_pool, dtype=bool)
## using 2 GPUs
chunkSize = 50
NUM_GPUS = torch.cuda.device_count()
torch.multiprocessing.set_start_method('spawn', force=True)

In [2]:
def betterSlice(num_gpus, gpu_id, total_len):
    upper_bound = int(total_len/(num_gpus-gpu_id))
    lower_bound = int(upper_bound-(total_len/num_gpus))
    return slice(lower_bound, upper_bound)

def trace_for_chunk(xt_, rank, num_gpus, chunkSize, currentInv, fisher, total_len, gpu_id):
    traceEst = torch.zeros((total_len//num_gpus))
    for c_idx in range(len(xt_), chunkSize):
        xt_chunk = xt_[c_idx : c_idx + chunkSize]
        xt_chunk = xt_chunk.cuda(gpu_id)
        fisher = fisher.cuda(gpu_id)
        currentInv = currentInv.cuda(gpu_id)
        # with torch.no_grad():
        innerInv = torch.inverse(torch.eye(rank).cuda(gpu_id) + xt_chunk @ currentInv @ xt_chunk.transpose(1, 2)) 
        traceEst[c_idx : c_idx + chunkSize] = torch.diagonal(
            xt_chunk @ currentInv @ fisher @ currentInv @ xt_chunk.transpose(1, 2) @ innerInv,
            dim1=-2,
            dim2=-1
        ).sum(-1).detach().cpu()
    return traceEst

def select(X, K, fisher, iterates, savefile, alg, lamb=1, backwardSteps=0, nLabeled=0, chunkSize=200):
    '''
    K is the number of images to be selected for labelling, 
    iterates is the fisher for images that are already labelled
    '''
    time_begin_select = time.time()
    numEmbs = len(X)
    dim = X.shape[-1]
    rank = X.shape[-2]
    indsAll = []
    currentInv = torch.rand((dim, dim)).cuda()
    #currentInv = torch.inverse(lamb * torch.eye(dim).cuda() + iterates.cuda() * nLabeled / (nLabeled + K))
    X = X * np.sqrt(K / (nLabeled + K))
    xt_ = X
    chunkSize = min(X.shape[0], chunkSize)
    total_len = xt_.shape[0]
    NUM_GPUS = torch.cuda.device_count()
    fishers = [fisher.clone().detach().cuda(x) for x in range(NUM_GPUS)]
    xts = [X[betterSlice(NUM_GPUS, x, total_len)].clone().detach().cuda(x) for x in range(NUM_GPUS)]
    torch.multiprocessing.set_start_method('spawn', force=True)
    distStats = []

    with Pool(processes=NUM_GPUS) as pool:
        for i in range(int((backwardSteps + 1) *  K)):
            cInvs = [currentInv.clone().detach().cuda(x) for x in range(NUM_GPUS)]
            args = [(xts[x], rank, NUM_GPUS, chunkSize, cInvs[x], fishers[x], total_len, x) for x in range(NUM_GPUS)]
            tE = pool.starmap(trace_for_chunk, args)
            traceEst = tE[0]
            for j in range(1,NUM_GPUS):
                traceEst = torch.cat((traceEst, tE[j]))
            torch.cuda.empty_cache()
            gc.collect()
            torch.cuda.empty_cache()
            gc.collect()
            traceEst = traceEst.detach().cpu().numpy()

            dist = traceEst - np.min(traceEst) + 1e-10
            dist = dist / np.sum(dist)
            distStats.append([np.min(dist), np.max(dist), np.std(dist)])
            sampler = stats.rv_discrete(values=(np.arange(len(dist)), dist))
            ind = sampler.rvs(size=1)[0]
            for j in np.argsort(dist)[::-1]:
                if j not in indsAll:
                    ind = j
                    break

            indsAll.append(ind)
            temp_xt = X[ind].unsqueeze(0).cuda()
            innerInv = torch.inverse(torch.eye(rank).cuda(0) + temp_xt @ cInvs[0] @ temp_xt.transpose(1, 2)).detach()
            currentInv = (cInvs[0] - cInvs[0] @ temp_xt.transpose(1, 2) @ innerInv @ temp_xt @ cInvs[0]).detach()[0]
    
    for i in range(len(indsAll) - K):
        # select index for removal
        xt_ = torch.tensor(X[indsAll]).cuda()
        innerInv = torch.inverse(-1 * torch.eye(rank).cuda() + xt_ @ currentInv @ xt_.transpose(1, 2)).detach()
        traceEst = torch.diagonal(xt_ @ currentInv @ fisher @ currentInv @ xt_.transpose(1, 2) @ innerInv, dim1=-2, dim2=-1).sum(-1)
        delInd = torch.argmin(-1 * traceEst).item()

        # compute new inverse
        xt_ = torch.tensor(X[indsAll[delInd]]).unsqueeze(0).cuda()
        innerInv = torch.inverse(-1 * torch.eye(rank).cuda() + xt_ @ currentInv @ xt_.transpose(1, 2)).detach()
        currentInv = (currentInv - currentInv @ xt_.transpose(1, 2) @ innerInv @ xt_ @ currentInv).detach()[0]

        del indsAll[delInd]

    del xt_, innerInv, currentInv, tE, traceEst
    #save_dist_stats(distStats, savefile, alg)
    torch.cuda.empty_cache()
    gc.collect()
    time_end_select = time.time()
    #logging.debug("Select took" + str(time_end_select - time_begin_select) + "seconds")
    return indsAll

In [None]:
def select(X, K, fisher, iterates, lamb=1, backwardSteps=0, nLabeled=0):

    numEmbs = len(X)
    indsAll = []
    dim = X.shape[-1]
    rank = X.shape[-2]

    currentInv = torch.inverse(lamb * torch.eye(dim).cuda() + iterates.cuda() * nLabeled / (nLabeled + K))
    X = X * np.sqrt(K / (nLabeled + K))
    fisher = fisher.cuda()

    # forward selection
    for i in range(int((backwardSteps + 1) *  K)):

        xt_ = X.cuda() 
        innerInv = torch.inverse(torch.eye(rank).cuda() + xt_ @ currentInv @ xt_.transpose(1, 2)).detach()
        innerInv[torch.where(torch.isinf(innerInv))] = torch.sign(innerInv[torch.where(torch.isinf(innerInv))]) * np.finfo('float32').max
        traceEst = torch.diagonal(xt_ @ currentInv @ fisher @ currentInv @ xt_.transpose(1, 2) @ innerInv, dim1=-2, dim2=-1).sum(-1)

        xt = xt_.cpu()
        del xt, innerInv
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.empty_cache()
        gc.collect()

        traceEst = traceEst.detach().cpu().numpy()

        dist = traceEst - np.min(traceEst) + 1e-10
        dist = dist / np.sum(dist)
        sampler = stats.rv_discrete(values=(np.arange(len(dist)), dist))
        ind = sampler.rvs(size=1)[0]
        for j in np.argsort(dist)[::-1]:
            if j not in indsAll:
                ind = j
                break

        indsAll.append(ind)
        print(i, ind, traceEst[ind], flush=True)
       
        xt_ = X[ind].unsqueeze(0).cuda()
        innerInv = torch.inverse(torch.eye(rank).cuda() + xt_ @ currentInv @ xt_.transpose(1, 2)).detach()
        currentInv = (currentInv - currentInv @ xt_.transpose(1, 2) @ innerInv @ xt_ @ currentInv).detach()[0]

    # backward pruning
    for i in range(len(indsAll) - K):

        # select index for removal
        xt_ = X[indsAll].cuda()
        innerInv = torch.inverse(-1 * torch.eye(rank).cuda() + xt_ @ currentInv @ xt_.transpose(1, 2)).detach()
        traceEst = torch.diagonal(xt_ @ currentInv @ fisher @ currentInv @ xt_.transpose(1, 2) @ innerInv, dim1=-2, dim2=-1).sum(-1)
        delInd = torch.argmin(-1 * traceEst).item()
        print(i, indsAll[delInd], -1 * traceEst[delInd].item(), flush=True)


        # compute new inverse
        xt_ = X[indsAll[delInd]].unsqueeze(0).cuda()
        innerInv = torch.inverse(-1 * torch.eye(rank).cuda() + xt_ @ currentInv @ xt_.transpose(1, 2)).detach()
        currentInv = (currentInv - currentInv @ xt_.transpose(1, 2) @ innerInv @ xt_ @ currentInv).detach()[0]

        del indsAll[delInd]

    del xt_, innerInv, currentInv
    torch.cuda.empty_cache()
    gc.collect()
    return indsAll

In [3]:

chosen = select(xt, n, fisher, init, "savefile", "FISH", lamb=1, backwardSteps=0, nLabeled=np.sum(idxs_lb), chunkSize=chunkSize)

In [1]:
""" def betterSlice(num_gpus, gpu_id, total_len):
    upper_bound = int(total_len/(num_gpus-gpu_id))
    lower_bound = int(upper_bound-(total_len/num_gpus))
    return slice(lower_bound, upper_bound) """

def betterSlice(num_gpus, gpu_id, total_len):
    upper_bound = (1+gpu_id)*total_len/num_gpus
    lower_bound = upper_bound-(total_len/num_gpus)
    return slice(int(lower_bound), int(upper_bound))

#for i in range(1,5):
    #print(f"-----------{i} gpus-----------")
i = 3
for k in range(24000,0,-2000):
    print(f"-------size {k}-------")
    for j in range(i):
        sliced = betterSlice(i,j,k)
        print(f"gpu {j}: ({sliced})")

-------size 24000-------
gpu 0: (slice(0, 8000, None))
gpu 1: (slice(8000, 16000, None))
gpu 2: (slice(16000, 24000, None))
-------size 22000-------
gpu 0: (slice(0, 7333, None))
gpu 1: (slice(7333, 14666, None))
gpu 2: (slice(14666, 22000, None))
-------size 20000-------
gpu 0: (slice(0, 6666, None))
gpu 1: (slice(6666, 13333, None))
gpu 2: (slice(13333, 20000, None))
-------size 18000-------
gpu 0: (slice(0, 6000, None))
gpu 1: (slice(6000, 12000, None))
gpu 2: (slice(12000, 18000, None))
-------size 16000-------
gpu 0: (slice(0, 5333, None))
gpu 1: (slice(5333, 10666, None))
gpu 2: (slice(10666, 16000, None))
-------size 14000-------
gpu 0: (slice(0, 4666, None))
gpu 1: (slice(4666, 9333, None))
gpu 2: (slice(9333, 14000, None))
-------size 12000-------
gpu 0: (slice(0, 4000, None))
gpu 1: (slice(4000, 8000, None))
gpu 2: (slice(8000, 12000, None))
-------size 10000-------
gpu 0: (slice(0, 3333, None))
gpu 1: (slice(3333, 6666, None))
gpu 2: (slice(6666, 10000, None))
-------size 80