In [1]:
import os
import torch
import torchvision
import numpy as np
from torch import nn
from pathlib import Path
from GeMPooling import GeMPooling 
from d2l import torch as d2l
from mapillary_sls.datasets.msls import MSLS
from mapillary_sls.datasets.generic_dataset import ImagesFromList
from mapillary_sls.utils.utils import configure_transform
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [2]:
SAMPLE_CITIES = "zurich,sf"

root_dir = Path('/datasets/msls').absolute()

# get transform
meta = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}
transform = configure_transform(image_dim = (480, 640), meta = meta)

In [3]:
posDistThr = 5

# negatives are defined outside a radius of 25 m
negDistThr = 25

# number of negatives per triplet
nNeg = 5

# number of cached queries
cached_queries = 6

# number of cached negatives
cached_negatives = 100

# whether to use positive sampling
positive_sampling = True

# choose the cities to load
cities = SAMPLE_CITIES

# choose task to test on [im2im, seq2im, im2seq, seq2seq]
task = 'im2im'

# choose sequence length
seq_length = 1

train_dataset = MSLS(root_dir, cities = cities, transform = transform, mode = 'train', task = task, seq_length = seq_length,
                    negDistThr = negDistThr, posDistThr = posDistThr, nNeg = nNeg, positive_sampling = positive_sampling)

=====> zurich
=====> sf
#Sideways [179/3021]; #Night; [0/3021]
Forward and Day weighted with 1.0000
Sideways and Day weighted with 17.8771


  self.pIdx = np.asarray(self.pIdx)
  self.nonNegIdx = np.asarray(self.nonNegIdx)


In [4]:
# divides dataset into smaller cache sets
train_dataset.new_epoch()

# creates triplets on the smaller cache set
train_dataset.update_subcache()

# create data loader
opt = {'batch_size': 4, 'shuffle': True}
trainDataloader = DataLoader(train_dataset, **opt)


In [5]:
def get_net(new=False):
    """get the resnet50"""
    pretrained_net = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2)
    
    # if the pretrained_net is not good, use the net
    if new == False:
        net = pretrained_net
    else:
        net_list = list(pretrained_net.children())
        net_list[-2] = GeMPooling(net_list[-1].in_features)
        net = nn.Sequential(*net_list)
    return net

net = get_net(new=True)


In [6]:
def train(net, train_iter, num_epochs, loss, lr, optimizer, device, task):
    """train funtion"""
    net.to(device)
    loss = loss
    optimizer=optimizer
    
    for i in range(num_epochs):
        net.train()
        metric = d2l.Accumulator(2)
        for j, (sequences, labels) in enumerate(train_iter):
            if i == 0:
                N = labels.shape[1]
                q_seq_length, db_seq_length = split_seq(sequences, N, task)
            # sequences.shape=(batch_size, len(q)+len(p)+len(neg), 3, 480, 640)
            X = sequences.reshape(-1, 3, 480, 640).to(device)
            y_hat = net(X)
            y_hat = y_hat.reshape(sequences.shape[0], sequences.shape[1], -1)
            
            optimizer.zero_grad()
            anchor = y_hat[:, : q_seq_length, :].mean(1)
            positive = y_hat[:, q_seq_length: q_seq_length + db_seq_length, :].mean(1)
            negtive = y_hat[:, q_seq_length + db_seq_length:, :].mean(1)
            
            l = loss(anchor, positive, negtive)
            
            l.backward()
            optimizer.step()
            
            metric.add(l * sequences.shape[0], labels.numel())
            
            train_loss = metric[0] / metric[1]
            
            if j % 10 == 0:
                print(f'epoch:{i + 1}, loss:{train_loss:.3f}')
                print(list(net.children())[-2].p)
        print(f'epoch{i + 1} is end *****************************************')
        print(f'now loss:{train_loss:.3f}')
        print(f'epoch{i + 2} is start *****************************************')
    
    print(f'the train is end *****************************************')
    print(f'in the end, loss:{train_loss:.3f}')
    torch.save(net.state_dict(), f"train_model_epoch{num_epochs}_lr{lr}sucessfully.params")
    
                
                    
            
