## 入力形式に対応したデータセットを作る
- case2 ~ case133までを学習用、case134以降を評価用に使う

https://github.com/akshaykvnit/pl-sem-seg/blob/master/pl_training.ipynb

In [None]:
EPOCHS = 1
THRESHOLD = 0.99999
INPUT_SIZE = (360, 360)

In [None]:
import os
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from glob import glob

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.nn import functional as F
from torchvision.models.segmentation import lraspp_mobilenet_v3_large, fcn_resnet50

import pytorch_lightning as pl

from PIL import Image, ImageOps
import matplotlib.pyplot as plt

In [None]:
def get_file_id(file_name):
    parts = file_name.split("_")
    case = parts[0].split("/")[-1]
    day = parts[1].split("/")[0]
    slice = "slice_" + parts[2]
    return case, day, slice

def get_file_dict(file_names):
    file_dict = {}
    for file_name in file_names:
        case, day, slice = get_file_id(file_name)
        file_id = "{}_{}_{}".format(case, day, slice)
        file_dict[file_id] = file_name
    return file_dict


In [None]:
def rle3_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string list formated
    '''
    results = []
    for i in range(3):
        pixels = img[i].flatten()
        pixels = np.concatenate([[0], pixels, [0]])
        runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
        runs[1::2] -= runs[::2]
        results.append(' '.join(str(x) for x in runs))
    return results

def rle3_decode(mask_rle3, shape, input_size):
    '''
    mask_rle3: run-length as string list formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    masks = []
    for mask_rle in mask_rle3:
        img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
        mask_rle = str(mask_rle)
        if mask_rle != 'nan':
            s = mask_rle.split()
            starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
            starts -= 1
            ends = starts + lengths
            for lo, hi in zip(starts, ends):
                img[lo:hi] = 1
        img = img.reshape(shape)
        masks.append(img)
    stacked = np.stack(masks).transpose(2, 1, 0)
    res_img = ImageOps.mirror(Image.fromarray(stacked)).rotate(90)
    res_img = res_img.resize(input_size)
    
    return np.asarray(res_img)

In [None]:
def create_dataset(pd_table, input_size, mode):
    dataset = []
    # CSVファイルを順に処理
    for colum_num in range(0, len(pd_table), 3):
        table = pd_table[colum_num:colum_num+3]
        assert(len(table['id'].unique())==1)
        assert(list(table['class']) == ['large_bowel', 'small_bowel', 'stomach'])
        file_id, rle_str_list = list(table['id'])[0], list(table['segmentation'])

        file_name = file_dict[file_id]
        case_id = int(file_id.split('_')[0][4:])
        
        if (mode == 'train' and case_id < 134) or (mode == 'test' and case_id >= 134):
            data = {}
            data["file_name"] = file_name
            data["mask"] = rle_str_list
            dataset.append(data)

    return dataset



class MyDataset(Dataset):
    def __init__(self, pd_table, input_size, mode='train'):
        super().__init__()
        self.pd_table = pd_table
        self.input_size = input_size
        self.mode = mode
        self.data = create_dataset(pd_table, input_size, mode)
        self.len = len(self.data)
        
    def __len__(self):
        return self.len
    
    def __getitem__(self, index):
        file_name = self.data[index]["file_name"]
        shape = (int(file_name.split('_')[3]), int(file_name.split('_')[4]))
        image = np.asarray(Image.open(file_name).resize(self.input_size)).astype(np.float32)[:, :, np.newaxis].transpose(2,1,0)
        mask = rle3_decode(self.data[index]["mask"], shape, self.input_size).astype(np.float32).transpose(2,1,0)
        return image, mask

In [None]:
root_dir = "/kaggle/input/uw-madison-gi-tract-image-segmentation/"
pd_table = pd.read_csv("{}train.csv".format(root_dir))

file_names = glob("{}train/*/*/scans/*.png".format(root_dir))
file_dict = get_file_dict(file_names)

train_dataset = MyDataset(pd_table, INPUT_SIZE)
valid_dataset = MyDataset(pd_table, INPUT_SIZE, mode='test')

n_cpu = os.cpu_count()
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=n_cpu)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=n_cpu)

In [None]:
print("---\ntrain")
image, mask = train_dataset[81]
plt.subplot(1,2,1)
plt.imshow(image.transpose(2,1,0))
plt.subplot(1,2,2)
plt.imshow(mask.transpose(2,1,0))
plt.show()

print("---\nvalid")
image, mask = valid_dataset[62]
plt.subplot(1,2,1)
plt.imshow(image.transpose(2,1,0))
plt.subplot(1,2,2)
plt.imshow(mask.transpose(2,1,0))
plt.show()

