In [8]:
import os
import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from torchvision import datasets
from torch.utils.data import Dataset
from torchvision import transforms
import resnet
import torch.nn as nn
import torch.nn.functional as F
from scipy import stats
import gc

In [9]:
data_tr = datasets.CIFAR10('data'+ '/CIFAR10', train=True, download=True)
data_te = datasets.CIFAR10('data' + '/CIFAR10', train=False, download=True)
X_tr = data_tr.data
Y_tr = torch.from_numpy(np.array(data_tr.targets))
X_te = data_te.data
Y_te = torch.from_numpy(np.array(data_te.targets))

Files already downloaded and verified
Files already downloaded and verified


In [10]:
num_imp_per_layer = 14000
xt = np.zeros((len(Y_tr), len(np.unique(Y_tr)), num_imp_per_layer))

In [11]:
fisher = torch.zeros(xt.shape[-1], xt.shape[-1],dtype=torch.double).cuda()
iterates = torch.zeros(xt.shape[-1], xt.shape[-1], dtype=torch.double)

In [12]:
def select(X, K, fisher, iterates, lamb=1, backwardSteps=0, nLabeled=0):
    '''
    K is the number of images to be selected for labelling, 
    iterates is the fisher for images that are already labelled
    '''

    numEmbs = len(X)
    dim = X.shape[-1]
    rank = X.shape[-2]
    indsAll = []

    currentInv = torch.inverse(lamb * torch.eye(dim).cuda() + iterates.cuda() * nLabeled / (nLabeled + K))
    # what is lamb used for here?
    #X = X * np.sqrt(K / (nLabeled + K))
    fisher = fisher.cuda()

    # forward selection
    for i in range(int((backwardSteps + 1) *  K)): 
        print("Start outer loop iteration ", i)
        '''
        K corresponds to minibatch size, which is called B in the paper.
        currently we assume that backwardSteps = 0
        '''

        # xt_ = X.cuda()
        xt_ = X  
        '''
        in the calculation below, traceEst has X.shape[0] elements.
        The calculation done for computing one element of traceEst
        has no effect on the calculation done for computing other
        elements of traceEst. This suggests that we can compute  
        tracEst in chunks, rather than computing all elements in 
        one go.

        traceEst = torch.zeros(X.shape[0])
        chunkSize = 100
        for c_idx in range(0, X.shape[0], chunkSize):
            xt_chunk = xt_[c_idx * chunkSize : (c_idx + 1) * chunkSize]
            innerInv = torch.inverse(torch.eye(rank).cpu() + xt_chunk @ currentInv @ xt_chunk.transpose(1, 2)).detach()
            traceEst[c_idx * chunkSize : (c_idx + 1) * chunkSize] = torch.diagonal(
                xt_chunk @ currentInv @ fisher @ currentInv @ xt_chunk.transpose(1, 2) @ innerInv, 
                dim1=-2, 
                dim2=-1
            ).sum(-1) 
        '''


        # innerInv = torch.inverse(torch.eye(rank).cuda() + xt_ @ currentInv @ xt_.transpose(1, 2)).detach()
        # innerInv[torch.where(torch.isinf(innerInv))] = torch.sign(innerInv[torch.where(torch.isinf(innerInv))]) * np.finfo('float32').max
        
        
        
        # traceEst = torch.diagonal(
        #     xt_ @ currentInv @ fisher @ currentInv @ xt_.transpose(1, 2) @ innerInv, 
        #     dim1=-2, 
        #     dim2=-1
        # ).sum(-1)
        traceEst = np.zeros(X.shape[0]) #torch.zeros(X.shape[0]).cuda() 
        chunkSize = 100
        #print(X.shape[0])
        for c_idx in range(0, X.shape[0], chunkSize):
            if c_idx % 100 == 0:
                print(c_idx)
            xt_chunk = xt_[c_idx * chunkSize : (c_idx + 1) * chunkSize]
            xt_chunk = torch.tensor(xt_chunk).cuda() #.clone().detach()
            innerInv = torch.inverse(torch.eye(rank).cuda() + xt_chunk @ currentInv @ xt_chunk.transpose(1, 2))
            innerInv[torch.where(torch.isinf(innerInv))] = torch.sign(innerInv[torch.where(torch.isinf(innerInv))]) * np.finfo('float32').max
            traceEst[c_idx * chunkSize : (c_idx + 1) * chunkSize] = torch.diagonal(
                xt_chunk @ currentInv @ fisher @ currentInv @ xt_chunk.transpose(1, 2) @ innerInv, 
                dim1=-2, 
                dim2=-1
            ).sum(-1).detach().cpu()
        '''
        Vx^T M^-1 I(θ_L) M^-1 Vx A^-1 formula from page 5 of paper.
        currentInv corresponds to M^-1
        fisher corresponds to I(θ_L)
        xt_ corresponds to Vx^T
        innerInv corresponds to A^-1
        '''

        xt = xt_
        del xt, innerInv
        #del xt_, innerInv
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.empty_cache()
        gc.collect()

        # traceEst = traceEst.detach().cpu().numpy() # objective value in eq (5) from the paper

        dist = traceEst - np.min(traceEst) + 1e-10
        dist = dist / np.sum(dist)
        sampler = stats.rv_discrete(values=(np.arange(len(dist)), dist))
        ind = sampler.rvs(size=1)[0]
        for j in np.argsort(dist)[::-1]:
            if j not in indsAll:
                ind = j
                break

        indsAll.append(ind)  # adding a new tilde_x to the minibatch being made
        print(i, ind, traceEst[ind], flush=True)
       
        xt_ = torch.tensor(X[ind]).unsqueeze(0).cuda()
        innerInv = torch.inverse(torch.eye(rank).cuda() + xt_ @ currentInv @ xt_.transpose(1, 2)).detach()
        currentInv = (currentInv - currentInv @ xt_.transpose(1, 2) @ innerInv @ xt_ @ currentInv).detach()[0]

    # backward pruning
    for i in range(len(indsAll) - K):

        # select index for removal
        xt_ = torch.tensor(X[indsAll]).cuda()
        innerInv = torch.inverse(-1 * torch.eye(rank).cuda() + xt_ @ currentInv @ xt_.transpose(1, 2)).detach()
        traceEst = torch.diagonal(xt_ @ currentInv @ fisher @ currentInv @ xt_.transpose(1, 2) @ innerInv, dim1=-2, dim2=-1).sum(-1)
        delInd = torch.argmin(-1 * traceEst).item()
        print(i, indsAll[delInd], -1 * traceEst[delInd].item(), flush=True)


        # compute new inverse
        xt_ = torch.tensor(X[indsAll[delInd]]).unsqueeze(0).cuda()
        innerInv = torch.inverse(-1 * torch.eye(rank).cuda() + xt_ @ currentInv @ xt_.transpose(1, 2)).detach()
        currentInv = (currentInv - currentInv @ xt_.transpose(1, 2) @ innerInv @ xt_ @ currentInv).detach()[0]

        del indsAll[delInd]

    del xt_, innerInv, currentInv
    torch.cuda.empty_cache()
    gc.collect()
    return indsAll

