In [None]:
import torch
import torch.nn as nn

from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import numpy as np
import cv2 as cv

In [None]:
# ЗАГРУЗКА ОБУЧАЮЩЕЙ ВЫБОРКИ
def input_data():
        
    #Подготовка данных и выделение признаков
    data = torch.zeros((0,1,572,572))

    for j in range(2):

        i = 1
        if (j == 0): name_part = 'yes'; n = 123;          
        elif (j == 1): name_part = 'no'; n = 82;

        for i in range(1,n):

            # Загрузка изображения
            s = 'images/brain_tumor_dataset/' + str(name_part) + '/1 (' + str(i) + ').jpg'
            img = cv.imread(s,0) 

            x_new = 572
            y_new = 572
            dsize = (x_new, y_new)
            output = cv.resize(img, dsize, interpolation = cv.INTER_AREA)
            output = torch.tensor([[output.T]]).float()
            
            mean = output.mean()
            max_ = output.max()
            output_ = (output - mean) / max_
            
            plt.imshow(output_[0][0].T, cmap='gray')
            plt.show()

            data = torch.cat((data, output_))
            print(s)

    return data

data = input_data()

In [None]:
len(data)

In [None]:
# ЗАГРУЗКА МАСКИ
def mask_data():
        
    #Подготовка данных и выделение признаков
    data = torch.zeros((0,3,388,388))
    
    for j in range(2):

        i = 1
        if (j == 0): name_part = 'mask'; n = 123;          
        elif (j == 1): name_part = 'no'; n = 82;

        for i in range(1,n):
        
            # Загрузка изображения
            s = 'images/brain_tumor_dataset/' + str(name_part) + '/1 (' + str(i) + ').jpg'
            img = cv.imread(s,1) # ТЕНЗОР 3 РАНГА

            # Приведение к одному разрешению
            x_new = 388
            y_new = 388
            dsize = (x_new, y_new)
            output = cv.resize(img, dsize, interpolation = cv.INTER_AREA)
            output = torch.tensor([output.T]).float()

            mean = output.mean()
            max_ = output.max()
            output_ = (output - mean) / max_

            plt.imshow(output_[0].T)
            plt.show()

            data = torch.cat((data, output_))
            print(s)

    return data

data_mask = mask_data()

In [None]:
len(data_mask)

In [None]:
data, data_mask = shuffle(data, data_mask) 

In [None]:
def double_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size = 3),
        nn.ReLU(inplace = True),
        nn.Conv2d(out_c, out_c, kernel_size = 3),
        nn.ReLU(inplace = True)
    )
    return conv

In [None]:
def crop_img(tensor, target_tensor):
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta]

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()   # При наследовании классов
        
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.down_conv_1 = double_conv(1, 64)
        self.down_conv_2 = double_conv(64, 128)
        self.down_conv_3 = double_conv(128, 256)
        self.down_conv_4 = double_conv(256, 512)
        self.down_conv_5 = double_conv(512, 1024)
        
        self.up_trans_1 = nn.ConvTranspose2d(
            in_channels = 1024, 
            out_channels = 512,
            kernel_size = 2, 
            stride = 2)
        self.up_conv_1 = double_conv(1024, 512)
        
        self.up_trans_2 = nn.ConvTranspose2d(
            in_channels = 512, 
            out_channels = 256,
            kernel_size = 2, 
            stride = 2)
        self.up_conv_2 = double_conv(512, 256)
        
        self.up_trans_3 = nn.ConvTranspose2d(
            in_channels = 256, 
            out_channels = 128,
            kernel_size = 2, 
            stride = 2)
        self.up_conv_3 = double_conv(256, 128)
        
        self.up_trans_4 = nn.ConvTranspose2d(
            in_channels = 128, 
            out_channels = 64,
            kernel_size = 2, 
            stride = 2)
        self.up_conv_4 = double_conv(128, 64)
        
        self.out = nn.Conv2d(
            in_channels = 64,
            out_channels = 3,
            kernel_size = 1)
        
    def forward(self, image):
        # encoder
        x1 = self.down_conv_1(image)  # ---->
        #print(x1.size())
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2) # ---->
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4) # ---->
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6) # ---->
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)
        #print(x9.size())
        
        # decoder
        x = self.up_trans_1(x9)
        y = crop_img(x7, x)
        x = self.up_conv_1(torch.cat([x, y], 1))
        
        x = self.up_trans_2(x)
        y = crop_img(x5, x)
        x = self.up_conv_2(torch.cat([x, y], 1))
        
        x = self.up_trans_3(x)
        y = crop_img(x3, x)
        x = self.up_conv_3(torch.cat([x, y], 1))
        
        x = self.up_trans_4(x)
        y = crop_img(x1, x)
        x = self.up_conv_4(torch.cat([x, y], 1))
        
        x = self.out(x)
        #print(x.size())
        return x

In [None]:
UNet_model = UNet()

In [None]:
#criterion = nn.BCEWithLogitsLoss()  # для бинарной классификации
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(UNet_model.parameters(), lr=0.001, weight_decay = 0.001)

In [None]:
epochs = 5
history = []
#loss_ = 0.0

In [None]:
for i in range(epochs):
    for j in range(0,200,20):
        # Предсказание
        y = UNet_model(data[j:j+20,:1,:572,:572])

        # Вычисдение ошибки
        loss = criterion(y, data_mask[j:j+20,:3,:388,:388])
        history.append(loss.item())
        #loss_ += loss

        # Обнуление градиентов
        optimizer.zero_grad()

        # Расчет градиентов
        loss.backward()

        # Обновление градиентов
        optimizer.step()
        #print(loss)
    
    #history.append(loss_)
    print("Epoches_Step: ", i + 1, "\t", "Loss_Value:", loss.item())
    #loss_ = 0.0

In [None]:
plt.figure(figsize=(6, 4))
plt.plot(history)
plt.title('Loss by mini_batches')
plt.ylabel('Loss')
plt.xlabel('mini_batches number')
plt.grid()
plt.show()

In [None]:
for j in range(1):    
    y = UNet_model(data[j:j+1,:1,:572,:572])
    img = y.detach().numpy() * 12
    IMG = img[0].T

    plt.subplots(figsize = (4,4))
    plt.imshow(IMG)
    plt.show()

In [None]:
i = 7
name_part = 'yes'        
# Загрузка изображения
s = 'images/brain_tumor_dataset/' + str(name_part) + '/1 (' + str(i) + ').jpg'
img = cv.imread(s,0) 

#cv2.imshow('Color blue',color_image[:,:,0])
#cv2.imshow('Color green',color_image[:,:,1])
#cv2.imshow('Color red',color_image[:,:,2])

x_new = 572
y_new = 572
dsize = (x_new, y_new)
output = cv.resize(img, dsize, interpolation = cv.INTER_AREA)
output = torch.tensor([[output.T]]).float()
mean = output.mean()
max_ = output.max()
output_ = (output - mean) / max_
plt.imshow(output_[0][0].T, cmap='gray')
plt.show()

In [None]:
j = 0
y = UNet_model(output_[j:j+1,:1,:572,:572])
img = y.detach().numpy() * 12
IMG = img[0].T

plt.subplots(figsize = (4,4))
plt.imshow(IMG)
plt.show()

In [None]:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in UNet_model.state_dict():
    print(param_tensor, "\t", UNet_model.state_dict()[param_tensor].size())

In [None]:
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

In [None]:
# Saving model
PATH = 'UNet_model.pt'
torch.save(UNet_model.state_dict(), PATH)