In [1]:
import datetime
import random
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.backends import cudnn
import torch.nn as nn
from torcheval.metrics.functional import multiclass_accuracy

import segmentation_models_pytorch as smp
import segmentation_models_pytorch.utils

import cv2
from keras.utils import to_categorical

import PIL.Image

from os.path import join as pjoin
import os

from tqdm import tqdm
import json

In [2]:
from utils import test_i_sample, test_x_sample
from my_dataset import MyDataset
from tf_callback import images_to_probs, plot_classes_preds, plot_confusion_matrix, plot_to_image
from metrics import runningScore

In [3]:
model_name = 'model_0407_1009'

In [4]:
print(model_name)

model_0407_1009


In [5]:
params_dict = None

In [6]:
with open(pjoin('models', model_name, 'params.json'), 'r') as f:
    params_dict = json.load(f)

In [7]:
num_classes = 6

In [8]:
# python_seed = 245
# np_seed = 123
# torch_seed = 321
# torch_cuda_seed = 111
# learning_rate = 0.001
step = params_dict['step']
split_test = params_dict['split_test']
path_data = 'data'

In [9]:
im_shape=(401,701)
iline, xline = im_shape
i_locations = np.arange(0, iline, step)
x_locations = np.arange(0, xline, step)

In [10]:
test_i_mask = [i not in i_locations for i in range(0, iline)]
test_x_mask = [x not in x_locations for x in range(0, xline)]

In [11]:
seismic = np.load(pjoin(path_data,'train','train_seismic.npy'))
labels  = np.load(pjoin(path_data,'train','train_labels.npy' ))

In [12]:
uniq, cnts = np.unique(labels, return_counts=True)

In [13]:
prcnts = np.round(100*cnts/np.sum(cnts),2)

In [14]:
for p in prcnts:
    print(f'{p}%', end='\t')

28.09%	11.89%	48.59%	6.64%	3.28%	1.51%	

In [15]:
dict(zip(uniq, cnts))

{0: 20137839, 1: 8519666, 2: 34831122, 3: 4760778, 4: 2350150, 5: 1081200}

In [16]:
test_seismic_ix = seismic[test_i_mask][:,test_x_mask]
test_labels_ix = labels[test_i_mask][:,test_x_mask]

In [17]:
del seismic
del labels

In [18]:
test_data_shape = test_seismic_ix.shape

In [19]:
print(test_data_shape)

(392, 686, 255)


In [20]:
# Создаем U-Net модель с энкодером resnet34
model = smp.Unet(
    encoder_name="resnet18", # resnet18 
    # encoder_weights="swsl", # можно обучать с нуля
    in_channels=1,
    classes=6,  # Количество классов для сегментации
)

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [22]:
model = model.to(device)

In [23]:
model.load_state_dict(torch.load(pjoin('models', model_name, f'best_0_{model_name}.pth')))

<All keys matched successfully>

In [24]:
test_data_shape

(392, 686, 255)

In [25]:
pred_labels_cat_i = np.zeros((test_data_shape[0], test_data_shape[1], test_data_shape[2], 6))

In [26]:
for i in tqdm(range(test_data_shape[0])):
    im = test_seismic_ix[i]
    lbl = to_categorical(test_labels_ix[i], num_classes=6)
    _, lbl_cat, pred_lbl_cat = test_i_sample(model, im, lbl, False)
    pred_labels_cat_i[i] = pred_lbl_cat

100%|██████████| 392/392 [06:22<00:00,  1.02it/s]


In [27]:
np.save(pjoin('models', model_name, 'pred_labels_cat_i.npy'), pred_labels_cat_i)

In [28]:
pred_labels_cat_x = np.zeros((test_data_shape[0], test_data_shape[1], test_data_shape[2], 6))

