In [None]:
import os
import numpy as np
import torch
from collections import OrderedDict
import json
from models.mlp import MLP_MoCo
from tqdm import tqdm

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')

In [None]:
query_feature_A_dir = '/nfs3-p2/zsxm/naic/preliminary/test_A/query_feature_A'
gallery_feature_A_dir = '/nfs3-p2/zsxm/naic/preliminary/test_A/gallery_feature_A'
train_feature_dir = '/nfs3-p1/zsxm/naic/preliminary/train/train_feature'

In [None]:
query_feature_A = []
for data in sorted(os.listdir(query_feature_A_dir)):
    query_feature_A.append(torch.from_numpy(np.fromfile(os.path.join(query_feature_A_dir, data), dtype='<f4')))
query_feature_A = torch.stack(query_feature_A)

In [None]:
gallery_feature_A = []
for data in sorted(os.listdir(gallery_feature_A_dir)):
    gallery_feature_A.append(torch.from_numpy(np.fromfile(os.path.join(gallery_feature_A_dir, data), dtype='<f4')))
gallery_feature_A = torch.stack(gallery_feature_A)

In [None]:
print(query_feature_A.shape, query_feature_A)
print(gallery_feature_A.shape, gallery_feature_A)
torch.save(query_feature_A, '/nfs3-p2/zsxm/naic/preliminary/test_A/query_feature_A.pt')
torch.save(gallery_feature_A, '/nfs3-p2/zsxm/naic/preliminary/test_A/gallery_feature_A.pt')

In [None]:
train_feature = []
for data in sorted(os.listdir(train_feature_dir)):
    train_feature.append(torch.from_numpy(np.fromfile(os.path.join(train_feature_dir, data), dtype='<f4')))
train_feature = torch.stack(train_feature)
print(train_feature.shape)
torch.save(train_feature, '/nfs3-p1/zsxm/naic/preliminary/train/train_feature.pt')

In [None]:
query_feature_A = torch.load('/nfs3-p2/zsxm/naic/preliminary/test_A/query_feature_A.pt')
gallery_feature_A = torch.load('/nfs3-p2/zsxm/naic/preliminary/test_A/gallery_feature_A.pt')
train_feature = torch.load('/nfs3-p1/zsxm/naic/preliminary/train/train_feature.pt')
print(query_feature_A.shape, query_feature_A)
print(gallery_feature_A.shape, gallery_feature_A)
print(train_feature.shape, train_feature)

In [None]:
query_feature_A = query_feature_A.to(device)
gallery_feature_A = gallery_feature_A.to(device)
train_feature = train_feature.to(device)

In [None]:
def cos_similarity(q, k):
    try:
        mm = torch.mm(q, k.T)
        qn = torch.linalg.vector_norm(q, dim=1, keepdim=True)
        kn = torch.linalg.vector_norm(k, dim=1, keepdim=True)
        qk = torch.mm(qn, kn.T)
        res = mm / qk
    except RuntimeError:
        q = q.cpu()
        k = k.cpu()
        mm = torch.mm(q, k.T)
        qn = torch.linalg.vector_norm(q, dim=1, keepdim=True)
        kn = torch.linalg.vector_norm(k, dim=1, keepdim=True)
        qk = torch.mm(qn, kn.T)
        res = mm / qk
    return res

In [None]:
def cos_similarity(q, k):
    try:
        q = torch.nn.functional.normalize(q, dim=1)
        k = torch.nn.functional.normalize(k, dim=1)
        res = torch.mm(q, k.T)
    except RuntimeError:
        q = q.cpu()
        k = k.cpu()
        q = torch.nn.functional.normalize(q, dim=1)
        k = torch.nn.functional.normalize(k, dim=1)
        res = torch.mm(q, k.T)
    return res

In [None]:
@torch.no_grad()
def batch_cos(q, k, batch_size=2048):
    res = []
    for i in tqdm(range(0, q.shape[0], batch_size)):
        bres = cos_similarity(q[i:i+batch_size], k)
        res.append(bres)
    return torch.cat(res)

In [None]:
res = batch_cos(query_feature_A, gallery_feature_A)
print(res.shape)

In [None]:
net = MLP_MoCo()
net.load_state_dict(torch.load('.details/checkpoints/MLP/01-11_11:48:10/Net_best.pth', map_location=device))
net.to(device)
net = net.encoder_q#.encoder
net.eval()
print('')

In [None]:
@torch.no_grad()
def encode(net, data, batch_size=2048):
    res, rcons = [], []
    for i in tqdm(range(0, data.shape[0], batch_size)):
        brcons, bres = net(data[i:i+batch_size])
        res.append(bres)
        rcons.append(brcons)
    return torch.cat(res), torch.cat(rcons)

In [None]:
query_code_A, query_recons_A = encode(net, query_feature_A)
gallery_code_A, gallery_recons_A = encode(net, gallery_feature_A)

In [None]:
print(query_code_A.shape, query_recons_A.shape)
print(gallery_code_A.shape, gallery_recons_A.shape)

In [None]:
res2 = batch_cos(query_code_A, gallery_code_A)
print(res.shape)