def split_seq(sequences, N, task):
    """split the sequences before training according to the task"""
    if task == "im2im":
        q_seq_length, db_seq_length = 1, 1
    elif task == "seq2seq":
        seq_length = sequences.shape[1] // (N)
        q_seq_length, db_seq_length = seq_length, seq_length
    elif task == "im2seq":
        seq_length = (sequences.shape[1] - 1) // (N - 1)
        q_seq_length, db_seq_length = 1, seq_length
    elif task == "seq2im":
        seq_length = sequences.shape[1] - (N - 1)
        q_seq_length, db_seq_length = seq_length, 1
    
    return q_seq_length, db_seq_length

            

In [7]:


num_epochs, lr, device = 2, 0.0001, torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
loss = nn.TripletMarginLoss(margin=train_dataset.margin, p=2)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train(net, trainDataloader, num_epochs, loss, lr, optimizer, device, task)

epoch:1, loss:1.523
Parameter containing:
tensor([2.9999, 2.9999, 2.9999,  ..., 2.9999, 3.0001, 2.9999], device='cuda:4',
       requires_grad=True)
epoch:1, loss:1.169
Parameter containing:
tensor([3.0002, 2.9994, 2.9992,  ..., 2.9995, 2.9995, 2.9996], device='cuda:4',
       requires_grad=True)
epoch:1, loss:0.945
Parameter containing:
tensor([3.0002, 2.9995, 2.9988,  ..., 2.9994, 2.9990, 2.9997], device='cuda:4',
       requires_grad=True)
epoch:1, loss:0.857
Parameter containing:
tensor([3.0001, 2.9997, 2.9986,  ..., 2.9997, 2.9985, 2.9995], device='cuda:4',
       requires_grad=True)
epoch:1, loss:0.749
Parameter containing:
tensor([2.9998, 2.9996, 2.9984,  ..., 2.9999, 2.9983, 2.9994], device='cuda:4',
       requires_grad=True)
epoch:1, loss:0.692
Parameter containing:
tensor([2.9998, 2.9993, 2.9981,  ..., 3.0000, 2.9980, 2.9996], device='cuda:4',
       requires_grad=True)


KeyboardInterrupt: 

In [21]:
# positive are defined within a radius of 25 m 阳性定义在25米的半径范围内
posDistThr = 25

# choose task to test on [im2im, seq2im, im2seq, seq2seq]
task = 'seq2seq'

# choose sequence length
seq_length = 3

# choose subtask to test on [all, s2w, w2s, o2n, n2o, d2n, n2d]
subtask = 'all'

val_dataset = MSLS(root_dir, cities = SAMPLE_CITIES, transform = transform, mode = 'test',
                   task = task, seq_length = seq_length, subtask = subtask, posDistThr = posDistThr)

opt = {'batch_size': 3}

# get images
qLoader = DataLoader(ImagesFromList(val_dataset.qImages[val_dataset.qIdx], transform), **opt)
dbLoader = DataLoader(ImagesFromList(val_dataset.dbImages, transform), **opt)

# get positive index (we allow some more slack: default 25 m)
pIdx = val_dataset.pIdx


=====> zurich
=====> sf


In [27]:
for i, batch in enumerate(qLoader):
    x, y = batch
    # print(len(x))
    print(y)
    break
x[0].shape


3
tensor([0, 1, 2])


torch.Size([3, 3, 480, 640])

In [90]:
for i, batch in enumerate(dbLoader):
    x, y = batch
    print(len(x))
    print(x.shape, y)
    break

1
torch.Size([1, 3, 480, 640]) tensor([0])


In [66]:
y_list = []
for i, batch in enumerate(qLoader):
    y_list.append(batch[1])
len(y_list)

362

In [67]:
ydb_list = []
for i, batch in enumerate(dbLoader):
    ydb_list.append(batch[1])
len(ydb_list)

8006

In [28]:
val_dataset.qIdx.shape, val_dataset.pIdx.shape, val_dataset.qImages.shape, val_dataset.dbImages.shape

((361,), (0,), (361,), (7948,))

In [29]:
val_dataset.dbImages[batch[1]]