In [29]:
for i in tqdm(range(test_data_shape[1])):
    im = test_seismic_ix[:,i]
    lbl = to_categorical(test_labels_ix[:,i], num_classes=6)
    _, lbl_cat, pred_lbl_cat = test_x_sample(model, im, lbl, False)
    pred_labels_cat_x[:,i,:,:] = pred_lbl_cat

100%|██████████| 686/686 [05:38<00:00,  2.03it/s]


In [30]:
np.save(pjoin('models', model_name, 'pred_labels_cat_x.npy'), pred_labels_cat_x)

In [31]:
loss = [smp.utils.losses.DiceLoss(), nn.CrossEntropyLoss()]
metrics = [smp.utils.metrics.IoU(threshold=0.5)]

In [32]:
# pred_labels_cat_i = np.load(pjoin('models', model_name, 'pred_labels_cat_i.npy'))
# pred_labels_cat_x = np.load(pjoin('models', model_name, 'pred_labels_cat_x.npy'))

In [33]:
pred_labels_cat_i.shape

(392, 686, 255, 6)

In [34]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(pjoin('models', model_name, 'runs'))

In [35]:
# for i in tqdm(range(test_data_shape[0])):
#     lbl = torch.from_numpy(np.expand_dims(np.moveaxis(to_categorical(test_labels_ix[i],num_classes=6), -1, 0), axis=0))
#     pred_i = torch.from_numpy(np.expand_dims(np.moveaxis(pred_labels_cat_i[i], -1, 0), axis=0))
#     pred_x = torch.from_numpy(np.expand_dims(np.moveaxis(pred_labels_cat_x[i], -1, 0), axis=0))

#     li = torch.round(loss[0](pred_i, lbl) + loss[1](pred_i, lbl), decimals=2)
#     mi = torch.round(metrics[0](pred_i, lbl), decimals=2)
#     lx = torch.round(loss[0](pred_x, lbl) + loss[1](pred_x, lbl), decimals=2)
#     mx = torch.round(metrics[0](pred_x, lbl), decimals=2)

#     writer.add_scalar('test/(i_algo) loss along i', li, i)
#     writer.add_scalar('test/(i_algo) metric along i', mi, i)
#     writer.add_scalar('test/(x_algo) loss along i', lx, i)
#     writer.add_scalar('test/(x_algo) metric along i', mx, i)

In [36]:
# for i in tqdm(range(test_data_shape[1])):
#     lbl = torch.from_numpy(np.expand_dims(np.moveaxis(to_categorical(test_labels_ix[:,i],num_classes=6), -1, 0), axis=0))
#     pred_i = torch.from_numpy(np.expand_dims(np.moveaxis(pred_labels_cat_i[:,i], -1, 0), axis=0))
#     pred_x = torch.from_numpy(np.expand_dims(np.moveaxis(pred_labels_cat_x[:,i], -1, 0), axis=0))

#     li = torch.round(loss[0](pred_i, lbl) + loss[1](pred_i, lbl), decimals=2)
#     mi = torch.round(metrics[0](pred_i, lbl), decimals=2)
#     lx = torch.round(loss[0](pred_x, lbl) + loss[1](pred_x, lbl), decimals=2)
#     mx = torch.round(metrics[0](pred_x, lbl), decimals=2)

#     writer.add_scalar('test/(i_algo) loss along x', li, i)
#     writer.add_scalar('test/(i_algo) metric along x', mi, i)
#     writer.add_scalar('test/(x_algo) loss along x', lx, i)
#     writer.add_scalar('test/(x_algo) metric along x', mx, i)
    

In [37]:
# lbl_i = torch.from_numpy(np.moveaxis(to_categorical(test_labels_ix,num_classes=6), -1, 1))
# pred_i = torch.from_numpy(np.moveaxis(pred_labels_cat_i, -1, 1))
# lbl_x = torch.from_numpy(np.moveaxis(to_categorical(test_labels_ix,num_classes=6), (1, -1), (0, 1)))
# pred_x = torch.from_numpy(np.moveaxis(pred_labels_cat_x, (1, -1), (0, 1)))

