In [None]:
import sys
import argparse
import os

sys.argv = ["view", "--config", "../../config/multi_task_cross_stitch.yaml"]

parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path to the config file")
args = parser.parse_args()

print(args.config)

In [None]:
from dataloader import VOC08Attr
from torchvision.transforms import transforms
from config_experiments import config
from model import ObjectDetectionModel, AttributePredictionModel, CrossStitchNet
import torch
import numpy as np
import torch.onnx
import netron
from utils import set_device

In [None]:
os.getcwd()
os.chdir("../../")
os.getcwd()

In [None]:
transform_train = transforms.Compose(
    [
        transforms.Resize(size=config["transform"]["resize_values"]),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=config["transform"]["mean"], std=config["transform"]["std"]
        ),
    ]
)
device = "cpu"

In [None]:
# FOR VGG16

path_best_model_obj = config["model"]["model_obj"]
path_best_model_attr = config["model"]["model_attr"]
model_obj = ObjectDetectionModel().to(device)
model_attr = AttributePredictionModel().to(device)
model_cross = CrossStitchNet(model_obj.backbone, model_attr.backbone)

for i, (name, params) in enumerate(model_cross.named_parameters()):
    print(i, name, params.shape)

best_model_obj = torch.load(path_best_model_obj, map_location=device)

for i, (name, params) in enumerate(best_model_obj.items()):
    print(i, name, params.shape)

best_model_attr = torch.load(path_best_model_attr, map_location=device)

for i, (name, params) in enumerate(best_model_attr.items()):
    print(i, name, params.shape)

mapping_obj = {
    "cross_stitch_net.models_a.0.0.weight": "alex.features.0.weight",
    "cross_stitch_net.models_a.0.0.bias": "alex.features.0.bias",
    "cross_stitch_net.models_a.0.2.weight": "alex.features.2.weight",
    "cross_stitch_net.models_a.0.2.bias": "alex.features.2.bias",
    "cross_stitch_net.models_a.1.0.weight": "alex.features.5.weight",
    "cross_stitch_net.models_a.1.0.bias": "alex.features.5.bias",
    "cross_stitch_net.models_a.1.2.weight": "alex.features.7.weight",
    "cross_stitch_net.models_a.1.2.bias": "alex.features.7.bias",
    "cross_stitch_net.models_a.2.0.weight": "alex.features.10.weight",
    "cross_stitch_net.models_a.2.0.bias": "alex.features.10.bias",
    "cross_stitch_net.models_a.2.2.weight": "alex.features.12.weight",
    "cross_stitch_net.models_a.2.2.bias": "alex.features.12.bias",
    "cross_stitch_net.models_a.2.4.weight": "alex.features.14.weight",
    "cross_stitch_net.models_a.2.4.bias": "alex.features.14.bias",
    "cross_stitch_net.models_a.3.0.weight": "alex.features.17.weight",
    "cross_stitch_net.models_a.3.0.bias": "alex.features.17.bias",
    "cross_stitch_net.models_a.3.2.weight": "alex.features.19.weight",
    "cross_stitch_net.models_a.3.2.bias": "alex.features.19.bias",
    "cross_stitch_net.models_a.3.4.weight": "alex.features.21.weight",
    "cross_stitch_net.models_a.3.4.bias": "alex.features.21.bias",
    "cross_stitch_net.models_a.4.0.weight": "alex.features.24.weight",
    "cross_stitch_net.models_a.4.0.bias": "alex.features.24.bias",
    "cross_stitch_net.models_a.4.2.weight": "alex.features.26.weight",
    "cross_stitch_net.models_a.4.2.bias": "alex.features.26.bias",
    "cross_stitch_net.models_a.4.4.weight": "alex.features.28.weight",
    "cross_stitch_net.models_a.4.4.bias": "alex.features.28.bias",
    "cross_stitch_classifier.branch_a.1.weight": "roi_module.classifier.1.weight",
    "cross_stitch_classifier.branch_a.1.bias": "roi_module.classifier.1.bias",
    "cross_stitch_classifier.branch_a.4.weight": "roi_module.classifier.4.weight",
    "cross_stitch_classifier.branch_a.4.bias": "roi_module.classifier.4.bias",
    "model_obj_detect.cls_score.weight": "obj_detect_head.cls_score.weight",
    "model_obj_detect.cls_score.bias": "obj_detect_head.cls_score.bias",
    "model_obj_detect.bbox.weight": "obj_detect_head.bbox.weight",
    "model_obj_detect.bbox.bias": "obj_detect_head.bbox.bias",
}

