In [15]:
from hipposlam.trainVAE import EmbeddingImageDatasetAll
from pycircstat.descriptive import cdiff
import numpy as np

In [12]:
load_annotations_pth = 'data/VAE/annotations2.csv'
load_embed_dir = 'data/VAE/embeds2'
dataset = EmbeddingImageDatasetAll(load_annotations_pth, load_embed_dir, to_numpy=True)

In [13]:
embeds, labels = dataset.get_all()
x = labels[:, 0]
y = labels[:, 1]
a = labels[:, 2]
xmin, ymin = labels[:, [0, 1]].min(axis=0)
xmax, ymax = labels[:, [0, 1]].max(axis=0)

In [20]:
xnorm = (x-xmin)/(xmax-xmin)
ynorm = (y-ymin)/(ymax-ymin)
adiff = cdiff(a.reshape(-1, 1), a.reshape(1, -1)) / np.pi
posdist = np.sqrt((xnorm.reshape(-1, 1) - xnorm.reshape(1, -1)) ** 2 + (ynorm.reshape(-1, 1) - ynorm.reshape(1, -1)) ** 2) / np.sqrt(2)


In [23]:
dist_thresh = 0.1
adiff_thresh = 0.1
sim_mask = (posdist <dist_thresh) & (adiff < adiff_thresh)

In [27]:
sim_rids, sim_cids = np.where(sim_mask)
dissim_rids, dissim_cids = np.where(~sim_mask)


(6805058,)

In [58]:
import torch
from torch.utils.data import Dataset
from os.path import join
import pandas as pd

class ContrastiveEmbeddingDataset:
    def __init__(self, load_annotation_pth, load_embed_dir, datainds, dist_thresh=0.1, adiff_thresh=0.1):
        self.load_annotation_pth = load_annotation_pth
        self.load_embed_dir = load_embed_dir
        self.dist_thresh= dist_thresh
        self.adiff_thresh = adiff_thresh
        self.embed_all = torch.load(join(self.load_embed_dir, 'all.pt'))[datainds[0]:datainds[1]]
        self.img_labels = pd.read_csv(load_annotation_pth, header=0).to_numpy()[datainds[0]:datainds[1], [1, 2, 3]]
        self.sim_rcids, self.dissim_rcids = self._compute_contrastive_labels()
        self.sim_size, self.dissim_size = self.sim_rcids.shape[0], self.dissim_rcids.shape[0]

    def iterate(self, batchsize):
        assert (batchsize < self.sim_size) and (batchsize < self.dissim_size)
        sim_randvec = np.random.permutation(self.sim_size)
        dissim_randvec = np.random.permutation(self.sim_size)
        slice_inds = np.append(np.arange(0, self.sim_size, batchsize), self.sim_size)

        for i in range(len(slice_inds)-1):
            start_ind, end_ind = slice_inds[i], slice_inds[i+1]

            sim_randinds = sim_randvec[start_ind:end_ind]
            sim_rids = self.sim_rcids[sim_randinds, 0]
            sim_cids = self.sim_rcids[sim_randinds, 1]
            simbatch1 = self.embed_all[sim_rids]
            simbatch2 = self.embed_all[sim_cids]

            dissim_randinds = dissim_randvec[start_ind:end_ind]
            dissim_rids = self.dissim_rcids[dissim_randinds, 0]
            dissim_cids = self.dissim_rcids[dissim_randinds, 1]
            dissimbatch1 = self.embed_all[dissim_rids]
            dissimbatch2 = self.embed_all[dissim_cids]

            yield (simbatch1, simbatch2), (dissimbatch1, dissimbatch2)


    def _compute_contrastive_labels(self):
        x = self.img_labels[:, 0]
        y = self.img_labels[:, 1]
        a = self.img_labels[:, 2]
        xmin, ymin = self.img_labels[:, [0, 1]].min(axis=0)
        xmax, ymax = self.img_labels[:, [0, 1]].max(axis=0)
        xnorm = (x-xmin)/(xmax-xmin)
        ynorm = (y-ymin)/(ymax-ymin)
        adiff = cdiff(a.reshape(-1, 1), a.reshape(1, -1)) / np.pi
        posdist = np.sqrt((xnorm.reshape(-1, 1) - xnorm.reshape(1, -1)) ** 2 + (ynorm.reshape(-1, 1) - ynorm.reshape(1, -1)) ** 2) / np.sqrt(2)
        sim_mask_tmp = (posdist < self.dist_thresh) & (adiff < self.adiff_thresh)
        no_diag_mask = ~np.eye(sim_mask_tmp.shape[0]).astype(bool)
        sim_mask = sim_mask_tmp & no_diag_mask
        dissim_mask = ~sim_mask_tmp
        sim_rcids = np.stack(np.where(sim_mask)).T
        dissim_rcids = np.stack(np.where(dissim_mask)).T
        return sim_rcids, dissim_rcids




In [59]:
load_annotations_pth = 'data/VAE/annotations2.csv'
load_embed_dir = 'data/VAE/embeds2'
dataset = ContrastiveEmbeddingDataset(load_annotations_pth, load_embed_dir)

In [67]:
for (sim_embed1, sim_embed2), dissim in dataset.iterate(32):
    break

In [73]:
D = torch.sqrt(torch.sum(torch.square(sim_embed1 - sim_embed2), dim=1))
torch.nn.functional.relu(10 - D)

tensor([ 9.8988, 13.9596,  8.3449, 14.6587, 10.4811, 12.0691, 11.8705,  8.5749,
         8.3619, 11.5178, 12.8127,  8.2108, 13.8534,  8.8696,  6.4656,  8.3833,
         8.1345, 10.0703,  9.4563, 12.2103,  7.9625, 11.2243,  9.4736,  9.6146,
        10.9514,  8.0056,  6.9203, 14.2253, 13.5121,  7.9144,  6.9854,  8.2553])


tensor([0.1012, 0.0000, 1.6551, 0.0000, 0.0000, 0.0000, 0.0000, 1.4251, 1.6381,
        0.0000, 0.0000, 1.7892, 0.0000, 1.1304, 3.5344, 1.6167, 1.8655, 0.0000,
        0.5437, 0.0000, 2.0375, 0.0000, 0.5264, 0.3854, 0.0000, 1.9944, 3.0797,
        0.0000, 0.0000, 2.0856, 3.0146, 1.7447])