# li = torch.round(loss[0](pred_i, lbl_i) + loss[1](pred_i, lbl_i), decimals=2)
# mi = torch.round(metrics[0](pred_i, lbl_i), decimals=2)
# lx = torch.round(loss[0](pred_x, lbl_x) + loss[1](pred_x, lbl_x), decimals=2)
# mx = torch.round(metrics[0](pred_x, lbl_x), decimals=2)

# print(f'i_algo: loss={li}, metric={mi}')
# print(f'x_algo: loss={lx}, metric={mx}')

In [38]:
test_labels_ix.shape, pred_labels_cat_i.shape, pred_labels_cat_x.shape

((392, 686, 255), (392, 686, 255, 6), (392, 686, 255, 6))

In [39]:
pred_labels_i = np.argmax(pred_labels_cat_i, axis=-1)
pred_labels_x = np.argmax(pred_labels_cat_x, axis=-1)

In [40]:
pred_labels_i.shape, pred_labels_x.shape

((392, 686, 255), (392, 686, 255))

In [41]:
running_metrics_test = runningScore(num_classes)

In [42]:
metric_names = ['Pixel Acc', 'Mean Class Acc', 'Freq Weighted IoU', 'Mean IoU']

In [43]:
running_metrics_test.update(test_labels_ix, pred_labels_i)
score, class_iu = running_metrics_test.get_scores()
running_metrics_test.reset()

test_res_dict_i = {}

print(model_name)

for i, m in enumerate(metric_names):
    _s = score[f'{m}: ']
    print(f'{m}: \t{_s}')
    test_res_dict_i[f'{m}'] = _s

for i, _ca in enumerate(score['Class Accuracy: ']):
    print(f'Class[{i}] Accuracy:\t{_ca}')
    test_res_dict_i[f'Class[{i}] Accuracy'] = _ca

    print(f'Class[{i}] IoU:     \t{class_iu[i]}')
    test_res_dict_i[f'Class[{i}] IoU'] = class_iu[i]

model_0407_1009
Pixel Acc: 	0.97503898060682
Mean Class Acc: 	0.9406687536408507
Freq Weighted IoU: 	0.9526618634792456
Mean IoU: 	0.890301082017122
Class[0] Accuracy:	0.9848572928154772
Class[0] IoU:     	0.973280509435295
Class[1] Accuracy:	0.9634516348766319
Class[1] IoU:     	0.9008507661921294
Class[2] Accuracy:	0.9860039206468098
Class[2] IoU:     	0.9783819197357092
Class[3] Accuracy:	0.9276622663291406
Class[3] IoU:     	0.8764050879081702
Class[4] Accuracy:	0.9172693426384813
Class[4] IoU:     	0.8029769541838312
Class[5] Accuracy:	0.8647680645385631
Class[5] IoU:     	0.8099112546475968


In [44]:
running_metrics_test.update(test_labels_ix, pred_labels_x)
score, class_iu = running_metrics_test.get_scores()
running_metrics_test.reset()

test_res_dict_x = {}

print(model_name)

for i, m in enumerate(metric_names):
    _s = score[f'{m}: ']
    print(f'{m}: \t{_s}')
    test_res_dict_x[f'{m}'] = _s

for i, _ca in enumerate(score['Class Accuracy: ']):
    print(f'Class[{i}] Accuracy:\t{_ca}')
    test_res_dict_x[f'Class[{i}] Accuracy'] = _ca

    print(f'Class[{i}] IoU:     \t{class_iu[i]}')
    test_res_dict_x[f'Class[{i}] IoU'] = class_iu[i]

