### UNet, resnet34

In [None]:
!pip install torchsummary

In [None]:
!pip install segmentation_models_pytorch

In [None]:
import torch
from torch import nn
from torchsummary import summary
from torchvision import models, transforms, datasets
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import cv2
import os

from sklearn.model_selection import train_test_split

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
data = pd.read_csv('../input/makeup-lips-segmentation-28k-samples/set-lipstick-original/list.csv')
data

In [None]:
IMAGES_PATH = '../input/makeup-lips-segmentation-28k-samples/set-lipstick-original/720p/'
MASKS_PATH = '../input/makeup-lips-segmentation-28k-samples/set-lipstick-original/mask/'

In [None]:
img = os.listdir(IMAGES_PATH)
mask = os.listdir(MASKS_PATH)

Нужно удалить лишние файлы, кусок кода взят из ноутбука: Lips Segmentation LinkNet PyTorch

In [None]:
imgs_set = set(os.listdir(IMAGES_PATH))
masks_set = set(os.listdir(MASKS_PATH))

imgs_set = set(''.join(filter(lambda x: x.isdigit(), i)) for i in imgs_set)
masks_set = set(''.join(filter(lambda x: x.isdigit(), i)) for i in masks_set)

In [None]:
len(imgs_set.difference(masks_set)), len(masks_set.difference(imgs_set))

In [None]:
not_mask = imgs_set.difference(masks_set)

not_mask = [f'image{i}.jpg' for i in not_mask]

data = data.loc[~data['filename'].isin(not_mask)]
data.reset_index(drop=True, inplace=True)

Создадим класс для кастомного датасета:

In [None]:
SIZE = 256
class Pixel_Perfect_Lips_Segmentation(Dataset):
    
    def __init__(self, data, preprocessing=None):
        self.data = data

        self.data_len = len(self.data.index)
        
        self.preprocessing = preprocessing
    
    def __getitem__(self, idx):
        data = self.data.iloc[idx]
        img_path = os.path.join(IMAGES_PATH, data['filename'])
        mask_path = os.path.join(MASKS_PATH, data['mask'])
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_RGB2BGR)
        img = cv2.resize(img, (SIZE, SIZE))
        img = np.array(img).astype('float')
        
        mask = cv2.cvtColor(cv2.imread(mask_path), cv2.COLOR_RGB2BGR)
        mask = cv2.resize(mask, (SIZE, SIZE))
        mask = np.array(mask).astype('float')
        mask = torch.as_tensor(mask)
        
        if self.preprocessing:
            img = self.preprocessing(img)
            img = torch.as_tensor(img)
            mask = self.preprocessing(mask)
            mask = torch.as_tensor(mask)
            
            
        else:
            img = torch.as_tensor(img) / 255.0
            mask = torch.as_tensor(mask) / 255.0
       
        img = img.permute(2,0,1)
        mask = mask.permute(2,0,1)

        
        return (img.float(), mask) #s)
    
    def __len__(self):
        return self.data_len

In [None]:
dataset = Pixel_Perfect_Lips_Segmentation(data)
img, masks = dataset[777]
print(img.shape, masks.shape)
fig, ax = plt.subplots(1, 2, figsize=(15, 7))
ax[0].imshow(img.permute(1, 2, 0))
ax[1].imshow(masks.permute(1, 2, 0))
plt.show()

In [None]:
# 80 % в тренировочную выборку, 20 - в тестовую
train, test = train_test_split(data, test_size=0.2, random_state=9)

# Упорядочиваем индексацию
train.reset_index(drop=True, inplace=True)
test.reset_index(drop=True, inplace=True)

In [None]:
train.shape, test.shape

Оборачиваем каждую выборку в наш кастомный датасет:

In [None]:
train_dataset = Pixel_Perfect_Lips_Segmentation(train)

In [None]:
test_dataset = Pixel_Perfect_Lips_Segmentation(test)

In [None]:
train_data_loader = DataLoader(
    train_dataset,
    batch_size=24,
    shuffle=True
)

In [None]:
test_data_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False
)

In [None]:
for img, target in train_data_loader:
    print(img.shape, target.shape)
    print(img[0].min(), img[0].max())
    print(target[0].min(), target[0].max())
    fig, ax = plt.subplots(1, 2, figsize=(15, 6))
    ax[0].imshow(img[0].permute(1, 2, 0))
    ax[1].imshow(target[0].permute(1, 2, 0))
    break

In [None]:
# aux_params_1=dict(
#     pooling='max',             # one of 'avg', 'max'
#     dropout=0.5,               # dropout ratio, default is None
#     activation='sigmoid',      # activation function, default is None
#     classes=1,                 # define number of output labels
# )

In [None]:
# создание модели, как праваильно добавить aux? Постоянно ошибка TypeError: only integer tensors of a single element can be converted to an index
BACKBONE = 'resnet34'
segmodel = smp.Unet(BACKBONE, classes=1, activation='sigmoid').to(device) # , aux_params=aux_params_1).to(device)
preprocess_input = smp.encoders.get_preprocessing_fn(BACKBONE, pretrained='imagenet')

После препроцессинга resnet34:

In [None]:
dataset = Pixel_Perfect_Lips_Segmentation(data, preprocessing=preprocess_input)
img, masks = dataset[9]
print(img.shape, masks.shape)
fig, ax = plt.subplots(1, 2, figsize=(15, 7))
ax[0].imshow(img[0])#.permute(1, 2, 0))
ax[1].imshow(masks[0])#.permute(1, 2, 0))
plt.show()

In [None]:
for img, target in train_data_loader:
    print(img.shape, target.shape)
    print(img[0].min(), img[0].max())
    print(target[0].min(), target[0].max())
    break

In [None]:
criterion = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU(),]

optimizer = torch.optim.Adam(params=segmodel.parameters(), lr=0.0001)

In [None]:
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
    segmodel, 
    loss=criterion, 
    metrics=metrics, 
    optimizer=optimizer,
    device=device,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    segmodel, 
    loss=criterion, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

In [None]:
# train model

max_score = 0

for i in range(0, 1):
    print(f'Epoch: {i + 1}')
    train_logs = train_epoch.run(train_data_loader)
    valid_logs = valid_epoch.run(test_data_loader)
    
    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(segmodel, './best_model.pth')
        print('Model saved!')

In [None]:
def get_orig(image):
    image = image.permute(1, 2, 0)
    image = image.numpy()
    image = np.clip(image, 0, 1)
    return image

In [None]:
for i, data in enumerate(test_data_loader):
    images, labels = data
    images = images.to(device)
    labels = labels.to(device)
    segmodel.eval()
    outputs = segmodel(images)
    f, axarr = plt.subplots(1,3, figsize=(15, 6))

    for j in range(0, 4):
        axarr[0].imshow(outputs.squeeze().detach().cpu().numpy()[j,:,:])
        axarr[0].set_title('Guessed labels')
        axarr[1].imshow(labels.squeeze().detach().cpu().numpy()[j,:, :].transpose(1,2,0))
        axarr[1].set_title('Ground truth labels')

        original = get_orig(images[j].cpu())
        axarr[2].imshow(original)
        axarr[2].set_title('Original Images')
        plt.show()
    if i > 3:
        break