In [None]:
import config
import os
os.chdir(config.PROJECT_ROOT_PATH)

import torch
from torchvision.utils import draw_segmentation_masks
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms.functional as tvF

from src.unet import dataset, loss, unet, unet2

import GPUtil

import numpy as np
import skimage.measure
from PIL import Image
import cv2
from pytesseract import pytesseract

pytesseract.tesseract_cmd = r'/usr/bin/tesseract'
!export TESSDATA_PREFIX= / usr / local / share /

import warnings
warnings.filterwarnings('ignore')

In [None]:
# !sudo rmmod nvidia_uvm
# !sudo modprobe nvidia_uvm

In [None]:
# для отрисовки тензоров в полном размере :)
# torch.set_printoptions(profile="full")
# torch.set_printoptions(linewidth=500)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

#### Отрисовка результатов

In [None]:
def plot_many(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = tvF.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])


def triple_pic(x):
    ret = torch.cat([x * 255] * 3).to(dtype=torch.uint8)
    return ret

#### Тренировочный цикл

In [None]:
def train_model(model, criterion, optimizer, dataloaders, n_epochs, epoch_size=100000):
    train_dataloader, valid_dataloader = dataloaders['train'], dataloaders['valid']

    results = {
        'train': [],
        'valid': []
    }

    for epoch in range(1, n_epochs + 1):
        GPUtil.showUtilization()
        print(f'EPOCH {epoch} STARTED')
        # training
        model.train()
        train_losses = []
        for batch_i, (x, y) in enumerate(train_dataloader):
            x = x.to(device)
            y = y.to(device)
            pred_mask = model(x)
            loss = criterion(pred_mask, y)
            print('epoch:', epoch, 'train batch', batch_i, 'loss:', loss.item())
            train_losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if batch_i == epoch_size:
                break

        # testing
        model.eval()
        valid_losses = []
        for batch_i, (x, y) in enumerate(valid_dataloader):
            x = x.to(device)
            y = y.to(device)
            with torch.no_grad():
                pred_mask = model(x)
                loss = criterion(pred_mask, y)
                print('epoch:', epoch, 'valid batch', batch_i, 'loss:', loss.item())
            valid_losses.append(loss.item())
            if batch_i >= epoch_size*0.2:
                break

        # draw_results(x,pred_mask)

        results['train'].append(np.mean(train_losses))
        results['valid'].append(np.mean(valid_losses))
        print(results['train'][-1], results['valid'][-1])

        # torch.save(model.state_dict(), f'model_states/epoch_{epoch}_state.st')
    return results, model

### Создание и обучение модельки

In [None]:
RESIZE = 512
# -1, если нужно использовать все возможные картинки
MAX_IMAGES_NUM = -1
N_EPOCHS = 10
BATCH_SIZE = 1
EPOCH_SIZE = 30

dataset_params = {
    'resize':RESIZE,
    'batch_size':BATCH_SIZE,
    'datadir':config.DATA_PATH,
    'classes_list':config.CLASSES_LIST,
    'class_num_mapping':config.NUMS_CLASSES,
    'orig_class_name':config.ORIG_CLASS_NAME,
    'max_images':MAX_IMAGES_NUM,
}

In [None]:
torch.cuda.empty_cache()
GPUtil.showUtilization()

dataloaders = dataset.create_dataloaders(
    **dataset_params
)

model = unet2.UNet2(len(config.CLASSES_LIST))
model.to(device)

criterion = loss.JaccardLoss('multiclass')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
results, model = train_model(model, criterion, optimizer, dataloaders, n_epochs=N_EPOCHS, epoch_size=EPOCH_SIZE)

torch.cuda.empty_cache()
GPUtil.showUtilization()

In [None]:
!ls model_states
# !mkdir model_states
!rm model_states/*
torch.save(model.state_dict(), 'model_states/model_state.st')
# model = unet2.UNet2(3)
# # model.to(device)
# model.load_state_dict(torch.load('/home/f/Programming/projects/scans_to_pdf_cg_cv/model_states/model_state.st',
#                                  map_location=torch.device('cpu')))
# model.eval()

### Тест

In [None]:
loader = dataset.create_dataloaders(**dataset_params)['train']

In [None]:
def process_prediction(t):
    t = t[0]
    t = t.softmax(dim=0)
    t = t.argmax(dim=0)
    t = F.one_hot(t)
    t = t.transpose(0, 2).transpose(2, 1)
    t = t.bool()
    return t


def process_input(x):
    return x[0]

def process_input_and_prediction(x, t):
    x = process_input(x)
    t = process_prediction(t)
    x, t = x.to('cpu'), t.to('cpu')
    return x, t



x, _ = iter(loader).__next__()
x = x.to('cpu')
model = model.to('cpu')
with torch.no_grad():
    t = model(x)

x,t = process_input_and_prediction(x,t)
plot_many(draw_segmentation_masks(triple_pic(x), masks=t, alpha=0.7))


маску каждого класса разбираем на компоненты

In [None]:
plt.figure(figsize=(13, 13))

l = len(t)
for i, x in enumerate(t):
    t_map = t[i].numpy()
    lables0 = skimage.measure.label(t_map)
    plt.subplot(1*100 + l*10 + i+1)
    plt.imshow(lables0, cmap='summer')
    plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
def get_masks_of_each_component(t, draw=False):
    t = t.to('cpu')

    masks = dict()
    for class_num in range(t.shape[0]):
        if class_num == 0:
            continue
        t_class_map = t[class_num].numpy()
        components = skimage.measure.label(t_class_map)
        components_nums = np.unique(components)
        masks[class_num] = list()
        for cn in components_nums:
            if cn == 0:
                continue
            masks[class_num].append(
                (lambda x: x != cn)(components.copy())
            )

        if draw:
            print(components_nums)
            plt.figure(figsize=(13, 13))
            plt.subplot(131)
            plt.imshow(components, cmap='summer')
            plt.axis('off')
            plt.tight_layout()
            plt.show()

    return masks


masks = get_masks_of_each_component(t)
test_mask = masks[1][3]

In [None]:
def apply_mask_to_image(image, mask, draw=False):
    image = image[0].to('cpu').numpy()
    print(image.shape, mask.shape)
    masked = np.ma.masked_array(image, mask)
    masked = np.array(masked.filled(1))
    print(masked)
    if draw:
        plt.imshow(masked, cmap='gray')
    print(image.shape, mask.shape)
    return np.uint8(masked * 255)

masked_image = apply_mask_to_image(x, test_mask)
plt.imshow(masked_image, cmap='gray')

In [None]:
def extract_text_from_image(image):
    image = Image.fromarray(image, 'L')
    text = pytesseract.image_to_string(image, lang='eng')
    return text


# text = extract_text_from_image(masked_image)
# print(text)
x_np = np.uint8(x.numpy()[0]*255)
print(x_np.shape)
x_img = Image.fromarray(x_np, 'L')
x_img.show()
text = pytesseract.image_to_string(x_img, lang='eng')
print(text)