mapping_attr = {
    "cross_stitch_net.models_b.0.0.weight": "alex.features.0.weight",
    "cross_stitch_net.models_b.0.0.bias": "alex.features.0.bias",
    "cross_stitch_net.models_b.0.2.weight": "alex.features.2.weight",
    "cross_stitch_net.models_b.0.2.bias": "alex.features.2.bias",
    "cross_stitch_net.models_b.1.0.weight": "alex.features.5.weight",
    "cross_stitch_net.models_b.1.0.bias": "alex.features.5.bias",
    "cross_stitch_net.models_b.1.2.weight": "alex.features.7.weight",
    "cross_stitch_net.models_b.1.2.bias": "alex.features.7.bias",
    "cross_stitch_net.models_b.2.0.weight": "alex.features.10.weight",
    "cross_stitch_net.models_b.2.0.bias": "alex.features.10.bias",
    "cross_stitch_net.models_b.2.2.weight": "alex.features.12.weight",
    "cross_stitch_net.models_b.2.2.bias": "alex.features.12.bias",
    "cross_stitch_net.models_b.2.4.weight": "alex.features.14.weight",
    "cross_stitch_net.models_b.2.4.bias": "alex.features.14.bias",
    "cross_stitch_net.models_b.3.0.weight": "alex.features.17.weight",
    "cross_stitch_net.models_b.3.0.bias": "alex.features.17.bias",
    "cross_stitch_net.models_b.3.2.weight": "alex.features.19.weight",
    "cross_stitch_net.models_b.3.2.bias": "alex.features.19.bias",
    "cross_stitch_net.models_b.3.4.weight": "alex.features.21.weight",
    "cross_stitch_net.models_b.3.4.bias": "alex.features.21.bias",
    "cross_stitch_net.models_b.4.0.weight": "alex.features.24.weight",
    "cross_stitch_net.models_b.4.0.bias": "alex.features.24.bias",
    "cross_stitch_net.models_b.4.2.weight": "alex.features.26.weight",
    "cross_stitch_net.models_b.4.2.bias": "alex.features.26.bias",
    "cross_stitch_net.models_b.4.4.weight": "alex.features.28.weight",
    "cross_stitch_net.models_b.4.4.bias": "alex.features.28.bias",
    "cross_stitch_classifier.branch_b.1.weight": "roi_module.classifier.1.weight",
    "cross_stitch_classifier.branch_b.1.bias": "roi_module.classifier.1.bias",
    "cross_stitch_classifier.branch_b.4.weight": "roi_module.classifier.4.weight",
    "cross_stitch_classifier.branch_b.4.bias": "roi_module.classifier.4.bias",
    "model_attribute.attr_score.weight": "attribute_head.attr_score.weight",
    "model_attribute.attr_score.bias": "attribute_head.attr_score.bias",
}

# Copia dei pesi per il task di object detection
for name, param in model_cross.named_parameters():
    if name in mapping_obj:
        source_name = mapping_obj[name]
        param.data.copy_(best_model_obj[source_name].data)

# Copia dei pesi per il task di attribute classification
for name, param in model_cross.named_parameters():
    if name in mapping_attr:
        source_name = mapping_attr[name]
        param.data.copy_(best_model_attr[source_name].data)

In [None]:
# FOR ALEXNET

path_best_model_obj = "experiments/object_detection/2024-07-24_09-59-58/models/best_model_epoch_93.pth"  # 2024-07-24_09-59-58
path_best_model_attr = "experiments/attribute_prediction/2024-08-02_11-55-19/models/best_model_epoch_60.pth"  # 2024-08-02_11-55-19
model_obj = ObjectDetectionModel().to(device)
model_attr = AttributePredictionModel().to(device)
model_cross = CrossStitchNet(model_obj.backbone, model_attr.backbone)