In [13]:
n_pool = len(Y_tr)
idxs_lb = np.zeros(n_pool, dtype=bool)
idxs_unlabeled = np.arange(n_pool)[~idxs_lb]

In [14]:
chosen = select(xt[idxs_unlabeled], 1000, fisher, iterates, lamb = 1, backwardSteps = 1, nLabeled=np.sum(idxs_lb))

Start outer loop iteration  0
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
10500
10600
10700
10800
10900
11000
11100
11200
11300
11400
11500
11600
11700
11800
11900
12000
12100
12200
12300
12400
12500
12600
12700
12800
12900
13000
13100
13200
13300
13400
13500
13600
13700
13800
13900
14000
14100
14200
14300
14400
14500
14600
14700
14800
14900
15000
15100
15200
15300
15400
15500
15600
15700
15800
15900
16000
16100
16200
16300
16400
16500
16600
16700
16800
16900
17000
17100
17200
17300
17400
17500
17600
17700
17800
17900
18

KeyboardInterrupt: 

In [None]:
class DataHandler3(Dataset):
    def __init__(self, X, Y, transform= None):
        self.X = X
        self.Y = Y
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.X[index], self.Y[index]
        if self.transform is not None:
            x = Image.fromarray(x)
            x = self.transform(x)
        return x, y, index
    def __len__(self):
        return len(self.Y)

In [None]:
def get_handler(name):
    if name == 'CIFAR10':
        return DataHandler3

In [None]:
handler = get_handler('CIFAR10')

In [None]:
print(type(get_handler))

<class 'function'>


In [None]:
args={'n_epoch': 3, 'transform': transforms.Compose([ 
                     transforms.RandomCrop(32, padding=4),
                     transforms.RandomHorizontalFlip(),
                     transforms.ToTensor(),
                     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                 ]),
                 'loader_tr_args':{'batch_size': 128, 'num_workers': 1},
                 'loader_te_args':{'batch_size': 1000, 'num_workers': 1}, # change back to 1000
                 'optimizer_args':{'lr': 0.05, 'momentum': 0.3},
                 'transformTest': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])}

In [None]:
def predict_prob(X, Y, model, exp=True):
    # if type(model) == list: model = clf

    # loader_te = DataLoader(handler(X, Y, transform=args['transformTest']), shuffle=False, **args['loader_te_args'])
    # print(type(loader_te))
    loader_te = DataLoader(X_tr, batch_size = 100, shuffle = False)
    model = model.eval()
    probs = torch.zeros([len(Y), len(np.unique(Y))])
    with torch.no_grad():
        for idx, x in enumerate(loader_te):
            x = Variable(x.cuda())
            out, e1 = model(x)
            if exp: out = F.softmax(out, dim=1)
            probs[idx] = out.cpu().data
            # probs = out.cpu().data
    
    return probs

In [None]:
innerInv = torch.tensor(())

In [None]:
X_tr[:100].shape[0]

100

In [None]:
loader_te = DataLoader(X_tr, batch_size = 100, shuffle = False)
for idx,x in enumerate(loader_te):
    print(x.shape)

torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([100, 32, 32, 3])
torch.Size([10

In [None]:
model = resnet.ResNet18()

In [None]:
phat = predict_prob(X_tr[idxs_unlabeled], Y_tr[idxs_unlabeled],model)

RuntimeError: Given groups=1, weight of size [16, 3, 3, 3], expected input[100, 32, 32, 3] to have 3 channels, but got 32 channels instead

In [None]:
print('all probs: ' + 
                str(str(torch.mean(torch.max(phat, 1)[0]).item())) + ' ' + 
                str(str(torch.mean(torch.min(phat, 1)[0]).item())) + ' ' + 
                str(str(torch.mean(torch.std(phat,1)).item())), flush=True)