In [None]:
import yaml
# System libs
import os, csv, torch, numpy, scipy.io, PIL.Image, torchvision.transforms
import torch.nn as nn

# Our libs
from semseg.models import ModelBuilder, SegmentationModule
from semseg.utils import colorEncode

colors = scipy.io.loadmat('data/mit_data/color150.mat')['colors']
names = {}
with open('data/mit_data/object150_info.csv') as f:
    reader = csv.reader(f)
    next(reader)
    for row in reader:
        names[int(row[0])] = row[5].split(";")[0]

def visualize_result(img, pred, index=None):
    # filter prediction class if requested
    if index is not None:
        pred = pred.copy()
        pred[pred != index] = -1
        print(f'{names[index+1]}:')
        
    # colorize prediction
    pred_color = colorEncode(pred, colors).astype(numpy.uint8)

    # aggregate images and save
    im_vis = numpy.concatenate((img, pred_color), axis=1)
    display(PIL.Image.fromarray(im_vis))
    
    
def parse_model_config(path):
    with open(path) as file:
        data = yaml.load(file, Loader=yaml.FullLoader)
    
    encoder_path = None
    decoder_path = None

    for p in os.listdir(data['DIR']):
        if "encoder" in p.lower():
            encoder_path = "{}/{}".format(data['DIR'], p)
            continue
        if "decoder" in p.lower():
            decoder_path = "{}/{}".format(data['DIR'], p)
            continue

    if encoder_path==None or decoder_path==None:
        raise("model weights not found")
        
    return data, encoder_path, decoder_path

def cal_weight(tensor, l):
    weights = [] 
    b, c, w, h = tensor.shape
    ind = 0 # the index of the image in the sequence with gt
    for i in range(l):
        if i % seq_len == 0:
            ind = i
        weights.append(torch.sum(cos(tensor[i], tensor[ind]))/(w * h))
    return weights

'''
def get_activation(name, activation):
    def hook(model,input, output):
        activation[name] = output.detach()
    return hook
'''

In [None]:
'''
model_config, encoder_path, decoder_path = parse_model_config("config/bodypart-hrnetv2.yaml")
net_encoder = ModelBuilder.build_encoder(
    arch = model_config["MODEL"]['arch_encoder'],
    fc_dim = model_config['MODEL']['fc_dim'],
    weights = encoder_path)
net_decoder = ModelBuilder.build_decoder(
    arch = model_config["MODEL"]['arch_decoder'],
    fc_dim = model_config['MODEL']['fc_dim'],
    num_class = model_config['DATASET']['num_class'],
    weights = decoder_path,
    use_softmax=True)

crit = torch.nn.NLLLoss(ignore_index=-1)
segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
segmentation_module.eval()
segmentation_module.cuda()
'''
print()

In [None]:
# Network Builders
net_encoder = ModelBuilder.build_encoder(
    arch='hrnetv2',
    fc_dim=2048,
    weights='ckpt/bodypart-hrnetv2-c1/encoder_epoch_30.pth')
net_decoder = ModelBuilder.build_decoder(
    arch='c1',
    fc_dim=2048,
    num_class=150,
    weights='ckpt/bodypart-hrnetv2-c1/encoder_epoch_30.pth',
    use_softmax=True)

crit = torch.nn.NLLLoss(ignore_index=-1)
segmentation_module = SegmentationModule(net_encoder, net_decoder, crit, batch_size=2)
segmentation_module.eval();
# segmentation_module.cuda();

In [None]:
len(tmp)

In [None]:
cos = nn.CosineSimilarity(dim=0, eps=1e-6)

In [None]:
for name, m in segmentation_module.encoder.named_children():
    print(name, m)

In [None]:
hidden_layer_names = ['cbr', 'conv_last', 'cbr']

In [None]:
print("{}".format("hello"))

In [1]:
a = [1,2,3]
b = [4,5,6]
c = [1,2,3]

d = [a,b,c]

# zf = zip(a,b,c)
zf = zip(*d)

In [2]:
a.extend(b)
a

[1, 2, 3, 4, 5, 6]

In [4]:
a.append(10)
a

[1, 2, 3, 4, 5, 6, 10]

In [None]:
for val in zf:
    print(val)

In [None]:
activation = {}
for name, m in segmentation_module.decoder.named_children():
    if name in hidden_layer_names:
        m.register_forward_hook(get_activation('{}'.format(name), activation))
        print(name)

In [None]:
def get_activation(name, activation):
    def hook(model,input, output):
        try: 
            activation[name] = output.detach()
        except:
            activation[name] = []
            for out in output:
                activation[name].append(out.detach())
    return hook

def register_hooks(model, module_names, activation, show=False):
    for name, module in model.named_children():
        if name in module_names:
            module.register_forward_hook(get_activation('{}'.format(name), activation))
            if show: print(name)

In [None]:
activation = {}
for name, module in segmentation_module.encoder.named_children():
    print(name)
    module.register_forward_hook(get_activation('{}'.format(name), activation))

tmp = segmentation_module.encoder(torch.rand(6, 3, 50, 50,  device='cpu'))

In [None]:
activation.keys()

In [None]:
for val in activation['stage4']:
    print(val.shape)

In [None]:
activation = {}
for name, module in segmentation_module.decoder.named_children():
    print(name)
    module.register_forward_hook(get_activation('{}'.format(name), activation))

In [None]:
tmp = segmentation_module.decoder(torch.rand(512, 2048, 3, 3,  device='cpu'))

In [None]:
#decoder.cbr.register_forward_hook(get_activation('cbr'))
#decoder.conv_last.register_forward_hook(get_activation('conv_last'))

In [None]:
for name, m in segmentation_module.decoder.named_children():
    print(name)
    print(m)
    print()

In [None]:
m

In [None]:
encoder_weight_type = "transition3,stage4"
decoder_weight_type = "cbr,conv_last"

In [None]:
decoder_wt = decoder_weight_type.split(',')
decoder_wt

In [None]:
register_hooks(segmentation_module.decoder, decoder_wt, activation, True)

In [None]:
encoder_wt = encoder_weight_type.split(',')
encoder_wt

In [None]:
"tmp".split(',')

In [None]:
zip([1,2],[1,2],[1,2])

In [None]:
a = torch.tensor([[1., -1.], [1., -1.]])
b = torch.tensor([[1., -1.,2], [1., -1.,2]])
a.shape

In [None]:
b.shape

In [None]:
c = torch.cat([a, b, a], 1)
c.shape