In [None]:
def print_res(res, print_json=False):
    res_dict = {}
    query_names = sorted(os.listdir(query_feature_A_dir))
    gallery_names = sorted(os.listdir(gallery_feature_A_dir))
    for i, name in enumerate(tqdm(query_names)):
        idx = torch.argsort(res[i], dim=-1, descending=True)
        query_res = []
        for j in range(100):
            query_res.append(gallery_names[idx[j]])
        res_dict[name] = query_res
    if print_json:
        with open('./sub_a.json', 'w') as f:
        json.dump(res_dict, f)

In [None]:
def compare_res(res1, res2, topk=100):
    assert res1.shape[0] == res2.shape[0]
    ave = 0
    t = tqdm(range(res1.shape[0]))
    for i in t:
        idx1 = set(torch.argsort(res1[i], dim=-1, descending=True)[:topk].tolist())
        idx2 = set(torch.argsort(res2[i], dim=-1, descending=True)[:topk].tolist())
        iou = len(idx1&idx2)/len(idx1|idx2)
        t.set_postfix(IoU=f'{iou:.4f}')
        ave += iou
    ave /= res1.shape[0]
    return ave

In [None]:
ave = compare_res(res, res2)
print(ave)

## 马氏距离计算

In [None]:
qfs = query_feature_A.sum(dim=0)
nqfs = torch.arange(2048)[qfs!=0]
print(nqfs.shape, nqfs)
gfs = gallery_feature_A.sum(dim=0)
ngfs = torch.arange(2048)[gfs!=0]
print(ngfs.shape, ngfs)
print(nqfs.equal(ngfs))
tfs = train_feature.sum(dim=0)
ntfs = torch.arange(2048)[tfs!=0]
print(ntfs.shape, ntfs)
print(nqfs.equal(ntfs))

In [None]:
not_zero_dim = nqfs
torch.save(not_zero_dim, '/nfs3-p1/zsxm/naic/preliminary/train/not_zero_dim.pt')

In [None]:
query_reshape_A = query_feature_A[:, qfs!=0]
gallery_reshape_A = gallery_feature_A[:, gfs!=0]
print(query_reshape_A.shape, gallery_reshape_A.shape)

In [None]:
print(query_reshape_A.abs().mean(), gallery_reshape_A.abs().mean())
print(query_reshape_A.abs().max(), gallery_reshape_A.abs().max())

In [None]:
del query_feature_A, gallery_feature_A

In [None]:
def Mahalanobis(q, k):
    query_names = sorted(os.listdir(query_feature_A_dir))
    gallery_names = sorted(os.listdir(gallery_feature_A_dir))
    q = torch.nn.functional.normalize(q, dim=1)
    k = torch.nn.functional.normalize(k, dim=1)
    
    mean_k = k.mean(dim=0, keepdim=True)
    sk = k - mean_k
    cov = torch.mm(sk.T, sk)/(sk.shape[0]-1)
    print(cov.shape, cov)
    cov = cov.to(torch.device('cpu'))
    icov = torch.linalg.inv(cov)
    print(icov.shape, icov)
    del cov
    icov = icov.to(device)
    res_dict = {}
    for i in tqdm(range(q.shape[0])):
        sub = q[i] - k
        mul = torch.mm(sub, icov)
        r = torch.einsum('bi,bi->b', mul, sub)
        assert r.shape == (k.shape[0],)
        idx = torch.argsort(r, descending=False)
        query_res = []
        for j in range(100):
            query_res.append(gallery_names[idx[j]])
        res_dict[query_names[i]] = query_res
    return res_dict

In [None]:
res3 = Mahalanobis(query_reshape_A, gallery_reshape_A)

In [None]:
with open('./sub_a.json', 'w') as f:
    json.dump(res3, f)

# ReRanking

In [1]:
import os
import numpy as np
from collections import OrderedDict
import json
from tqdm import tqdm
from scipy.spatial.distance import cdist

query_feature_A_dir = '/nfs3-p2/zsxm/naic/preliminary/test_A/query_feature_A'
gallery_feature_A_dir = '/nfs3-p2/zsxm/naic/preliminary/test_A/gallery_feature_A'

In [None]:
query_feature_A = []
for data in sorted(os.listdir(query_feature_A_dir)):
    query_feature_A.append(np.fromfile(os.path.join(query_feature_A_dir, data), dtype='<f4'))
query_feature_A = np.stack(query_feature_A)
gallery_feature_A = []
for data in sorted(os.listdir(gallery_feature_A_dir)):
    gallery_feature_A.append(np.fromfile(os.path.join(gallery_feature_A_dir, data), dtype='<f4'))
