# 1. Preparation

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from IPython.display import HTML, display
from myutils.pytorch import module_util
from models import get_model
from models.org import rcnn

# 2. Functions to compare I/O shapes

In [3]:
def check_if_shape_match(xt, xs):
    if type(xt) != type(xs):
        return False
    elif isinstance(xt, dict):
        if xt.keys() != xs.keys():
            return False
        for t_value, s_value in zip(xt.values(), xs.values()):
            if not check_if_shape_match(t_value, s_value):
                return False
        return True
    elif isinstance(xt, (list, tuple)):
        for t_value, s_value in zip(xt, xs):
            if not check_if_shape_match(t_value, s_value):
                return False
        return True
    elif isinstance(xt, torch.Tensor):
        return xt.shape == xs.shape
    return xt == xs


def check_if_io_shape_match(teacher_ios, student_ios):
    input_flag = True
    output_flag = True
    for (teacher_input, teacher_output), (student_input, student_output) in zip(teacher_ios, student_ios):
        if not check_if_shape_match(teacher_input, student_input):
            input_flag = False
            
        if not check_if_shape_match(teacher_output, student_output):
            output_flag = False
        
        if input_flag == output_flag and not input_flag:
            return input_flag, output_flag
    return input_flag, output_flag


def convert2shape(x):
    if isinstance(x, torch.Tensor):
        return x.shape
    elif isinstance(x, dict):
        y = dict()
        for key, value in x.items():
            y[key] = convert2shape(value)
        return y
    elif isinstance(x, (list, tuple)):
        is_list = isinstance(x, list)
        return [convert2shape(z) for z in x] if is_list else tuple([convert2shape(z) for z in x])
    return x


def extract_if_single(x):
    if isinstance(x, (list, tuple)) and len(x) == 1 and isinstance(x[0], torch.Tensor):
        return x[0]
    return x


def convert2str(teacher_ios, student_ios):
    teacher_input_list, teacher_output_list = list(), list()
    student_input_list, student_output_list = list(), list()
    for (teacher_input, teacher_output), (student_input, student_output) in zip(teacher_ios, student_ios):
        teacher_input_list.append(extract_if_single(teacher_input))
        teacher_output_list.append(extract_if_single(teacher_output))
        student_input_list.append(extract_if_single(student_input))
        student_output_list.append(extract_if_single(student_output))
    return convert2shape(teacher_input_list), convert2shape(student_input_list), convert2shape(teacher_output_list), convert2shape(student_output_list)


def build_header(header_list=['Teacher path', 'Student Path', 'Input shape', 'Output shape']):
    multi_column_tag = ''.join(['<td colspan="3" style="text-align:center;">{}</td>'.format(title) for title in header_list[-2:]])
    return '<td style="text-align:center">' + '</td><td style="text-align:center">'.join(header_list[:-2]) + '</td>' + multi_column_tag


def build_color_row(flag, teacher_sample, student_sample, match_color, error_color):
    color = match_color if flag else error_color
    color_css = 'style="background-color:{};"'
    ext_css = 'style="background-color:{};white-space: nowrap;"'
    return ''.join(['<td {}>{}</td>'.format(color_css.format(color) if i != 1 else ext_css.format(color), x) for i, x in enumerate([teacher_sample, '<->', student_sample])])


def build_row(row, match_color='green', error_color='red'):
    input_tuple, output_tuple = row[-2:]
    input_str = build_color_row(*input_tuple, match_color, error_color)
    output_str = build_color_row(*output_tuple, match_color, error_color)
    return '<td>' + '</td><td>'.join(row[:-2]) + '</td>' + input_str + output_str
    

