In [1]:
import os
import numpy as np
import torch
from collections import OrderedDict
import json

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [6]:
query_feature_A_dir = '/nfs3-p2/zsxm/naic/preliminary/test_A/query_feature_A'
query_feature_A = []
for data in sorted(os.listdir(query_feature_A_dir)):
    query_feature_A.append(torch.from_numpy(np.fromfile(os.path.join(query_feature_A_dir, data), dtype='<f4')))
query_feature_A = torch.stack(query_feature_A).to(device)

In [3]:
gallery_feature_A_dir = '/nfs3-p2/zsxm/naic/preliminary/test_A/gallery_feature_A'
gallery_feature_A = []
for data in sorted(os.listdir(gallery_feature_A_dir)):
    gallery_feature_A.append(torch.from_numpy(np.fromfile(os.path.join(gallery_feature_A_dir, data), dtype='<f4')))
gallery_feature_A = torch.stack(gallery_feature_A).to(device)

In [4]:
print(query_feature_A.shape, query_feature_A)
print(gallery_feature_A.shape, gallery_feature_A)

torch.Size([20000, 2048]) tensor([[ 0.0805,  0.0000,  0.0000,  ...,  0.0000,  0.0646,  0.0000],
        [-0.2679,  0.0000,  0.0000,  ...,  0.0000, -0.2808,  0.0000],
        [-0.3382,  0.0000,  0.0000,  ...,  0.0000,  0.0906,  0.0000],
        ...,
        [-0.2313,  0.0000,  0.0000,  ...,  0.0000, -0.4064,  0.0000],
        [-0.1418,  0.0000,  0.0000,  ...,  0.0000,  0.0699,  0.0000],
        [ 0.0150,  0.0000,  0.0000,  ...,  0.0000, -0.2349,  0.0000]],
       device='cuda:0')
torch.Size([428794, 2048]) tensor([[ 0.1381,  0.0000,  0.0000,  ...,  0.0000, -0.0197,  0.0000],
        [-0.3590,  0.0000,  0.0000,  ...,  0.0000, -0.5844,  0.0000],
        [-0.3214,  0.0000,  0.0000,  ...,  0.0000, -0.2519,  0.0000],
        ...,
        [-0.0717,  0.0000,  0.0000,  ...,  0.0000, -0.4211,  0.0000],
        [-0.5035,  0.0000,  0.0000,  ...,  0.0000,  0.2123,  0.0000],
        [-0.2045,  0.0000,  0.0000,  ...,  0.0000, -0.6521,  0.0000]],
       device='cuda:0')


In [12]:
def cos_similarity(q, k):
    try:
        mm = torch.mm(q, k.T)
        qn = torch.linalg.vector_norm(q, dim=1, keepdim=True)
        kn = torch.linalg.vector_norm(k, dim=1, keepdim=True)
        qk = torch.mm(qn, kn.T)
        res = mm / qk
    except RuntimeError:
        q = q.cpu()
        k = k.cpu()
        mm = torch.mm(q, k.T)
        qn = torch.linalg.vector_norm(q, dim=1, keepdim=True)
        kn = torch.linalg.vector_norm(k, dim=1, keepdim=True)
        qk = torch.mm(qn, kn.T)
        res = mm / qk
    return res

In [13]:
res = cos_similarity(query_feature_A, gallery_feature_A)

In [14]:
print(res.shape)

torch.Size([20000, 428794])


In [22]:
res_dict = {}
query_names = sorted(os.listdir(query_feature_A_dir))
gallery_names = sorted(os.listdir(gallery_feature_A_dir))
for i, name in enumerate(query_names):
    idx = torch.argsort(res[i], dim=-1, descending=True)
    query_res = []
    for j in range(100):
        query_res.append(gallery_names[idx[j]])
    res_dict[name] = query_res

In [25]:
with open('./sub_a.json', 'w') as f:
    json.dump(res_dict, f)

In [1]:
from datasets.preliminary_dataset import PreliminaryDataset, PreliminaryBatchSampler, preliminary_collate_fn
import random
from torch.utils.data import DataLoader

