In [None]:
import os
import sys
import argparse

print(os.getcwd())
os.chdir("../../")
os.getcwd()

sys.argv = ["view", "--config", "config/multi_task_cross_stitch.yaml"]

parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path to the config file")
args = parser.parse_args()

print(args.config)

In [None]:
from config_experiments import config
from torchvision.transforms import transforms
from dataloader import VOC08Attr
import matplotlib.pyplot as plt
from model import (
    ObjectDetectionModel,
    AttributePredictionModel,
    CrossStitchNet,
    AttributePredictionHead,
    ObjectDetectionHead,
)
from utils import set_device
import torch
from bbox_transform import resize_bounding_boxes, apply_nms
import matplotlib.patches as patches
import torchvision
import torch.nn as nn

In [None]:
device = set_device(config["global"]["gpu_id"])

In [None]:
path_best_model_obj = "../dl_project/experiments/object_detection/2024-07-28_19-23-43/models/best_model_epoch_94.pth"
path_best_model_attr = "../dl_project/experiments/attribute_prediction/2024-07-29_17-54-07/models/best_model_epoch_20.pth"
model_obj = ObjectDetectionModel().to(device)
model_attr = AttributePredictionModel().to(device)
model_obj.load_state_dict(torch.load(path_best_model_obj, map_location=device))
model_attr.load_state_dict(torch.load(path_best_model_attr, map_location=device))
model_cross = CrossStitchNet().to(device)


def copy_weights(src_layers, dst_layers):
    src_idx = 0
    dst_idx = 0
    while src_idx < len(src_layers) and dst_idx < len(dst_layers):
        if isinstance(dst_layers[dst_idx], nn.Conv2d):
            if isinstance(src_layers[src_idx], nn.Conv2d):
                dst_layers[dst_idx].weight.data = src_layers[
                    src_idx
                ].weight.data.clone()
                if src_layers[src_idx].bias is not None:
                    dst_layers[dst_idx].bias.data = src_layers[
                        src_idx
                    ].bias.data.clone()
            src_idx += 1
        dst_idx += 1


for model_a, model_b in zip(
    model_cross.cross_stitch_net.models_a, model_cross.cross_stitch_net.models_b
):
    copy_weights(model_obj.alex.features, model_a)
    copy_weights(model_attr.alex.features, model_b)

model_cross.roi_a.load_state_dict(model_obj.roi_module.state_dict())
model_cross.roi_b.load_state_dict(model_attr.roi_module.state_dict())

model_cross.model_obj_detect.load_state_dict(model_obj.obj_detect_head.state_dict())
model_cross.model_attribute.load_state_dict(model_attr.attribute_head.state_dict())

In [None]:
for name, layer in model_attr.named_parameters():
    print(name, layer.shape)

In [None]:
for name, layer in model_cross.named_parameters():
    print(name, layer.shape)