# Neighbour Extraction using Gram Matrices

---

In [1]:
import sys
print(sys.executable)
print(sys.version)

/home/kshitij98/getNeighbours/venv/bin/python3
3.5.2 (default, Nov 12 2018, 13:43:14) 
[GCC 5.4.0 20160609]


In [2]:
from ipynb.fs.full.Helper import getDataLoader, getNames, dist
from ipynb.fs.full.GramMatrix import convertModel, GramMatrixLayer
from ipynb.fs.full.LabelDataset import createDirectories

import torch
import torch.nn as nn

%matplotlib inline

---

## Feature Extraction

In [3]:
loader = getDataLoader('/scratch/bam_subset_2_0_labeled', batch_size=4, shuffle=False, num_workers=4, testing=False)

In [4]:
import torchvision.models as models
vgg19 = models.vgg19(pretrained=True)

In [5]:
gramMatrixLayers = ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1']
gramMatrixWeights = [1, 1, 1, 1, 1]
vgg19, model, gram_matrices = convertModel(vgg19, gramMatrixLayers, gramMatrixWeights, testing=False)

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

model.eval()
model.to(device)

cuda:0


Sequential(
  (conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1_1): ReLU()
  (gram_matrix1_1): GramMatrixLayer(λ=1)
  (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1_2): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2_1): ReLU()
  (gram_matrix2_1): GramMatrixLayer(λ=1)
  (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2_2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3_1): ReLU()
  (gram_matrix3_1): GramMatrixLayer(λ=1)
  (conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3_2): ReLU()
  (conv3_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3_3):

In [7]:
dataIter = iter(loader)

(data, classes), names = dataIter.next()
data = data.cuda()
out = model(data)
G = []
for layer in gram_matrices:
    G.append(layer.gramMatrix)
G = torch.cat(G, 1)

a, D = G.size()

print(D)

610304


Find a ‘safe’ number of components to randomly project to

The distortion introduced by a random projection p only changes the distance between two points by a factor (1 +- eps) in an euclidean space with good probability.

In [8]:
from sklearn.random_projection import johnson_lindenstrauss_min_dim

# TODO: Remove hardcoded dataset size
K = johnson_lindenstrauss_min_dim(40474, eps=0.3)

print(K)

1178


In [9]:
from sklearn.random_projection import gaussian_random_matrix

RPM = gaussian_random_matrix(K, D)
RPM = RPM.transpose()
print(RPM.shape)

(610304, 1178)


In [10]:
import os

if not os.path.exists('/scratch/kshitij98'):
    print("Creating", '/scratch/kshitij98')
    os.makedirs('/scratch/kshitij98')

RPM = torch.from_numpy(RPM)
RPM = RPM.cuda()
RPM = RPM.float()
print(RPM.shape)

torch.save(RPM, "/scratch/kshitij98/rpm")

Creating /scratch/kshitij98
torch.Size([610304, 1178])


In [11]:
import time

createDirectories('/scratch/bam_subset_2_0_features/')

dataIter = iter(loader)
t = time.time()

# TODO: Remove hardcoded dataset size
i = 1
while i <= 30250:
    (data, classes), names = dataIter.next()
    data = data.cuda()
    out = model(data)
    G = []
    for layer in gram_matrices:
        G.append(layer.gramMatrix)
    G = torch.cat(G, 1)
#     for j, gm in enumerate(G):
#         print(names[j])
#     break
#         torch.save(gm, names[j].replace('imdb_dataset', 'imdb_dataset_features'))
    G = torch.mm(G, RPM)
    for j, gm in enumerate(G):
#         print(names[j])
        torch.save(gm, names[j].replace('bam_subset_2_0_labeled', 'bam_subset_2_0_features'))
    print(i, "\tETA: ", ((time.time() - t) / ((i) * 4)) * (121000 - ((i) * 4)) * (1 / 60), "minutes", end='\r')
    i += 1
    