In [None]:
class SegModel(pl.LightningModule):
    def __init__(self):
        super(SegModel, self).__init__()
        self.batch_size = 16
        self.learning_rate = 1e-4
        self.net = fcn_resnet50(num_classes=3, pretrained_backbone=False)
        self.sigmoid = nn.Sigmoid()
        #self.net.backbone['0'][0] = nn.Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        self.net.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # fcn_resnet50
        #self.transform = transforms.Compose([
            #transforms.ToTensor(),
            #transforms.Normalize(mean = [0.35675976, 0.37380189, 0.3764753], std = [0.32064945, 0.32098866, 0.32325324])
        #])
        self.trainset = train_dataset
        self.testset = valid_dataset
        
    def forward(self, x):
        return self.sigmoid(self.net(x)['out'])
    
    def training_step(self, batch, batch_nb) :
        img, mask = batch
        img = img.float()
        mask = mask.float()
        out = self.forward(img)
        #print('mask', mask)
        #print('out', out)
        loss_val = F.cross_entropy(out, mask)
#         print(loss.shape)
        return {'loss' : loss_val}
    
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.net.parameters(), lr = self.learning_rate)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max = 10)
        return [opt], [sch]
    
    def train_dataloader(self):
        return DataLoader(self.trainset, batch_size = self.batch_size, shuffle = True)
    
    def test_dataloader(self):
        return DataLoader(self.testset, batch_size = 1, shuffle = False)

In [None]:
model = SegModel()
checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath = '/',save_last = True, verbose = True, monitor = 'loss', mode = 'min')
trainer = pl.Trainer(gpus = 1, max_epochs= EPOCHS, checkpoint_callback = checkpoint_callback)
trainer.fit(model)

In [None]:
glob('/kaggle/working/lightning_logs/version_0/checkpoints/*.ckpt')

In [None]:
model = SegModel()
ckpt_file = glob('/kaggle/working/lightning_logs/version_0/checkpoints/*.ckpt')[-1]
checkpoint = torch.load(ckpt_file, map_location = lambda storage, loc : storage)
model.load_state_dict(checkpoint['state_dict'])
model.eval()

---

## 推論

In [None]:
print("---\ntest")
image, mask = train_dataset[81]
plt.subplot(1,4,1)
plt.imshow(image.transpose(2,1,0))

predict = model(torch.tensor(image).unsqueeze(0))
result = predict.cpu().detach().numpy().squeeze().transpose(2,1,0)
plt.subplot(1,4,2)
plt.imshow(result)

plt.subplot(1,4,3)
plt.imshow((result>THRESHOLD).astype(np.float32))

plt.subplot(1,4,4)
plt.imshow(mask.transpose(2,1,0))
plt.show()

## 提出用データ生成

In [None]:
sub_df = pd.read_csv('{}sample_submission.csv'.format(root_dir))
if not len(sub_df):
    debug = True
    test_fnames = glob("{}train/*/*/scans/*.png".format(root_dir))
else:
    debug = False
    test_fnames = glob("{}test/*/*/scans/*.png".format(root_dir))

In [None]:
pred_ids, pred_classes, pred_strings = [], [], []
for test_fname in test_fnames:
    case, day, slice = get_file_id(test_fname)
    shape = (int(test_fname.split('_')[3]), int(test_fname.split('_')[4]))
    file_id = "{}_{}_{}".format(case, day, slice)
    image = np.asarray(Image.open(test_fname).resize(INPUT_SIZE)).astype(np.float32)[:, :, np.newaxis].transpose(2,1,0)
    predict = model(torch.tensor(image).unsqueeze(0))
    resized = Image.fromarray((predict.cpu().detach().numpy().squeeze().transpose(2,1,0)).astype(np.uint8)).resize(shape)
    result = (np.asarray(resized)>THRESHOLD).astype(np.float32).transpose(2,1,0)
    rle_results = rle3_encode(result)
    for cls, rle_result in zip(['large_bowel', 'small_bowel', 'stomach'], rle_results):
        pred_ids.append(file_id)
        pred_classes.append(cls)
        pred_strings.append(rle_result)

pred_df = pd.DataFrame({
    "id":pred_ids,
    "class":pred_classes,
    "predicted":pred_strings
})

if not debug:
    sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/sample_submission.csv')
    del sub_df['predicted']
else:
    sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/train.csv')[:1000*3]
    del sub_df['segmentation']
    
sub_df = sub_df.merge(pred_df, on=['id','class'])
sub_df.to_csv('submission.csv',index=False)
display(sub_df.head(5))