In [None]:
!pip install torch-lucent

In [None]:
from torch._C import device
import torch
from torch import nn
from adet.config import get_cfg
from modules.solov2 import SOLOv2
from modules.reconstructor import Reconstructor
import matplotlib.pyplot as plt
import argparse
import os
import warnings
from detectron2.utils.logger import setup_logger
import glob
import time
import torchvision.transforms as transforms
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.data.detection_utils import read_image
from run import Editor
from utils import visualize_kernels
from lucent.modelzoo.util import get_model_layers
from lucent.optvis import render, param, transform, objectives

In [None]:
def setup_cfg(args):
    # load config from file and command-line arguments
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg

## Before running, remember to update the args!
def get_parser():
    parser = argparse.ArgumentParser(description="SOLOv2 Editor")
    parser.add_argument(
        "--config-file",
        default="../configs/R50_3x.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument("--input", nargs="+",default=['../inputs/bg.jpg'], help="A list of space separated test images")
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=['MODEL.WEIGHTS', '../SOLOv2_R50_3x.pth'],
        nargs=argparse.REMAINDER,
    )
    parser.add_argument(
        "--PATH",
        help="Path of the saved editor",
        default='../checkpoints/editor_grouped.pth',
        type=str
    )
    return parser

In [None]:
%tb
args, unknown = get_parser().parse_known_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg = setup_cfg(args)

solo = SOLOv2(cfg=cfg).to(device)
checkpointer = DetectionCheckpointer(solo)
checkpointer.load(cfg.MODEL.WEIGHTS)

for param in solo.parameters():
    param.requires_grad = False

image = torch.rand(3,64,64)
batched_input = []
batched_input.append(image)
r,_ = solo(batched_input)

reconstructor = Reconstructor(in_channels=r.shape[1])

editor_demo =Editor(solo, reconstructor)
editor_demo.load_state_dict(torch.load(args.PATH))
editor_demo.to(device).eval()

layers = get_model_layers(editor_demo)[198:]

for l in layers:
    print(l)

_ = render.render_vis(editor_demo, "reconstructor_encoder_conv1:45", show_inline=True)