In [1]:
import os
import torch
import argparse
import evaluate
import torchvision
import clip
from modules.GeMPooling import GeMPooling
from torch import nn
from pathlib import Path
from mapillary_sls.datasets.msls import MSLS
from mapillary_sls.datasets.msls_clip import MSLSCLIP
from mapillary_sls.datasets.generic_dataset import ImagesFromList, ImagesText
from mapillary_sls.utils.utils import configure_transform, clip_transform
from torch.utils.data import DataLoader

In [2]:
cities = "sf,cph"
root_dir = Path('/root/autodl-tmp/msls').absolute()

In [4]:
image_dim = (224, 224)
transform = clip_transform(image_dim)
# positive are defined within a radius of 25 m 阳性定义在25米的半径范围内
posDistThr = 25

# choose task to test on [im2im, seq2im, im2seq, seq2seq]
task = 'im2im'

# choose sequence length
seq_length = 1

# choose subtask to test on [all, s2w, w2s, o2n, n2o, d2n, n2d]
subtask = 'all'

batch_size = 4

val_dataset = MSLSCLIP(root_dir, cities = cities, transform = transform, mode = 'test',
                        task = task, seq_length = seq_length, subtask = subtask, posDistThr = posDistThr)
        
opt = {'batch_size': batch_size}
# get images
qLoader = DataLoader(ImagesText(val_dataset.qImages[val_dataset.qIdx], val_dataset.qText[val_dataset.qIdx], transform), **opt)
dbLoader = DataLoader(ImagesText(val_dataset.dbImages, val_dataset.dbText, transform), **opt)

=====> sf
=====> cph


In [5]:
def predict_clip_feature(net, Loader, device, im_or_seq='im'):
    """compute the features with net trained and get indices"""
    net.to(device)
    net.eval()
    result = []
    idx = []
    with torch.no_grad():
        if im_or_seq == 'im':

            for img_txt, y in Loader:
                x, text = img_txt
                print(x.shape)
                print("***")
                print(text.shape)
                x, text = x.to(device), text.reshape(-1, text.shape[-1]).to(device)
                
                y_hat = net(x, text)
                
                result.append(y_hat)
                idx.append(y)
                break
        elif im_or_seq == 'seq':
            # type(x_list)=list, and len(x_list=seq_length)
            for x_list, y in Loader:
                y_hat_list = torch.zeros((x_list[0].shape[0], net.back[1].out_features)).to(device)
                seq_length = len(x_list)
                for x in x_list:
                    # now the shape of x is(batch_size, 3, 224, 224)
                    x = x.to(device)
                    y_hat = net(x)
                    # compute the mean of all images in the seq
                    y_hat_list += y_hat
                y_hat = y_hat_list / seq_length
                result.append(y_hat)
                idx.append(y)
    result = torch.cat(result, dim=0)
    idx = torch.cat(idx, dim=0).reshape(-1, 1)
    return result, idx  

In [6]:
net, _= clip.load("ViT-B/16")


In [7]:
device = torch.device("cuda:0")
q_feature, q_idx = predict_clip_feature(net, qLoader, device, task.split('2')[0])

torch.Size([4, 3, 224, 224])
***
torch.Size([4, 1, 77])


In [8]:
print(len(val_dataset.dbImages))
print(len(val_dataset.dbText))

18916
18871


: 