In [1]:
import os
import torch
import torchvision
import numpy as np
from GeMPooling import GeMPooling
from torch import nn
from pathlib import Path
from d2l import torch as d2l
from mapillary_sls.datasets.msls import MSLS
from mapillary_sls.datasets.generic_dataset import ImagesFromList
from mapillary_sls.utils.utils import configure_transform
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [2]:
SAMPLE_CITIES = "zurich,sf"

root_dir = Path('/datasets/msls').absolute()

# get transform
meta = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}
transform = configure_transform(image_dim = (480, 640), meta = meta)

In [22]:
def get_net(new=False):
    """get the resnet50"""
    pretrained_net = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2)
    
    # if the pretrained_net is not good, use the net
    if new == False:
        net = pretrained_net
    else:
        """net = nn.Sequential(*list(pretrained_net.children())[:-2])
        net.add_module("gempooling", GeMPooling(pretrained_net.fc.in_features, output_size=(1, 1)))
        net.add_module("fc", pretrained_net.fc)"""
        net_list = list(pretrained_net.children())
        # create new net
        net = nn.Sequential()
        net.base = nn.Sequential(*net_list[:-2])
        # use an adaptiveavg-pooling in the GeMpooling,kernel_size=(1, 1)
        gem = GeMPooling(net_list[-1].in_features, output_size=(1, 1))
        net.back = nn.Sequential(gem, pretrained_net.fc)

    return net

net = get_net(new=True)
# print(net)
# num_epochs, lr, device = 20, 0.01, torch.device("cuda:6" if torch.cuda.is_available() else "cpu")


In [18]:
print(next(net.fc.parameters()).requires_grad)

True


In [24]:
print(list(net.children())[-1])

Sequential(
  (0): GeMPooling(
    (avg_pooling): AdaptiveAvgPool2d(output_size=(1, 1))
  )
  (1): Linear(in_features=2048, out_features=1000, bias=True)
)


In [25]:
x = torch.randn(size=(1, 3, 480, 640))
y = net(x)
y.shape

torch.Size([1, 1000])

In [None]:
# positive are defined within a radius of 25 m 阳性定义在25米的半径范围内
posDistThr = 25

# choose task to test on [im2im, seq2im, im2seq, seq2seq]
task = 'seq2seq'

# choose sequence length
seq_length = 3

# choose subtask to test on [all, s2w, w2s, o2n, n2o, d2n, n2d]
subtask = 'all'

val_dataset = MSLS(root_dir, cities = SAMPLE_CITIES, transform = transform, mode = 'test',
                   task = task, seq_length = seq_length, subtask = subtask, posDistThr = posDistThr)

opt = {'batch_size': 5}

# get images
qLoader = DataLoader(ImagesFromList(val_dataset.qImages[val_dataset.qIdx], transform), **opt)
dbLoader = DataLoader(ImagesFromList(val_dataset.dbImages, transform), **opt)

# get positive index (we allow some more slack: default 25 m)
pIdx = val_dataset.pIdx

=====> zurich
=====> sf


In [None]:
for i, batch in enumerate(dbLoader):
    x, y = batch
    print(len(x))
    print(y)
    break
x[1].shape

5
tensor([0, 1, 2, 3, 4])


torch.Size([3, 480, 640])

In [None]:
val_dataset.qIdx.shape, val_dataset.pIdx.shape, val_dataset.qImages.shape, val_dataset.dbImages.shape

((369,), (0,), (369,), (9306,))

In [None]:
for i, batch in enumerate(qLoader):
    x, y = batch
    print(len(x))
    print(y)
    break
type(x)



5
tensor([0, 1, 2, 3, 4])


torch.Tensor