In [2]:
dataset = PreliminaryDataset('/nfs3-p1/zsxm/naic/preliminary/train', False)

In [3]:
batchsampler = PreliminaryBatchSampler(dataset, 150)

In [4]:
dataloader = DataLoader(dataset, batch_sampler=batchsampler, num_workers=8, collate_fn=preliminary_collate_fn, pin_memory=True)

In [5]:
random.seed(2)#2078
test_count = 0
while True:
    batchs = []
    t_batch = []
    for i, b in enumerate(batchsampler):
        #print(i, b)
        batchs.append(b)
        t_batch.extend(b)
        b_len = 0
        for idx in b:
            b_len += dataset.idx2len[idx]
        assert 0 < b_len <= batchsampler.batch_size, str(b_len)+str(b)+str(i)
    assert len(t_batch) == 15000, len(t_batch)
    t_batch_set = set(t_batch)
    assert len(t_batch) == len(t_batch_set), len(t_batch_set)
    test_count +=1
    print(test_count)

1
2


KeyboardInterrupt: 

In [5]:
count = 0
for q, k, q_label, k_label in dataloader:
    print(count)
    print(q.shape)
    print(k.shape)
    print(q_label.shape)
    print(k_label.shape)
    count += 1

0
torch.Size([150, 2048])
torch.Size([158, 2048])
torch.Size([150])
torch.Size([158])
1
torch.Size([150, 2048])
torch.Size([158, 2048])
torch.Size([150])
torch.Size([158])
2
torch.Size([150, 2048])
torch.Size([157, 2048])
torch.Size([150])
torch.Size([157])
3
torch.Size([150, 2048])
torch.Size([161, 2048])
torch.Size([150])
torch.Size([161])
4
torch.Size([150, 2048])
torch.Size([162, 2048])
torch.Size([150])
torch.Size([162])
5
torch.Size([150, 2048])
torch.Size([160, 2048])
torch.Size([150])
torch.Size([160])
6
torch.Size([150, 2048])
torch.Size([161, 2048])
torch.Size([150])
torch.Size([161])
7
torch.Size([150, 2048])
torch.Size([158, 2048])
torch.Size([150])
torch.Size([158])
8
torch.Size([150, 2048])
torch.Size([159, 2048])
torch.Size([150])
torch.Size([159])
9
torch.Size([150, 2048])
torch.Size([158, 2048])
torch.Size([150])
torch.Size([158])
10
torch.Size([150, 2048])
torch.Size([162, 2048])
torch.Size([150])
torch.Size([162])
11
torch.Size([150, 2048])
torch.Size([161, 2048])
to

96
torch.Size([150, 2048])
torch.Size([155, 2048])
torch.Size([150])
torch.Size([155])
97
torch.Size([150, 2048])
torch.Size([160, 2048])
torch.Size([150])
torch.Size([160])
98
torch.Size([150, 2048])
torch.Size([153, 2048])
torch.Size([150])
torch.Size([153])
99
torch.Size([150, 2048])
torch.Size([160, 2048])
torch.Size([150])
torch.Size([160])
100
torch.Size([150, 2048])
torch.Size([159, 2048])
torch.Size([150])
torch.Size([159])
101
torch.Size([150, 2048])
torch.Size([155, 2048])
torch.Size([150])
torch.Size([155])
102
torch.Size([150, 2048])
torch.Size([166, 2048])
torch.Size([150])
torch.Size([166])
103
torch.Size([150, 2048])
torch.Size([158, 2048])
torch.Size([150])
torch.Size([158])
104
torch.Size([150, 2048])
torch.Size([158, 2048])
torch.Size([150])
torch.Size([158])
105
torch.Size([150, 2048])
torch.Size([158, 2048])
torch.Size([150])
torch.Size([158])
106
torch.Size([150, 2048])
torch.Size([161, 2048])
torch.Size([150])
torch.Size([161])
107
torch.Size([150, 2048])
torch.Si