def compare_io_shapes(sample_batch, teacher_model, student_model, ts_path_dict):
    teacher_model.eval()
    student_model.eval()
    teacher_paths = list(ts_path_dict.keys())
    student_paths = list(ts_path_dict.values())
    teacher_io_dict = module_util.extract_intermediate_io(sample_batch, teacher_model, teacher_paths)
    student_io_dict = module_util.extract_intermediate_io(sample_batch, student_model, student_paths)
    row_list = []
    for teacher_path, teacher_ios in teacher_io_dict.items():
        student_path = ts_path_dict[teacher_path]
        student_ios = student_io_dict[student_path]
        input_flag, output_flag = check_if_io_shape_match(teacher_ios, student_ios)
        teacher_input_str, student_input_str, teacher_output_str, student_output_str = convert2str(teacher_ios, student_ios)
        row_list.append([teacher_path, student_path, (input_flag, teacher_input_str, student_input_str), (output_flag, teacher_output_str, student_output_str)])
        
    header = build_header()
    html_tag = '<table><tr style="font-weight:bold">{}</tr><tr>{}</tr></table>'.format(header, '</tr><tr>'.join([build_row(row) for row in row_list]))
    display(HTML(html_tag))

# 3. Faster R-CNN

In [4]:
faster_rcnn_resnet18 = rcnn.get_model('faster_rcnn', False, 'resnet18', False)
faster_rcnn_resnet34 = rcnn.get_model('faster_rcnn', False, 'resnet34', False)
faster_rcnn_resnet50 = rcnn.get_model('faster_rcnn', False, 'resnet50', False)

## 3.1 ResNet-34 vs. ResNet-18

In [5]:
path_dict = {'backbone.fpn': 'backbone.fpn',
             'roi_heads.box_roi_pool': 'roi_heads.box_roi_pool',
             'roi_heads.box_predictor': 'roi_heads.box_predictor'}
compare_io_shapes(torch.rand(1, 3, 400, 600), faster_rcnn_resnet34, faster_rcnn_resnet18, path_dict)

0,1,2,3,4,5,6,7
Teacher path,Student Path,Input shape,Input shape,Input shape,Output shape,Output shape,Output shape
backbone.fpn,backbone.fpn,"[({0: torch.Size([1, 64, 200, 304]), 1: torch.Size([1, 128, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 512, 25, 38])},)]",<->,"[({0: torch.Size([1, 64, 200, 304]), 1: torch.Size([1, 128, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 512, 25, 38])},)]","[{0: torch.Size([1, 256, 200, 304]), 1: torch.Size([1, 256, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 256, 25, 38]), 'pool': torch.Size([1, 256, 13, 19])}]",<->,"[{0: torch.Size([1, 256, 200, 304]), 1: torch.Size([1, 256, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 256, 25, 38]), 'pool': torch.Size([1, 256, 13, 19])}]"
roi_heads.box_roi_pool,roi_heads.box_roi_pool,"[({0: torch.Size([1, 256, 200, 304]), 1: torch.Size([1, 256, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 256, 25, 38]), 'pool': torch.Size([1, 256, 13, 19])}, [torch.Size([1000, 4])], [(800, 1200)])]",<->,"[({0: torch.Size([1, 256, 200, 304]), 1: torch.Size([1, 256, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 256, 25, 38]), 'pool': torch.Size([1, 256, 13, 19])}, [torch.Size([1000, 4])], [(800, 1200)])]","[torch.Size([1000, 256, 7, 7])]",<->,"[torch.Size([1000, 256, 7, 7])]"
roi_heads.box_predictor,roi_heads.box_predictor,"[torch.Size([1000, 1024])]",<->,"[torch.Size([1000, 1024])]","[(torch.Size([1000, 91]), torch.Size([1000, 364]))]",<->,"[(torch.Size([1000, 91]), torch.Size([1000, 364]))]"


## 3.2 ResNet-50 vs. ResNet-18

In [6]:
path_dict = {'backbone.fpn': 'backbone.fpn',
             'roi_heads.box_roi_pool': 'roi_heads.box_roi_pool',
             'roi_heads.box_predictor': 'roi_heads.box_predictor'}
compare_io_shapes(torch.rand(1, 3, 400, 600), faster_rcnn_resnet50, faster_rcnn_resnet18, path_dict)