In [None]:
def predict_feature(net, Loader, device, im_or_seq='im'):
    """create the features and indices"""
    net.to(device)
    net.eval()
    result = []
    idx = []
    i=0
    with torch.no_grad():
        if im_or_seq == 'im':
            for i, (x, y) in enumerate(Loader):
                x = x.to(device)
                y_hat = net(x)
                print(y_hat.shape)
                result.append(y_hat)
                idx.append(y)
                if i == 4:
                    break
                i += 1
        elif im_or_seq == 'seq':
            # type(x_list)=list, and len(x_list=seq_length)
            for x_list, y in Loader:
                y_hat_list = torch.zeros((x_list[0].shape[0], net.fc.out_features)).to(device)
                seq_length = len(x_list)
                for x in x_list:
                    # now the shape of x is(batch_size, 3, 224, 224)
                    x = x.to(device)
                    y_hat = net(x)
                    # compute the mean of all images in the seq
                    y_hat_list += y_hat
                y_hat = y_hat_list / seq_length
                result.append(y_hat)
                idx.append(y)  
                
                if i == 4:
                    break
                i += 1
        result = torch.cat(result, dim=0)
        idx = torch.cat(idx, dim=0).reshape(-1, 1)
    return result, idx

q_result, q_idx = predict_feature(net, qLoader, device, task.split("2")[0])
db_result, _ = predict_feature(net, dbLoader, device, task.split("2")[1])

In [None]:
q_result.shape, q_idx.shape, db_result.shape

(torch.Size([25, 1000]), torch.Size([25, 1]), torch.Size([25, 1000]))

In [None]:
# db_result, db_idx = predict_feature(net_trained, dbLoader, device)
def query_to_dbIdx(qfeature, dbfeature):
    # for L2 norm:
    qfeature_normed = qfeature / torch.norm(qfeature, dim=1, keepdim=True)
    dbfeature_normed = dbfeature / torch.norm(dbfeature, dim=1, keepdim=True)

    # cos<a,b> of two vecter: a·b / (|a|*|b|)  == similarity of two vector
    sim = torch.matmul(qfeature_normed, dbfeature_normed.transpose(0, 1))
    # get the index of the first 20 maximum in sim

    idx = torch.argsort(sim, dim=1, descending=True)[:, :20]
    
    return idx


def find_keys(indices, val_dataset):
    address_all = val_dataset.dbImages[indices]
    # save the keys of all the queries
    keys_all = []
    for address_query in address_all:
        # save the 5 keys of one query
        keys = []
        for address in address_query:
            # address：array(['/datasets/msls/train_val/zurich/query/images/EOC7T_l63Z4LLTSSY6zkkg.jpg,
            # /datasets/msls/train_val/zurich/query/images/AoD5-ZB5YrgyClbR5qmG4g.jpg,
            # /datasets/msls/train_val/zurich/query/images/8VwFgahokEl-0uG-0Yoshg.jpg'], dtype='<U215')
            # key = address.split("/")[-1].split(".")[0]
            address_seq = address.split(",")
            key_seq = [key_1.split("/")[-1].split(".")[0] for key_1 in address_seq]
            # need a ',' between keys in the same seq
            # we add it when save to the .csv
            #if len(key_seq) > 1:
            #    key_seq = [key + ',' if i < len(key_seq) - 1 else key for i, key in enumerate(key_seq)]
                      
            keys.append(key_seq)
        keys_all.append(keys)
            
    return keys_all
    
# q1_idx = torch.arange(0,109, 1)
# q1 = torch.randn(size=(109, 5))
# db1 = torch.randn(size=(200, 5))
q_idx.cpu()
db_idx = query_to_dbIdx(q_result, db_result).cpu()
db_idx.shape

torch.Size([25, 20])

In [None]:
q_keys = find_keys(q_idx, val_dataset)
q_keys

