In [2]:
%reload_ext autoreload
%autoreload 2

import sys, os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
sys.path.append(os.path.expanduser("~")+"/dotfiles")
import myutils

import time, random, pickle, logging, uuid, collections, copy, json

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scipy, sklearn
from functools import partial

import torch
use_cuda, device, n_gpu = myutils.print_torch(torch)

import transformers
myutils.print_packages(transformers)
from transformers import *
from datasets import *

SEED = 42
myutils.seed_everything(SEED, random, os, np, torch)
cache_dir = '~/.cache'

from mammo2text import BreastImgTextDataset, MammoImgTextDataCollator, get_target_text, get_mammo2text_model, \
    generate_predictions, AddParametersToMlflowCallback, generate, load_image_model, EncoderWrapper
from sklearn.metrics import f1_score

python: 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0]
torch: 1.7.1
use_cuda: False, n_gpu: 0, device: cpu, devices:[]
transformers: 4.2.0


In [3]:
from PIL import Image
import math 
from sklearn.preprocessing import MinMaxScaler
import gc

# Eval

In [10]:
model_name = '1_att_cnn_niar4'

parameters = {}
parameters = myutils.json_load(f'models/{model_name}/parameters.json')
parameters['debug'] = False

tokenizer = BertTokenizerFast.from_pretrained(parameters["tokenizer_path"])
print('Vocab size:', len(tokenizer.get_vocab()))

random_number_generator = np.random.RandomState()

exam_list_eval = myutils.json_load(parameters["eval_json_path"]) #[:10]

##### Select best rated logic
res = []
for x in best_rated_cases[:10]:
    for y in exam_list_eval:
        if y['study_uid'] == x:
            res.append(y)
exam_list_eval = res           
##############################


eval_dataset = BreastImgTextDataset(
    exam_list=exam_list_eval,
    tokenizer=tokenizer,
    parameters=parameters, 
    random_number_generator=random_number_generator, 
    swap=False
)

data_collator = MammoImgTextDataCollator(tokenizer=tokenizer)

Vocab size: 40003


In [9]:
label_df = myutils.excel.read('data/label2/all_exp.xlsx')
best_model = 'pred_1_cnn_decoder_att'
model3 = 'pred_1_random_cnn_random_decoder'

In [24]:
a = label_df[label_df['type'].isin(['TARGET', best_model, model3])][['type', 'Предсказание', 'Общая оценка от 1-10', 'case_id', 'Номер случая', 'Номер строки']]

b = a[a['type']==best_model].sort_values('Общая оценка от 1-10', ascending=False)
b.set_index('case_id', inplace=True)

b2 = a[a['type']=='TARGET']
b2.set_index('case_id', inplace=True)
b2 = b2.rename(columns={x: 'target-'+x for x in b2.columns})

b3 = a[a['type']==model3]
b3.set_index('case_id', inplace=True)
b3 = b3.rename(columns={x: 'model-bad-'+x for x in b3.columns})

c = pd.merge(b, b3, left_index=True, right_index=True)
c = pd.merge(c, b2, left_index=True, right_index=True)
c['picture_id'] = range(len(c))

myutils.excel.save(c, 'data/label_viz/predictions2.xlsx', long_columns=['Предсказание', 'model-bad-Предсказание', 'target-Предсказание'])

In [111]:
best_rated_cases = c.index.values

In [114]:
# myutils.display_df(b[['Предсказание']].iloc[:10])

In [1]:
model = get_mammo2text_model(parameters, tokenizer, device, load_image_model, EncoderWrapper)
model.load_state_dict(torch.load(f'models/{parameters["model_name"]}/pytorch_model.bin'))
model.eval()
model = model.to(device)

In [None]:
def register_hooks(model):
    hook_nodes = {
        'decoder.bert.encoder.layer.10.crossattention.self.dropout', 
        'decoder.bert.encoder.layer.11.crossattention.self.dropout',
#         'encoder.model.four_view_net.net._conv_head',
#          'encoder.model.four_view_net.net._conv_head.static_padding',
#          'encoder.model.four_view_net.net._bn1'
    }

    activations = collections.defaultdict(list)
    def save_activation(name, mod, inp, out):
        activations[name] = out.cpu()

    for name, m in model.named_modules():
        if name in hook_nodes:
            handle = m.register_forward_hook(partial(save_activation, name))
    return activations
    
model_with_hooks = copy.deepcopy(model)
activations = register_hooks(model_with_hooks)

## Feature maps

order:

`hh = torch.stack([h[VIEWS.L_CC], h[VIEWS.R_CC], h[VIEWS.L_MLO], h[VIEWS.R_MLO]], -1)`

