In [13]:
import os
import pickle
import random
from pymilvus import (
    connections,
    utility,
    FieldSchema, CollectionSchema, DataType,
    Collection,
)

WORK_DIR = '/home/lishi/workspace/MAAF/'
os.chdir(WORK_DIR)

from datasets.datasets import load_dataset
from main import create_model_and_optimizer
from test_retrieval import compute_query_features
from test_retrieval import compute_db_features

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
MILVUS_HOST = '10.112.14.63'
MILVUS_PORT = '19530'

connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)

# fields = [
#     FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
#     FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=512)
# ]
# schema = CollectionSchema(fields = fields, description = "composed image retrieval demo schema")
# collection = Collection(name="fashioniq", schema=schema)

# collection.create_partition("val")
# collection.create_partition("test")

# assert collection.has_partition("val") == True
# assert collection.has_partition("test") == True

collection = Collection('fashioniq')

In [5]:
OPT_PATH = './temp/opt.pkl'
with open(OPT_PATH, 'rb') as f:
    opt = pickle.load(f)
dataset_dict = load_dataset(opt)
model, optimizer = create_model_and_optimizer(
    opt, dataset_dict["train"].get_all_texts())

Reading dataset  fashioniq
0  files not found in  train
0  files not found in  val
0  files not found in  test
train size 45429
val size 15415
test size 15417
Creating model and optimizer for sequence_concat_attention
Setting up sequence concatenation attention model
Using ResNet50
spatial
vocab size 3947


100%|██████████| 15415/15415 [02:04<00:00, 123.55it/s]


In [6]:
all_imgs, all_captions = compute_db_features(opt, model, testset=dataset_dict['val'])
data = [
    [cap for cap in all_captions],
    [img for img in all_imgs]
]
collection.insert(data = data, partition_name='val')

(insert count: 15415, delete count: 0, upsert count: 0, timestamp: 432185532578267137)

In [10]:
all_imgs_test, all_captions_test = compute_db_features(opt, model, testset=dataset_dict['test'])    
data_test = [
    [cap for cap in all_captions_test],
    [img for img in all_imgs_test]
]
collection.insert(data = data_test, partition_name='test')

(insert count: 15417, delete count: 0, upsert count: 0, timestamp: 432185675564449795)

In [14]:
def generate_a_random_query():
    _num2split = {
        0: 'train',
        1: 'val',
        2: 'test'
    }
    testset = dataset_dict[_num2split[random.randint(1, 2)]]
    test_queries = [testset.test_queries[random.randint(0, len(testset.test_queries))]]
    
    return testset, test_queries

testset, test_queries = generate_a_random_query()
computed_queries = compute_query_features(opt=opt, model=model, testset=testset, test_queries=test_queries)

100%|██████████| 1/1 [00:00<00:00, 13.14it/s]


In [20]:
computed_queries.tolist()

[[-0.21712519228458405,
  -0.587888777256012,
  -1.6805307865142822,
  0.8688580989837646,
  -0.22129948437213898,
  -0.16953909397125244,
  -0.5506154298782349,
  0.1262018233537674,
  0.20483772456645966,
  1.0819557905197144,
  0.3061200976371765,
  0.1666945219039917,
  0.12820270657539368,
  1.4313902854919434,
  -1.2735557556152344,
  0.5390727519989014,
  -0.5966591835021973,
  -0.23992139101028442,
  0.11315709352493286,
  -0.5121455788612366,
  -1.2028133869171143,
  0.6203965544700623,
  -0.598185122013092,
  -1.2485160827636719,
  -0.1040552482008934,
  0.7548115849494934,
  1.8889105319976807,
  -0.36713218688964844,
  -0.3719121515750885,
  0.6336674690246582,
  -0.3208863139152527,
  -0.6653966903686523,
  -0.11983291059732437,
  -1.4167134761810303,
  -0.09879478067159653,
  0.9427832365036011,
  0.6527234315872192,
  -0.8664104342460632,
  0.201171875,
  -1.0562942028045654,
  -0.289209246635437,
  0.48626571893692017,
  0.5795271396636963,
  0.0416078083217144,
  -0.82

In [22]:
collection.load()
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
results = collection.search(
	data=computed_queries.tolist(), 
	anns_field="embeddings", 
	param=search_params, 
	limit=100, 
	expr=None,
	consistency_level="Strong",
	partition_names=[testset.split]
)

In [26]:
candidates = [testset.id2asin[id] for id in results[0].ids]
scores = results[0].distances

In [28]:
candidates

['B006ZIOL60',
 'B00C3X34K0',
 'B007K4Y21W',
 'B0053A0PYI',
 'B004IK9XRY',
 'B00CI1EVWW',
 'B008ZAYOLI',
 'B001AOILUG',
 'B002TLTIMY',
 'B000GLO63K',
 'B008FLPFKG',
 'B00CY3AKLK',
 'B0096TN2MO',
 'B003UM1Y8W',
 'B000IBFDR6',
 'B00AYDWU42',
 'B00FEZN66G',
 'B00CJ9XL92',
 'B00CM2RPAW',
 'B000F41CI0',
 'B00CHSYIL0',
 'B00B35S11G',
 'B0072BD2VO',
 'B007FHZWAY',
 'B00APB7ADY',
 'B000Q584Q6',
 'B0007Y5JFA',
 'B0030ISRKE',
 'B005LTBO32',
 'B002OO142S',
 'B00F93DRH6',
 'B0074Z8FMY',
 'B00ESLTUXA',
 'B00FFGQ86O',
 'B00AZH3P00',
 'B005LTPCP8',
 'B007JSEF16',
 'B008FFR63Q',
 'B004K3FXC8',
 'B000I207NA',
 'B00DAX9D9S',
 'B00F8JQKRA',
 'B006ZP6OYK',
 'B007TMBTRU',
 'B0016NMA80',
 'B005CV19Q6',
 'B00AFRTQDK',
 'B00CI9FT16',
 'B00EP48GS0',
 'B00A8E6NJK',
 'B008N6Y2BG',
 'B002UNMPMG',
 'B006ZPQ1QG',
 'B00C30H8OG',
 'B008DS3MXI',
 'B008O8EUUQ',
 'B0055B76M4',
 'B00D8BPLYS',
 'B006W2ONVW',
 'B002CJ5L88',
 'B004B30PO8',
 'B006BB1KYC',
 'B007ZTC3SG',
 'B0096IXG08',
 'B009XBJ224',
 'B004GTLAYG',
 'B00850TW

In [18]:
testset.split

'test'

In [13]:
nn_result, sorted_sims = inference(opt=opt, model=model, test_queries=test_queries, dataset=dataset)

100%|██████████| 15415/15415 [02:02<00:00, 126.13it/s]