[[['VBhOO_DV9AMtrCBdEg39IA,',
   '_qLwDOh1rhtPc7tVsII-wA,',
   '-sSqPMpmsbwv9iAjgKb5sQ']],
 [['_qLwDOh1rhtPc7tVsII-wA,',
   '-sSqPMpmsbwv9iAjgKb5sQ,',
   '5TUQ193fbsXUHn2RmJyIUQ']],
 [['-sSqPMpmsbwv9iAjgKb5sQ,',
   '5TUQ193fbsXUHn2RmJyIUQ,',
   'P_7zNYGjYObsCIpaM7e3Kg']],
 [['5TUQ193fbsXUHn2RmJyIUQ,',
   'P_7zNYGjYObsCIpaM7e3Kg,',
   '5G1h6AKQ4boqkcFazUR_Bw']],
 [['P_7zNYGjYObsCIpaM7e3Kg,',
   '5G1h6AKQ4boqkcFazUR_Bw,',
   '9sE-UeCra7KgJx6rNkq9xQ']],
 [['5G1h6AKQ4boqkcFazUR_Bw,',
   '9sE-UeCra7KgJx6rNkq9xQ,',
   'ynz1bXUulgXoOClYciv6Og']],
 [['9sE-UeCra7KgJx6rNkq9xQ,',
   'ynz1bXUulgXoOClYciv6Og,',
   'vnhuNjWARNoORJ1IhnBgLQ']],
 [['ynz1bXUulgXoOClYciv6Og,',
   'vnhuNjWARNoORJ1IhnBgLQ,',
   'S7F1HKQ7S6iTmoxTmuuY3g']],
 [['vnhuNjWARNoORJ1IhnBgLQ,',
   'S7F1HKQ7S6iTmoxTmuuY3g,',
   'nCmGP3LI96HAP1VfYlE40g']],
 [['S7F1HKQ7S6iTmoxTmuuY3g,',
   'nCmGP3LI96HAP1VfYlE40g,',
   'I-CSP3R29tWtcVwhEAKuKw']],
 [['nCmGP3LI96HAP1VfYlE40g,',
   'I-CSP3R29tWtcVwhEAKuKw,',
   'rboEQxSjLSidG8DaoxCeGQ']],

: 

In [None]:
address_all = val_dataset.qImages[q_idx]
key=address_all[0][0]
key

'/datasets/msls/train_val/zurich/query/images/AoD5-ZB5YrgyClbR5qmG4g.jpg'

In [None]:
address_all = val_dataset.qImages[db_idx]

In [None]:
key_3= key.split(",")
key_3

['/datasets/msls/train_val/zurich/query/images/AoD5-ZB5YrgyClbR5qmG4g.jpg']

In [None]:
keys = [key_1.split("/")[-1].split(".")[0] for key_1 in key_3]
keys

['AoD5-ZB5YrgyClbR5qmG4g']

In [None]:
db_keys = find_keys(db_idx, val_dataset)
q_keys = find_keys(q_idx, val_dataset)


In [None]:
def write_key_seq(key_col, f):
    """write one key seq for a image/seq to csv"""
    for i, q_key in enumerate(key_col):
        if i > 0:
            f.write(',' + str(q_key))
        else:
            f.write(str(q_key))


def save_to_csv(q_keys, db_keys, path):
    """save the keys in csv, seq_length q_keys match to 20*seq_length db_keys"""
    # create the csv saved keys
    os.makedirs(os.path.join('.', 'files'), exist_ok=True)
    data_file = path
    
    with open(data_file, 'w') as f:
        # one query key match to 5 database keys
        # db_keys size(query_num, 20, seq_length)
        # q_keys size(query_num, 1, seq_length)
        for db_one, q_one in zip(db_keys, q_keys):
            # 20 db col after 1 q_col
            for q_col in q_one:
                # len(q_col) = seq_length
                write_key_seq(q_col, f)
                f.write(' ')
                for db_col in db_one:
                    write_key_seq(db_col, f)
                    f.write(' ')
            f.write("\n")
path = Path('./results_seq.csv')
save_to_csv(q_keys, db_keys, path)