array(['/datasets/msls/train_val/zurich/database/images/VBhOO_DV9AMtrCBdEg39IA.jpg,/datasets/msls/train_val/zurich/database/images/_qLwDOh1rhtPc7tVsII-wA.jpg,/datasets/msls/train_val/zurich/database/images/-sSqPMpmsbwv9iAjgKb5sQ.jpg',
       '/datasets/msls/train_val/zurich/database/images/_qLwDOh1rhtPc7tVsII-wA.jpg,/datasets/msls/train_val/zurich/database/images/-sSqPMpmsbwv9iAjgKb5sQ.jpg,/datasets/msls/train_val/zurich/database/images/5TUQ193fbsXUHn2RmJyIUQ.jpg',
       '/datasets/msls/train_val/zurich/database/images/-sSqPMpmsbwv9iAjgKb5sQ.jpg,/datasets/msls/train_val/zurich/database/images/5TUQ193fbsXUHn2RmJyIUQ.jpg,/datasets/msls/train_val/zurich/database/images/P_7zNYGjYObsCIpaM7e3Kg.jpg'],
      dtype='<U224')

In [72]:
batch[0]

tensor([[[[ 1.5125,  1.5125,  1.5297,  ..., -2.0837, -1.9980, -2.1179],
          [ 1.5125,  1.5125,  1.5297,  ..., -1.9980, -1.9124, -2.0494],
          [ 1.4954,  1.4954,  1.5125,  ..., -2.1008, -2.1008, -2.0837],
          ...,
          [-1.3130, -1.2445, -1.1418,  ..., -1.2959, -1.2788, -1.2103],
          [-1.4329, -1.3815, -1.2788,  ..., -1.3130, -1.3302, -1.3302],
          [-1.3987, -1.4158, -1.3987,  ..., -1.5870, -1.6213, -1.6898]],

         [[ 1.6583,  1.6583,  1.6758,  ..., -2.0007, -1.9132, -2.0357],
          [ 1.6583,  1.6583,  1.6758,  ..., -1.9132, -1.8256, -1.9657],
          [ 1.6408,  1.6408,  1.6583,  ..., -2.0182, -2.0182, -2.0007],
          ...,
          [-1.2479, -1.1779, -1.0728,  ..., -1.1253, -1.1078, -1.0378],
          [-1.3704, -1.3179, -1.2129,  ..., -1.1604, -1.1779, -1.1779],
          [-1.3354, -1.3529, -1.3354,  ..., -1.4405, -1.4755, -1.5455]],

         [[ 2.1520,  2.1520,  2.1694,  ..., -1.7696, -1.6824, -1.8044],
          [ 2.1520,  2.1520,  

In [73]:
img = torchvision.io.read_image("D:\\MSLS_train_val\\train_val\\trondheim\\database\\images\\9Iu7ckykQxh2KCQlhmraUg.jpg")

In [84]:
from PIL import Image
img = Image.open("D:\\MSLS_train_val\\train_val\\trondheim\\database\\images\\9Iu7ckykQxh2KCQlhmraUg.jpg")
img = transform(img)
img

tensor([[[ 1.5125,  1.5125,  1.5297,  ..., -2.0837, -1.9980, -2.1179],
         [ 1.5125,  1.5125,  1.5297,  ..., -1.9980, -1.9124, -2.0494],
         [ 1.4954,  1.4954,  1.5125,  ..., -2.1008, -2.1008, -2.0837],
         ...,
         [-1.3130, -1.2445, -1.1418,  ..., -1.2959, -1.2788, -1.2103],
         [-1.4329, -1.3815, -1.2788,  ..., -1.3130, -1.3302, -1.3302],
         [-1.3987, -1.4158, -1.3987,  ..., -1.5870, -1.6213, -1.6898]],

        [[ 1.6583,  1.6583,  1.6758,  ..., -2.0007, -1.9132, -2.0357],
         [ 1.6583,  1.6583,  1.6758,  ..., -1.9132, -1.8256, -1.9657],
         [ 1.6408,  1.6408,  1.6583,  ..., -2.0182, -2.0182, -2.0007],
         ...,
         [-1.2479, -1.1779, -1.0728,  ..., -1.1253, -1.1078, -1.0378],
         [-1.3704, -1.3179, -1.2129,  ..., -1.1604, -1.1779, -1.1779],
         [-1.3354, -1.3529, -1.3354,  ..., -1.4405, -1.4755, -1.5455]],

        [[ 2.1520,  2.1520,  2.1694,  ..., -1.7696, -1.6824, -1.8044],
         [ 2.1520,  2.1520,  2.1694,  ..., -1

In [None]:
q_result.shape, q_idx.shape

In [31]:
def predict_feature(net, Loader, device):
    net.to(device)
    net.eval()
    result = []
    idx = []
    with torch.no_grad():
        for x, y in Loader:
            x = x.to(device)
            y_hat = net(x)
            result.append(y_hat)
            idx.append(y)  
        result = torch.cat(result, dim=0)
        idx = torch.cat(idx, dim=0)
    return result, idx
#q_result, q_idx = predict_feature(net, qLoader, val_dataset, device)

In [112]:
q_result.shape, q_idx.shape

(torch.Size([362, 1000]), torch.Size([362]))

In [None]:
# db_result, db_idx = predict_feature(net_trained, dbLoader, device)
def query_to_dbIdx(q_result, db_result):
    # save the indices of the first 5 most similar images to query in db
    img_indices_list = []
    for feature in q_result:
        diff = torch.abs(feature - db_result).sum(dim=1).reshape(-1)
        # get the index of the first 5 minimum in diff
        idx = torch.argsort(diff)[:5].reshape(1, -1)
        
        img_indices_list.append(idx)
    img_idices = torch.cat(img_indices_list, dim=0)
    
    return img_idices


def find_keys(indices, val_dataset, mode="query"):
    
    if mode == "query":
        address_all = val_dataset.qImages[indices]
        # save the keys of all queries, one key for one query
        keys = []
        for address_query in address_all:
            # address_query示例：D:\\MSLS_train_val\\train_val\\zurich\\database\\images\\_qLwDOh1rhtPc7tVsII-wA.jpg
            key = address_query.split("\\")[-1].split(".")[0]
            keys.append(key)
        
        assert len(keys) == indices.shape[0]
        return keys
        
    elif mode == "database":
        # get all the addresses of the first 5 most similar imgs for all the queries
        address_all = val_dataset.dbImages[indices]
        
        # save the keys of all the queries
        keys_all = []
        for address_query in address_all:
            # save the 5 keys of one query
            keys = []
            for address in address_query:
                # address示例：D:\\MSLS_train_val\\train_val\\zurich\\database\\images\\_qLwDOh1rhtPc7tVsII-wA.jpg
                key = address.split("\\")[-1].split(".")[0]
                keys.append(key)
            keys_all.append(keys)
            
    return keys_all
    
q1_idx = torch.arange(0,109, 1)
q1 = torch.randn(size=(109, 5))
db1 = torch.randn(size=(200, 5))
db_indices = query_to_dbIdx(q1, db1)
# find_keys according to the index(index, dataset, mode="query"/"database")
db_keys = find_keys(db_indices, val_dataset, mode="database")
q_keys = find_keys(q1_idx, val_dataset, mode="query")
db_keys, q_keys

In [141]:
for i in img_address:
    a = i
num=2
a[num].split("\\")[-1].split(".")[0], a[0]

('9sE-UeCra7KgJx6rNkq9xQ',
 'D:\\MSLS_train_val\\train_val\\zurich\\database\\images\\_qLwDOh1rhtPc7tVsII-wA.jpg')

In [176]:
for i in q_keys:
    a = i.split("\\")[-1].split(".")[0]
    print(a)
    

AoD5-ZB5YrgyClbR5qmG4g
gEhnDp9LMQq4SEDWgyVvQw
mguDPMkLcTvYctYS2Zvo1w


In [181]:
def save_to_csv(q_keys, db_keys):
    # create the csv saved keys
    os.makedirs(os.path.join('.', 'files'), exist_ok=True)
    data_file = os.path.join('.', 'files', 'my_prediction_im2im.csv')
    
    with open(data_file, 'w') as f:
        # one query key match to five database keys
        for line, q_key in zip(db_keys, q_keys):
            f.write(str(q_key) + ' ')
            for db_key in line:
                f.write(str(db_key) + ' ')
            f.write('\n')

save_to_csv(q_keys, db_keys)


                

KeyboardInterrupt: 