gallery_feature_A = np.stack(gallery_feature_A)
print(query_feature_A.shape, gallery_feature_A.shape)
qfs = query_feature_A.sum(axis=0)
nqfs = np.arange(2048)[qfs!=0]
print(nqfs.shape, nqfs)
gfs = gallery_feature_A.sum(axis=0)
ngfs = np.arange(2048)[gfs!=0]
print(ngfs.shape, ngfs)
print((nqfs==ngfs).all)
query_reshape_A = query_feature_A[:, qfs!=0]
gallery_reshape_A = gallery_feature_A[:, gfs!=0]
print(query_reshape_A.shape, gallery_reshape_A.shape)
np.save('/nfs3-p2/zsxm/naic/preliminary/test_A/query_feature_A.npy', query_feature_A)
np.save('/nfs3-p2/zsxm/naic/preliminary/test_A/gallery_feature_A.npy', gallery_feature_A)
np.save('/nfs3-p2/zsxm/naic/preliminary/test_A/query_reshape_A.npy', query_reshape_A)
np.save('/nfs3-p2/zsxm/naic/preliminary/test_A/gallery_reshape_A.npy', gallery_reshape_A)

In [2]:
query_reshape_A = np.load('/nfs3-p2/zsxm/naic/preliminary/test_A/query_reshape_A.npy')
gallery_reshape_A = np.load('/nfs3-p2/zsxm/naic/preliminary/test_A/gallery_reshape_A.npy')

In [3]:
print(query_reshape_A.shape, gallery_reshape_A.shape)

(20000, 463) (428794, 463)


In [2]:
from re_ranking.re_ranking_pytable import re_ranking

In [None]:
res = re_ranking(20000, 428794, 100, 30, 0.3, 1000)

starting re_ranking:  10%|█         | 45827/448794 [3:13:15<23:18:47,  4.80it/s]

In [None]:
import os
import numpy as np
import tables
from tqdm import tqdm
dis_path = '/nfs3-p2/zsxm/naic/preliminary/test_A/dis'

In [None]:
hdf5_path = os.path.join(dis_path, 'reranking.hdf5')
hdf5_file = tables.open_file(hdf5_path, mode='w')

In [None]:
filters = tables.Filters()

In [None]:
temp = np.load(os.path.join(dis_path, 'original_dist-0.npy'))

In [None]:
all_num = temp.shape[1]
print(all_num)

In [None]:
original_dist = hdf5_file.create_earray(hdf5_file.root, 
                                        'original_dist', 
                                        tables.Atom.from_dtype(temp.dtype), 
                                        shape=(0, temp.shape[1]), 
                                        filters=filters, 
                                        expectedrows=temp.shape[1])

In [None]:
for t, i in enumerate(tqdm(range(0, all_num, 1000))):
    ori = np.load(os.path.join(dis_path, f'original_dist-{t}.npy'))
    original_dist.append(ori)

In [None]:
temp = np.load(os.path.join(dis_path, 'initial_rank-0.npy'))

In [None]:
initial_rank = hdf5_file.create_earray(hdf5_file.root, 
                                        'initial_rank', 
                                        tables.Atom.from_dtype(temp.dtype), 
                                        shape=(0, 101), 
                                        filters=filters, 
                                        expectedrows=all_num)

In [None]:
for t, i in enumerate(tqdm(range(0, all_num, 1000))):
    ori = np.load(os.path.join(dis_path, f'initial_rank-{t}.npy'))
    initial_rank.append(ori[:, :101])

In [None]:
hdf5_file.close()

In [None]:
V = hdf5_file.create_carray(hdf5_file.root, 
                            'V', 
                            tables.Atom.from_dtype(np.zeros(1, dtype=np.float32).dtype), 
                            shape=(all_num, all_num), 
                            filters=filters)

In [None]:
hdf5_file.root

In [None]:
type(V)

In [None]:
hdf5_file.remove_node(hdf5_file.root, 'V')

In [None]:
hdf5_file = tables.open_file(hdf5_path, mode='r')

In [None]:
hdf5_file.root

In [None]:
original_dist = hdf5_file.root.original_dist

In [None]:
original_dist[400000]

In [None]:
hdf5_file.close()

In [None]:
from datasets.preliminary_dataset import PreliminaryDataset, PreliminaryBatchSampler, preliminary_collate_fn
import random
from torch.utils.data import DataLoader

In [None]:
dataset = PreliminaryDataset('/nfs3-p1/zsxm/naic/preliminary/train', False)

In [None]:
batchsampler = PreliminaryBatchSampler(dataset, 150)

In [None]:
dataloader = DataLoader(dataset, batch_sampler=batchsampler, num_workers=8, collate_fn=preliminary_collate_fn, pin_memory=True)

In [None]:
random.seed(2)#2078
test_count = 0
while True:
    batchs = []
    t_batch = []
    for i, b in enumerate(batchsampler):
        #print(i, b)
        batchs.append(b)
        t_batch.extend(b)
        b_len = 0
        for idx in b:
            b_len += dataset.idx2len[idx]
        assert 0 < b_len <= batchsampler.batch_size, str(b_len)+str(b)+str(i)
    assert len(t_batch) == 15000, len(t_batch)
    t_batch_set = set(t_batch)
    assert len(t_batch) == len(t_batch_set), len(t_batch_set)
    test_count +=1
    print(test_count)

In [None]:
count = 0
for q, k, q_label, k_label in dataloader:
    print(count)
    print(q.shape)
    print(k.shape)
    print(q_label.shape)
    print(k_label.shape)
    count += 1