Creating /scratch/bam_subset_2_0_features/
121000
Creating /scratch/bam_subset_2_0_features/0
Creating /scratch/bam_subset_2_0_features/3
Creating /scratch/bam_subset_2_0_features/2
Creating /scratch/bam_subset_2_0_features/10
Creating /scratch/bam_subset_2_0_features/5
Creating /scratch/bam_subset_2_0_features/8
Creating /scratch/bam_subset_2_0_features/7
Creating /scratch/bam_subset_2_0_features/6
Creating /scratch/bam_subset_2_0_features/4
Creating /scratch/bam_subset_2_0_features/9
Creating /scratch/bam_subset_2_0_features/1
671 	ETA:  14.573956920019265 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


1206 	ETA:  14.314960599248748 minutes

  " Skipping tag %s" % (size, len(data), tag))


1674 	ETA:  14.06494447714966 minutess

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping 

1971 	ETA:  13.78783342721558 minutess	ETA:  13.829463022981479 minutes

  tag, len(values)))


2097 	ETA:  13.68057325357967 minutess

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


2840 	ETA:  12.635172978659192 minutes12.675000954180776 minutes

  'to RGBA images')


2899 	ETA:  12.567464083496887 minutes



3082 	ETA:  12.40185014265532 minutess

  'to RGBA images')


3143 	ETA:  12.367522165658732 minutes

  'to RGBA images')


3308 	ETA:  12.208725704823902 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


3571 	ETA:  11.978544079349408 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


3591 	ETA:  11.959254000090207 minutes

  'to RGBA images')


4808 	ETA:  10.99172058000673 minutess

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping 

5150 	ETA:  10.770272137509194 minutes



5614 	ETA:  10.350667872087447 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


5647 	ETA:  10.334932122879456 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


5729 	ETA:  10.294462772423374 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


5945 	ETA:  10.19821216349579 minutess



6153 	ETA:  10.101260050943804 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


6260 	ETA:  10.050376520814478 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


6697 	ETA:  9.823476006704533 minutess

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


6761 	ETA:  9.789089619733112 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


6762 	ETA:  9.79080730838018 minutes6763 	ETA:  9.789837971767271 minutes6764 	ETA:  9.79011167930518 minutes6765 	ETA:  9.789144002809758 minutes6766 	ETA:  9.788838692512027 minutes6767 	ETA:  9.787876769942748 minutes6768 	ETA:  9.787631885919916 minutes6769 	ETA:  9.78667244729976 minutes6770 	ETA:  9.785781792218554 minutes6771 	ETA:  9.784823029065233 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


7062 	ETA:  9.638034288608397 minutes

  " Skipping tag %s" % (size, len(data), tag))


7258 	ETA:  9.540469372229786 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


7377 	ETA:  9.482064543256897 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping 

8426 	ETA:  8.920953475381106 minutes



8523 	ETA:  8.88126330097894 minutess

  " Skipping tag %s" % (size, len(data), tag))


9851 	ETA:  8.312619754127192 minutes

  " Skipping tag %s" % (size, len(data), tag))


10020 	ETA:  8.239880275520102 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


10383 	ETA:  8.08384989381445 minutess

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


10434 	ETA:  8.063447473940473 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping 

17077 	ETA:  5.2029096361470675 minutes	ETA:  7.652981995141574 minutes 6.849393722420447 minutes6.131565811488263 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


17767 	ETA:  4.989915046600912 minutess

  " Skipping tag %s" % (size, len(data), tag))


19577 	ETA:  4.165499843577375 minutess

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


19578 	ETA:  4.1650549600108695 minutes19579 	ETA:  4.164607664474982 minutes19580 	ETA:  4.1641591489984755 minutes19581 	ETA:  4.163709533196305 minutes19582 	ETA:  4.163260802659178 minutes19583 	ETA:  4.162811420232453 minutes19584 	ETA:  4.162362492088967 minutes19585 	ETA:  4.161920126570346 minutes19586 	ETA:  4.1614713115150685 minutes19587 	ETA:  4.161012292219469 minutes19588 	ETA:  4.160562126151342 minutes19589 	ETA:  4.160113212285044 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