for i, (name, params) in enumerate(model_cross.named_parameters()):
    print(i, name, params.shape)

best_model_obj = torch.load(path_best_model_obj, map_location=device)

for i, (name, params) in enumerate(best_model_obj.items()):
    print(i, name, params.shape)

best_model_attr = torch.load(path_best_model_attr, map_location=device)

for i, (name, params) in enumerate(best_model_attr.items()):
    print(i, name, params.shape)

# Caricamento dei modelli single-task
best_model_obj = torch.load(path_best_model_obj, map_location=device)
best_model_attr = torch.load(path_best_model_attr, map_location=device)

# Mappa tra i nomi dei layer nei modelli single-task e multi-task
mapping_obj = {
    "cross_stitch_net.models_a.0.0.weight": "alex.features.0.weight",
    "cross_stitch_net.models_a.0.0.bias": "alex.features.0.bias",
    "cross_stitch_net.models_a.1.0.weight": "alex.features.3.weight",
    "cross_stitch_net.models_a.1.0.bias": "alex.features.3.bias",
    "cross_stitch_net.models_a.2.0.weight": "alex.features.6.weight",
    "cross_stitch_net.models_a.2.0.bias": "alex.features.6.bias",
    "cross_stitch_net.models_a.2.2.weight": "alex.features.8.weight",
    "cross_stitch_net.models_a.2.2.bias": "alex.features.8.bias",
    "cross_stitch_net.models_a.2.4.weight": "alex.features.10.weight",
    "cross_stitch_net.models_a.2.4.bias": "alex.features.10.bias",
    "cross_stitch_classifier.branch_a.1.weight": "roi_module.classifier.1.weight",
    "cross_stitch_classifier.branch_a.1.bias": "roi_module.classifier.1.bias",
    "cross_stitch_classifier.branch_a.4.weight": "roi_module.classifier.4.weight",
    "cross_stitch_classifier.branch_a.4.bias": "roi_module.classifier.4.bias",
    "model_obj_detect.cls_score.weight": "obj_detect_head.cls_score.weight",
    "model_obj_detect.cls_score.bias": "obj_detect_head.cls_score.bias",
    "model_obj_detect.bbox.weight": "obj_detect_head.bbox.weight",
    "model_obj_detect.bbox.bias": "obj_detect_head.bbox.bias",
}

mapping_attr = {
    "cross_stitch_net.models_b.0.0.weight": "alex.features.0.weight",
    "cross_stitch_net.models_b.0.0.bias": "alex.features.0.bias",
    "cross_stitch_net.models_b.1.0.weight": "alex.features.3.weight",
    "cross_stitch_net.models_b.1.0.bias": "alex.features.3.bias",
    "cross_stitch_net.models_b.2.0.weight": "alex.features.6.weight",
    "cross_stitch_net.models_b.2.0.bias": "alex.features.6.bias",
    "cross_stitch_net.models_b.2.2.weight": "alex.features.8.weight",
    "cross_stitch_net.models_b.2.2.bias": "alex.features.8.bias",
    "cross_stitch_net.models_b.2.4.weight": "alex.features.10.weight",
    "cross_stitch_net.models_b.2.4.bias": "alex.features.10.bias",
    "cross_stitch_classifier.branch_b.1.weight": "roi_module.classifier.1.weight",
    "cross_stitch_classifier.branch_b.1.bias": "roi_module.classifier.1.bias",
    "cross_stitch_classifier.branch_b.4.weight": "roi_module.classifier.4.weight",
    "cross_stitch_classifier.branch_b.4.bias": "roi_module.classifier.4.bias",
    "model_attribute.attr_score.weight": "attribute_head.attr_score.weight",
    "model_attribute.attr_score.bias": "attribute_head.attr_score.bias",
}


for name, param in model_cross.named_parameters():
    if name in mapping_obj:
        source_name = mapping_obj[name]
        param.data.copy_(best_model_obj[source_name].data)

for name, param in model_cross.named_parameters():
    if name in mapping_attr:
        source_name = mapping_attr[name]
        param.data.copy_(best_model_attr[source_name].data)