In [47]:
def plot_hist(a, bins=50):
    if isinstance(a, torch.Tensor):
        a = a.cpu().detach().numpy()
    _ = plt.hist(a, bins=bins)
    plt.show()
    
def get_named_modules(model):
    return [x for x,y in model.named_modules()]

In [48]:
def viz_array(t, save=False):
    if isinstance(t, torch.Tensor):
        t = t.cpu().detach().numpy()
    fig = plt.figure()
    plt.imshow(t , cmap='gray')
    if save:
        _ = plt.savefig('viz/picture.png')
        plt.close(fig)
    else:
        plt.show()

In [2]:
sample_id=0
sample = eval_dataset[sample_id]
image = os.path.join('', exam_list_eval[sample_id]['L-CC'][0])
image_array = np.array(Image.open(image))
viz_array(image_array)
viz_array(eval_dataset[sample_id]['L-CC'][0])

In [17]:
def viz_in_subplots(img_array, file_name='viz/subplots.png'):
    plt.ioff()
    fig = plt.figure()
    square = math.ceil(len(img_array) ** (1/2))
    ix = 1
    while ix-1 < len(img_array):
        # specify subplot and turn of axis
        ax = plt.subplot(square, square, ix)
        ax.set_xticks([])
        ax.set_yticks([])
        # img_array = feature_maps[0, ix-1, :, :]
        _ = plt.imshow(img_array[ix-1], cmap='gray', vmin=0, vmax=12)
        ix += 1
    _ = plt.savefig(file_name, dpi=1000)
    plt.close(fig)
    print(f'Saved as {file_name}')

## Full

In [49]:
# temp_att = np.array([
#                 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#                 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#                 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#                 [0, 0, 0, 0, 0, 0, 0.9, 0.9, 0, 0],
#                 [0, 0, 0, 0, 0, 0.9, 1, 1, 0, 0],
#                 [0, 0, 0, 0, 0, 0.9, 1, 1, 0, 0],
#                 [0, 0, 0, 0, 0, 0, 0.9, 0.9, 0, 0],
#                 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#                 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#                 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
#             ])

In [12]:
gl = None
def plot_attention(images, words, attention_plot, file_name='viz/attention.png', 
                   cmap='jet', interpolation='spline16', alpha=None, quality_factor=1, 
                   use_add_alpha=True, image_id=0):
    global gl
    plt.ioff()
    len_words = len(words) 
    len_images = len(images)
    layout_words = len_words+1
    layout_images = len_images
    fig = plt.figure(figsize=(layout_words*1.5*quality_factor, layout_images*3*quality_factor))
    pbar = tqdm(total=len_words*len_images)
    for img_index in range(len_images):
        temp_image = images[img_index]
        i = layout_words*img_index + 1
        ax = fig.add_subplot(layout_images, layout_words, i)
        img = ax.imshow(temp_image, cmap='gray')
        plt.axis('off')
    attention_plot = attention_plot.sum(0)
    attention_plot = np.expand_dims(attention_plot, 0)
    print(attention_plot.shape)
    for word_index in range(len_words):
        word_min = attention_plot[word_index,:,:,:].min()
        word_max = attention_plot[word_index,:,:,:].max()
        for img_index in range(len_images):
            temp_image = images[img_index]
            temp_att = attention_plot[word_index,:,:, img_index]
            i = layout_words*img_index + word_index + 2
            ax = fig.add_subplot(layout_images, layout_words, i)
            plt.axis('off')
            if img_index == 0:
                ax.set_title(result[word_index])
            img = ax.imshow(temp_image, cmap='gray')
            
            p = np.percentile(temp_att, 50)
            temp_att[temp_att < p] = p
            
            _alpha = alpha
            if use_add_alpha:
                _alpha = np.ones(temp_att.shape) * alpha
                scaler = MinMaxScaler((-0.2, 0.1))
                additional_alpha = scaler.fit_transform(temp_att)
                _alpha += additional_alpha

            ax.imshow(temp_att, interpolation=interpolation, cmap=cmap, alpha=_alpha, 
                      extent=img.get_extent(), vmin=word_min, vmax=word_max) #, vmin=word_min, vmax=word_max , vmin=-0.1, vmax=1.1 , vmin=0, vmax=1 # , vmin=temp_att.min(), vmax=temp_att.max()
            pbar.update(1)
    fig.tight_layout()
    _ = plt.savefig(f'viz/{image_id}_att.png', dpi=100)
    pbar.close()
    plt.close(fig)
#     plt.show()