192
torch.Size([150, 2048])
torch.Size([159, 2048])
torch.Size([150])
torch.Size([159])
193
torch.Size([150, 2048])
torch.Size([164, 2048])
torch.Size([150])
torch.Size([164])
194
torch.Size([150, 2048])
torch.Size([157, 2048])
torch.Size([150])
torch.Size([157])
195
torch.Size([150, 2048])
torch.Size([161, 2048])
torch.Size([150])
torch.Size([161])
196
torch.Size([150, 2048])
torch.Size([161, 2048])
torch.Size([150])
torch.Size([161])
197
torch.Size([150, 2048])
torch.Size([157, 2048])
torch.Size([150])
torch.Size([157])
198
torch.Size([150, 2048])
torch.Size([158, 2048])
torch.Size([150])
torch.Size([158])
199
torch.Size([150, 2048])
torch.Size([163, 2048])
torch.Size([150])
torch.Size([163])
200
torch.Size([150, 2048])
torch.Size([164, 2048])
torch.Size([150])
torch.Size([164])
201
torch.Size([150, 2048])
torch.Size([155, 2048])
torch.Size([150])
torch.Size([155])
202
torch.Size([150, 2048])
torch.Size([158, 2048])
torch.Size([150])
torch.Size([158])
203
torch.Size([150, 2048])
torc

Exception in thread Thread-4:
Traceback (most recent call last):
  File "/home/zsxm/miniconda3/envs/pytorch/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/home/zsxm/miniconda3/envs/pytorch/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/zsxm/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py", line 28, in _pin_memory_loop
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/home/zsxm/miniconda3/envs/pytorch/lib/python3.8/multiprocessing/queues.py", line 116, in get
    return _ForkingPickler.loads(res)
  File "/home/zsxm/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 289, in rebuild_storage_fd
    fd = df.detach()
  File "/home/zsxm/miniconda3/envs/pytorch/lib/python3.8/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home

288
torch.Size([150, 2048])
torch.Size([162, 2048])
torch.Size([150])
torch.Size([162])
289
torch.Size([150, 2048])
torch.Size([162, 2048])
torch.Size([150])
torch.Size([162])
290
torch.Size([150, 2048])
torch.Size([155, 2048])
torch.Size([150])
torch.Size([155])
291
torch.Size([150, 2048])
torch.Size([159, 2048])
torch.Size([150])
torch.Size([159])
292
torch.Size([150, 2048])
torch.Size([158, 2048])
torch.Size([150])
torch.Size([158])
293
torch.Size([150, 2048])
torch.Size([158, 2048])
torch.Size([150])
torch.Size([158])
294

KeyboardInterrupt: 

In [10]:
k_label

tensor([10774, 10774,  7723,  7723,  7723,  7723,  7723,  7723,  7723,  7723,
         7723,  7723,  7723,  7723,  7723,  7723,  7723,  9413,  9413,  9413,
         9413,  9413,  9413,  9413,  9413,  8551,  8551,  8551,  8551,  8551,
         8551,  8551,  8551,  8551, 10287, 10287,  3390,  3390,  3390,  3390,
         3390,  3390,  3390,  3390,  3390,  3390,  3390,  3390,  3390,  3390,
         3390,  8157,  8157,  8157,  8157,  8157,  8157,  8157,  6046,  6046,
         6046,  6046,  6046,  6046,  6046,  6297,  6297,  6297,  6297,  6297,
         6297,  6297,  6297,  6297,  6297,  3594,  3594,  3594,  3594,  3594,
         3594,  3594,  3594,  3594,  1222,  1222,  1222,  1222,  1222,  1222,
         1222,  1222,  1222,  1222,  1604,  1604,  1604,  1604,  1604,  1604,
         6265,  6265,  6265,  6265,  6265,  6265,  6265,  6265,  6265,  4755,
         4755,  4755,  4755,  4755,  4755,  4755,  4755,  7889,  7889,  7889,
         7889,  7889,  7889, 14219, 14219,  7755,  7755,  7755, 