0,1,2,3,4,5,6,7
Teacher path,Student Path,Input shape,Input shape,Input shape,Output shape,Output shape,Output shape
backbone.fpn,backbone.fpn,"[({0: torch.Size([1, 256, 200, 304]), 1: torch.Size([1, 512, 100, 152]), 2: torch.Size([1, 1024, 50, 76]), 3: torch.Size([1, 2048, 25, 38])},)]",<->,"[({0: torch.Size([1, 64, 200, 304]), 1: torch.Size([1, 128, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 512, 25, 38])},)]","[{0: torch.Size([1, 256, 200, 304]), 1: torch.Size([1, 256, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 256, 25, 38]), 'pool': torch.Size([1, 256, 13, 19])}]",<->,"[{0: torch.Size([1, 256, 200, 304]), 1: torch.Size([1, 256, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 256, 25, 38]), 'pool': torch.Size([1, 256, 13, 19])}]"
roi_heads.box_roi_pool,roi_heads.box_roi_pool,"[({0: torch.Size([1, 256, 200, 304]), 1: torch.Size([1, 256, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 256, 25, 38]), 'pool': torch.Size([1, 256, 13, 19])}, [torch.Size([1000, 4])], [(800, 1200)])]",<->,"[({0: torch.Size([1, 256, 200, 304]), 1: torch.Size([1, 256, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 256, 25, 38]), 'pool': torch.Size([1, 256, 13, 19])}, [torch.Size([1000, 4])], [(800, 1200)])]","[torch.Size([1000, 256, 7, 7])]",<->,"[torch.Size([1000, 256, 7, 7])]"
roi_heads.box_predictor,roi_heads.box_predictor,"[torch.Size([1000, 1024])]",<->,"[torch.Size([1000, 1024])]","[(torch.Size([1000, 91]), torch.Size([1000, 364]))]",<->,"[(torch.Size([1000, 91]), torch.Size([1000, 364]))]"


## 3.3 ResNet-50 vs. ResNet-34

In [7]:
path_dict = {'backbone.fpn': 'backbone.fpn',
             'roi_heads.box_roi_pool': 'roi_heads.box_roi_pool',
             'roi_heads.box_predictor': 'roi_heads.box_predictor'}
compare_io_shapes(torch.rand(1, 3, 400, 600), faster_rcnn_resnet50, faster_rcnn_resnet34, path_dict)

0,1,2,3,4,5,6,7
Teacher path,Student Path,Input shape,Input shape,Input shape,Output shape,Output shape,Output shape
backbone.fpn,backbone.fpn,"[({0: torch.Size([1, 256, 200, 304]), 1: torch.Size([1, 512, 100, 152]), 2: torch.Size([1, 1024, 50, 76]), 3: torch.Size([1, 2048, 25, 38])},)]",<->,"[({0: torch.Size([1, 64, 200, 304]), 1: torch.Size([1, 128, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 512, 25, 38])},)]","[{0: torch.Size([1, 256, 200, 304]), 1: torch.Size([1, 256, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 256, 25, 38]), 'pool': torch.Size([1, 256, 13, 19])}]",<->,"[{0: torch.Size([1, 256, 200, 304]), 1: torch.Size([1, 256, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 256, 25, 38]), 'pool': torch.Size([1, 256, 13, 19])}]"
roi_heads.box_roi_pool,roi_heads.box_roi_pool,"[({0: torch.Size([1, 256, 200, 304]), 1: torch.Size([1, 256, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 256, 25, 38]), 'pool': torch.Size([1, 256, 13, 19])}, [torch.Size([1000, 4])], [(800, 1200)])]",<->,"[({0: torch.Size([1, 256, 200, 304]), 1: torch.Size([1, 256, 100, 152]), 2: torch.Size([1, 256, 50, 76]), 3: torch.Size([1, 256, 25, 38]), 'pool': torch.Size([1, 256, 13, 19])}, [torch.Size([1000, 4])], [(800, 1200)])]","[torch.Size([1000, 256, 7, 7])]",<->,"[torch.Size([1000, 256, 7, 7])]"
roi_heads.box_predictor,roi_heads.box_predictor,"[torch.Size([1000, 1024])]",<->,"[torch.Size([1000, 1024])]","[(torch.Size([1000, 91]), torch.Size([1000, 364]))]",<->,"[(torch.Size([1000, 91]), torch.Size([1000, 364]))]"
