# MuseNet

This notebook to test [MuseNet](https://github.com/wtyhub/MuseNet)


In [3]:
import os
import torch
import torch.nn as nn 
from utils.common import setup_seed 
from utils.loader import init_dataset_train, init_dataset_test
from utils.metrics import metrics 

from MuseNet.model import three_view_net
from MuseNet.utils import extract_feature, get_id, extract_feature
import yaml 

os.environ['TORCH_HOME']='./'

class MuseNet_:
    def __init__(self) -> None:
        self.seed = 2024 
        
        setup_seed(self.seed)
        self.model_dir = os.path.join(os.getcwd(), 'model', 'MuseNet')
        os.environ['TORCH_HOME']='./' 

    def test(self, pth=None, query='drone', gallery='satellite', multiple_scale=[1], batchsize=128, style='mixed'):
        # load config file
        with open(os.path.join(self.model_dir, 'opts.yaml'), 'r') as stream:
            config = yaml.safe_load(stream)
        # load data
        image_datasets, dataloaders, dataset_sizes = init_dataset_test(batchsize=batchsize, style=style, w=config['w'], h=config['h'])
        # init label
        gallery_name = 'gallery_' + gallery
        query_name = 'query_' + query 
        gallery_label = get_id(image_datasets[gallery_name].imgs)
        query_label = get_id(image_datasets[query_name].imgs)
        # print(dataset_sizes[gallery_name])
        
        # load model
        model_file = sorted([f for f in os.listdir(self.model_dir) if f.endswith('.pth')])[-1] if pth==None else pth + '.pth'
        print("load model: {}".format(model_file))

        model = three_view_net(701, config['droprate'], stride = config['stride'], pool = config['pool'], share_weight = config['share'], norm = config['norm'], adain = config['adain'], btnk=config['btnk'], conv_norm=config['conv_norm'], VGG16=config['use_vgg'], Dense=config['use_dense'])

        #network_dict = model.state_dict()
        #trained_dict = torch.load(os.path.join(self.model_dir, model_file))
        #print('different keys---------------:', (network_dict.keys()^trained_dict.keys()))   
        #[print(param_tensor, "\t", model.state_dict()[param_tensor].size()) for param_tensor in network_dict]
        #[print(param_tensor, "\t", model.state_dict()[param_tensor].size()) for param_tensor in trained_dict]

        model.load_state_dict(torch.load(os.path.join(self.model_dir, model_file))) 

        if config['LPN']:
            for i in range(config['block']):
                cls_name = 'classifier'+str(i)
                c = getattr(model, cls_name)
                c.classifier = nn.Sequential()
        else:
            model.classifier.classifier = nn.Sequential()

        model = model.cuda()
        model = model.eval()
    
        with torch.no_grad():
            query_feature = extract_feature(model,dataloaders[query_name], view=query, ms=multiple_scale, LPN=config['LPN'], block=config['block'])
            gallery_feature = extract_feature(model,dataloaders[gallery_name], view=gallery, ms=multiple_scale, LPN=config['LPN'], block=config['block'])

        # calculate 
        m = metrics(query_feature, query_label, gallery_feature, gallery_label)
        print("Recall@1: {:.2f}".format(m[0]))
        print("Recall@5: {:.2f}".format(m[1]))
        print("Recall@10: {:.2f}".format(m[2]))
        print("Recall@top1: {:.2f}".format(m[3]))
        print("Recall@AP: {:.2f}".format(m[4]))
        # return m


In [4]:
a = MuseNet_()

a.test(style='mixed')

load model: net_209.pth
load ibn params:-----------------


Extract drone feature: 100%|██████████| 296/296 [41:02<00:00,  8.32s/it]
Extract satellite feature: 100%|██████████| 8/8 [00:33<00:00,  4.13s/it]
Evaluate metrics: 100%|██████████| 37855/37855 [00:12<00:00, 2955.11it/s]

Recall@1: 61.36
Recall@5: 79.94
Recall@10: 85.83
Recall@top1: 86.52
Recall@AP: 65.64





In [6]:
from utils.loader import environments 

for style in environments:
    print(style)
    a.test(style=style)

normal
load model: net_209.pth
load ibn params:-----------------


Extract drone feature: 100%|██████████| 296/296 [10:48<00:00,  2.19s/it]
Extract satellite feature: 100%|██████████| 8/8 [00:31<00:00,  3.90s/it]
Evaluate metrics: 100%|██████████| 37855/37855 [00:12<00:00, 3046.37it/s]


Recall@1: 73.24
Recall@5: 88.82
Recall@10: 92.57
Recall@top1: 93.00
Recall@AP: 76.73
dark
load model: net_209.pth
load ibn params:-----------------


Extract drone feature: 100%|██████████| 296/296 [11:55<00:00,  2.42s/it]
Extract satellite feature: 100%|██████████| 8/8 [00:36<00:00,  4.52s/it]
Evaluate metrics: 100%|██████████| 37855/37855 [00:12<00:00, 2999.47it/s]


Recall@1: 68.96
Recall@5: 86.08
Recall@10: 90.63
Recall@top1: 91.15
Recall@AP: 72.81
fog
load model: net_209.pth
load ibn params:-----------------


Extract drone feature: 100%|██████████| 296/296 [21:08<00:00,  4.28s/it]
Extract satellite feature: 100%|██████████| 8/8 [00:30<00:00,  3.86s/it]
Evaluate metrics: 100%|██████████| 37855/37855 [00:12<00:00, 3045.23it/s]


Recall@1: 67.07
Recall@5: 84.82
Recall@10: 89.80
Recall@top1: 90.40
Recall@AP: 71.10
rain
load model: net_209.pth
load ibn params:-----------------


Extract drone feature: 100%|██████████| 296/296 [1:30:15<00:00, 18.30s/it]
Extract satellite feature: 100%|██████████| 8/8 [00:35<00:00,  4.39s/it]
Evaluate metrics: 100%|██████████| 37855/37855 [00:12<00:00, 3001.11it/s]


Recall@1: 62.05
Recall@5: 80.59
Recall@10: 86.12
Recall@top1: 86.78
Recall@AP: 66.29
snow
load model: net_209.pth
load ibn params:-----------------


Extract drone feature: 100%|██████████| 296/296 [49:33<00:00, 10.05s/it]
Extract satellite feature: 100%|██████████| 8/8 [00:30<00:00,  3.84s/it]
Evaluate metrics: 100%|██████████| 37855/37855 [00:13<00:00, 2891.90it/s]


Recall@1: 57.98
Recall@5: 77.58
Recall@10: 83.88
Recall@top1: 84.58
Recall@AP: 62.49
fog_rain
load model: net_209.pth
load ibn params:-----------------


Extract drone feature: 100%|██████████| 296/296 [48:30<00:00,  9.83s/it]
Extract satellite feature: 100%|██████████| 8/8 [00:31<00:00,  3.88s/it]
Evaluate metrics: 100%|██████████| 37855/37855 [00:13<00:00, 2885.93it/s]


Recall@1: 58.87
Recall@5: 78.52
Recall@10: 84.54
Recall@top1: 85.33
Recall@AP: 63.38
fog_snow
load model: net_209.pth
load ibn params:-----------------


Extract drone feature: 100%|██████████| 296/296 [32:44<00:00,  6.64s/it]
Extract satellite feature: 100%|██████████| 8/8 [00:30<00:00,  3.86s/it]
Evaluate metrics: 100%|██████████| 37855/37855 [00:12<00:00, 2926.00it/s]


Recall@1: 48.50
Recall@5: 69.28
Recall@10: 76.67
Recall@top1: 77.61
Recall@AP: 53.35
rain_snow
load model: net_209.pth
load ibn params:-----------------


Extract drone feature: 100%|██████████| 296/296 [1:13:55<00:00, 14.99s/it]
Extract satellite feature: 100%|██████████| 8/8 [00:30<00:00,  3.85s/it]
Evaluate metrics: 100%|██████████| 37855/37855 [00:12<00:00, 2918.95it/s]


Recall@1: 59.56
Recall@5: 78.81
Recall@10: 84.64
Recall@top1: 85.38
Recall@AP: 63.97
light
load model: net_209.pth
load ibn params:-----------------


Extract drone feature: 100%|██████████| 296/296 [11:35<00:00,  2.35s/it]
Extract satellite feature: 100%|██████████| 8/8 [00:31<00:00,  3.89s/it]
Evaluate metrics: 100%|██████████| 37855/37855 [00:14<00:00, 2529.69it/s]


Recall@1: 57.17
Recall@5: 77.40
Recall@10: 84.07
Recall@top1: 84.82
Recall@AP: 61.83
wind
load model: net_209.pth
load ibn params:-----------------


Extract drone feature: 100%|██████████| 296/296 [15:07<00:00,  3.07s/it]
Extract satellite feature: 100%|██████████| 8/8 [00:30<00:00,  3.82s/it]
Evaluate metrics: 100%|██████████| 37855/37855 [00:13<00:00, 2862.33it/s]


Recall@1: 57.63
Recall@5: 77.74
Recall@10: 84.40
Recall@top1: 85.26
Recall@AP: 62.25