23116 	ETA:  2.72537937261774 minutesss 3.760189202633369 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


23197 	ETA:  2.6948267363066116 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


23342 	ETA:  2.641274501983003 minutess

  " Skipping tag %s" % (size, len(data), tag))


24930 	ETA:  2.0238599105479964 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


25068 	ETA:  1.9747688401154992 minutes



27695 	ETA:  0.9852525073375599 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


27866 	ETA:  0.9195086274529997 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


28312 	ETA:  0.7478521862620181 minutes

  " Skipping tag %s" % (size, len(data), tag))


28899 	ETA:  0.5221006968379528 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


28900 	ETA:  0.521756564395238 minutes28901 	ETA:  0.5213641416158535 minutes28902 	ETA:  0.5209716235208497 minutes28903 	ETA:  0.5205792927648791 minutes28904 	ETA:  0.5201961811345621 minutes28905 	ETA:  0.5198038279408781 minutes28906 	ETA:  0.5194114772422103 minutes28907 	ETA:  0.5190190882537145 minutes28908 	ETA:  0.5186604233334839 minutes28909 	ETA:  0.5182679677134115 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


29059 	ETA:  0.4603119885758616 minutess

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


29285 	ETA:  0.37305740982092245 minutes

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


30250 	ETA:  0.0 minutes6239329565 minutes

---

## Get Neighbours

In [None]:
import time

t = time.time()

names = getNames('/scratch/bam_subset_2_0_features/')
X = []
for i, fileName in enumerate(names):
    X.append(torch.load(fileName))
    print("ETA: ", (len(names) - i - 1) * ((time.time() - t) / (i+1)), end='\r')
X = torch.stack(X, 0)
print(X.shape)

In [None]:
createDirectories('/scratch/bam_subset_2_0_top_neighbours/')
createDirectories('/scratch/bam_subset_2_0_bottom_neighbours/')

In [None]:
import torch

C = 0.25 * torch.ones(1)
C = C.cuda()

def bDist(n1, n2, C):
    assert(n1.size(0) == 1)

    s1 = n1.size(0)
    s2 = n2.size(0)
    d = n1.size(1)
    
    n1 = n1.expand(s2, d)
    C = C.expand(s2, d)
    n2 = torch.sum((C - torch.min(torch.abs(torch.sub(n1, n2)), C)) ** 4, 1)

    return n2

if __name__ == '__main__':
    A = torch.ones(1, 10)
    B = 1.5* torch.ones(7, 10)
    A = A.cuda()
    B = B.cuda()
    print(bDist(A, B, C))

In [None]:
# import numpy as np
# # from ipynb.fs.full.Helper import bDist


# k = 15
# t = time.time()
# names = getNames('/scratch/bam_subset_2_0_features/')



# for i, source in enumerate(X):
# #     if i < 25000:
# #         continue
    
# #     if i > 25000 + 200:
# #         break
    
#     source = torch.unsqueeze(source, 0)

# #     d = dist(source, X)
# #     d, indices = d.sort()

#     bD = bDist(source, X, C)
#     bD, bIndices = bD.sort()

# #     topIds = indices[0, 1:k+1]
#     # Note: Negative slicing is not supported
#     bottomIds = bIndices[0:k]

# #     top = []
# #     for idx in topIds:
# # #         print(idx)
# #         top.append(names[idx].replace('bam_subset_2_0_features', 'bam_subset_2_0_labeled'))
# #     top = np.asarray(top)
    
#     bottom = []
#     for idx in bottomIds:
#         bottom.append(names[idx])
#     bottom = np.asarray(bottom)
    
# #     np.save(names[i].replace('bam_subset_2_0_features', 'bam_subset_2_0_top_neighbours'), top)
#     np.save(names[i].replace('bam_subset_2_0_features', 'bam_subset_2_0_bottom_neighbours'), bottom)

#     print("ETA: ", (len(X) - i - 1) * ((time.time() - t) / (i+1)) * (1 / 60), end='\r')