model_0407_1009
Pixel Acc: 	0.9744187325075803
Mean Class Acc: 	0.9175861639222856
Freq Weighted IoU: 	0.9513348829192126
Mean IoU: 	0.8721514802682048
Class[0] Accuracy:	0.9885908059999469
Class[0] IoU:     	0.9807383544956961
Class[1] Accuracy:	0.9642730705434569
Class[1] IoU:     	0.898167216086171
Class[2] Accuracy:	0.985776121263594
Class[2] IoU:     	0.9739216316204488
Class[3] Accuracy:	0.9311549817075911
Class[3] IoU:     	0.8837516039708136
Class[4] Accuracy:	0.9354134908324562
Class[4] IoU:     	0.8219182850130268
Class[5] Accuracy:	0.7003085131866689
Class[5] IoU:     	0.6744117904230719


In [45]:
with open(pjoin('models', model_name, 'test_res_dict_i.json'), 'w') as f:
    json.dump(test_res_dict_i, f)
with open(pjoin('models', model_name, 'test_res_dict_x.json'), 'w') as f:
    json.dump(test_res_dict_x, f)

In [47]:
del pred_labels_i
del pred_labels_x

In [48]:
pred_labels_ix = np.argmax((pred_labels_cat_i+pred_labels_cat_x)*0.5, axis=-1)

In [49]:
running_metrics_test.update(test_labels_ix, pred_labels_ix)
score, class_iu = running_metrics_test.get_scores()
running_metrics_test.reset()

test_res_dict_ix = {}

print(model_name)

for i, m in enumerate(metric_names):
    _s = score[f'{m}: ']
    print(f'{m}: \t{_s}')
    test_res_dict_ix[f'{m}'] = _s

for i, _ca in enumerate(score['Class Accuracy: ']):
    print(f'Class[{i}] Accuracy:\t{_ca}')
    test_res_dict_ix[f'Class[{i}] Accuracy'] = _ca

    print(f'Class[{i}] IoU:     \t{class_iu[i]}')
    test_res_dict_ix[f'Class[{i}] IoU'] = class_iu[i]

model_0407_1009
Pixel Acc: 	0.9769334555979826
Mean Class Acc: 	0.9292122315851173
Freq Weighted IoU: 	0.9559228976346963
Mean IoU: 	0.8872687127896667
Class[0] Accuracy:	0.9893585713170517
Class[0] IoU:     	0.9802952320898461
Class[1] Accuracy:	0.965960240992094
Class[1] IoU:     	0.907123297174404
Class[2] Accuracy:	0.9874406566697688
Class[2] IoU:     	0.9771343510075393
Class[3] Accuracy:	0.935316682874205
Class[3] IoU:     	0.8990066040231115
Class[4] Accuracy:	0.9402194691525891
Class[4] IoU:     	0.8281842110834328
Class[5] Accuracy:	0.7569777685049958
Class[5] IoU:     	0.7318685813596661


In [50]:
with open(pjoin('models', model_name, 'test_res_dict_ix.json'), 'w') as f:
    json.dump(test_res_dict_ix, f)

In [None]:
# from torchsummary import summary

In [None]:
# _s = summary(model, (1,256,256))

In [None]:
# x = torch.zeros(1, 1, 256, 256, dtype=torch.float, requires_grad=False)
# out = model(x)

In [None]:
# from torchviz import make_dot

In [None]:
# from torchview import draw_graph

In [None]:
# model_graph = draw_graph(model, input_size=(1,1,256,256), expand_nested=True)

In [None]:
# input_data = torch.randn(1, 1, 256, 256)  # Пример входных данных

# # Определите выход модели, чтобы создать граф
# output_data = model(input_data)

# # Создайте граф
# dot_graph = make_dot(output_data, params=dict(model.named_parameters()))
# dot_graph.view()  # Откроется окно с изображением графа модели

In [None]:
# make_dot(output_data).render("attached", format="png")

In [None]:
# model_graph = draw_graph(model, input_size=(1,1,256,256), expand_nested=True)

In [None]:
# model_graph.visual_graph

In [None]:
# dummy_input = torch.randn(1, 1, 256, 256)
# torch.onnx.export(model, dummy_input, 'model.onnx')

In [None]:
# import netron
# netron.start('model.onnx')