In [None]:
import numpy as np

k = 15
t = time.time()

X = X ** 0.8

for i, source in enumerate(X):
    source = torch.unsqueeze(source, 0)

    d = dist(source, X)
    d, indices = d.sort()

    topIds = indices[0, 1:k+1]
    # Note: Negative slicing is not supported
    bottomIds = indices[0, -k:]

    top = []
    for idx in topIds:
        top.append(names[idx].replace('bam_subset_2_0_features', 'bam_subset_2_0'))
    top = np.asarray(top)
    
    bottom = []
    for idx in bottomIds:
        bottom.append(names[idx])
    bottom = np.asarray(bottom)
    
    np.save(names[i].replace('bam_subset_2_0_features', 'bam_subset_2_0_top_neighbours'), top)
    np.save(names[i].replace('bam_subset_2_0_features', 'bam_subset_2_0_bottom_neighbours'), bottom)

    print("ETA: ", (len(X) - i - 1) * ((time.time() - t) / (i+1)) * (1 / 60), end='\r')

In [None]:
A = torch.rand(3, 3)
print(A)
print(A**0.8)
print(A*)

---

## Statistics

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline

names = getNames('/scratch/bam_subset_2_0_top_neighbours/', shuffle=True)
bNames = getNames('/scratch/bam_subset_2_0_bottom_neighbours/', shuffle=True)

def showImages(images):
    fig = plt.figure(figsize=(32, 32))
    columns = 5
    rows = 4
    for i in range(1, len(images) + 1):
        img = images[i-1]
        fig.add_subplot(rows, columns, i)
        plt.imshow(img)
    plt.show()

idx = 0
def showNeighbours(index = None):
    global idx
    if index is not None:
        idx = index
    nbs = np.load(names[idx])
    bNbs = np.load(bNames[idx])
    images = []
    nImages = []
    print(names[idx])
    images.append(Image.open(names[idx].replace('_top_neighbours', '_labeled').replace('.npy', '')))
    nImages.append(Image.open(bNames[idx].replace('_bottom_neighbours', '_labeled').replace('.npy', '')))
    for nb in nbs:
        print("NB", nb)
        nb = nb.replace('bam_subset_2_0', 'bam_subset_2_0_labeled')
        images.append(Image.open(nb))
    for nb in bNbs:
        print("NB", nb)
        nb = nb.replace('bam_subset_2_0_features', 'bam_subset_2_0_labeled')
        nImages.append(Image.open(nb))
    showImages(images)
    showImages(nImages)
    idx += 1


#### Qualitatively test results

In [None]:
showNeighbours()

In [None]:
import pickle

labels = None
with open('../bam_2_0_image_style_labels.pkl', 'rb') as handle:
    labels = pickle.load(handle)

print(len(labels))

idx2 = 0

names = getNames('/scratch/bam_subset_2_0_top_neighbours/', shuffle=True)

def getTopAccuracy(n):
    total = 0
    correct = 0
    mini = 1000000
    worst = None
    for i in range(n):
        truth = labels[names[i].split('/')[-1].replace('.npy', '')]
        nbs = np.load(names[i])
#         print(truth)
        curr = 0
        l = []
        for nb in nbs:
            nb = nb.split('/')[-1]
            l.append(labels[nb])
            if truth == labels[nb]:
                curr += 1
#             print(labels[nb])
            total += 1
        if curr < mini:
            mini = curr
            worst = i
            actual = truth
            nb_labels = l
        correct += curr
    showNeighbours(worst)
    print(mini, " / 15 are correct neighbours")
    print("Actual label", actual)
    print("Neighbour labels: ", nb_labels)
    return str((correct / total) * 100) + " %"

In [None]:
print(getTopAccuracy(1000))

In [None]:
mini = 10000
maxi = -10000
for i, x in enumerate(X):
#     print(i, x)
#     print(min(x), max(x))
    mini = min(min(x), mini)
    maxi = max(max(x), maxi)
    if i == 1000:
        break
print(mini, maxi)
