In [None]:
import os
import numpy as np

from PIL import Image
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as tsfm
from tqdm.notebook import tqdm

In [None]:
class CFG:
    # data path
    train_csv_path = '../input/plant-pathology-2021-fgvc8/train.csv'
    train_imgs_dir = '../input/pp2021-train-images-resized/224_square_not_crop'
    save_path = "/kaggle/working/images"
    seed = 77
    batch_size = 32
    num_workers = 2
    device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from torchvision import models

def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )

class ResNetUNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = models.resnet18(pretrained=False)
        self.base_layers = list(self.base_model.children())

        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(3, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)
        self.dropout = nn.Dropout(p=0.25)
        
    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)

        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)

        return out

In [None]:
seg_model = ResNetUNet(n_class=1)
seg_model = seg_model.to(CFG.device)

In [None]:
seg_model.load_state_dict(torch.load('../input/leafsegweights/best_val_weights.pth'))

In [None]:
data_csv = pd.read_csv(CFG.train_csv_path)

In [None]:
"""
Define dataset class
"""
class PlantDataset(Dataset):
    def __init__(self, csv_file, image_loc):
        self.csv_file = csv_file
        self.image_loc = image_loc

    def __len__(self):
        return len(self.csv_file)

    def __getitem__(self, idx):
        img_name = self.csv_file.iloc[idx, 0]
        img_path = os.path.join(self.image_loc,
                                img_name)
        
        img = Image.open(img_path).convert('RGB')
        img = tsfm.ToTensor()(img)
        return img, img_name

In [None]:
ds = PlantDataset(data_csv, CFG.train_imgs_dir)

In [None]:
ds_dataloader = DataLoader(ds, batch_size = CFG.batch_size, shuffle=False, num_workers=2)

In [None]:
def reverse_transform(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    inp = np.clip(inp, 0, 1)
    inp = (inp * 255).astype(np.uint8)
    return inp

In [None]:
def make_dir(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

In [None]:
import matplotlib.pyplot as plt
import imageio

seg_model.eval()
for i, batch_pair in enumerate(tqdm(ds_dataloader)):
    img_batch = batch_pair[0].to(CFG.device)
    img_names = batch_pair[1]
    
    seg_batch = seg_model(img_batch)
    seg_batch = torch.sigmoid(seg_batch)
    for img, seg, filename in zip(img_batch, seg_batch, img_names):
        seg_np = seg.cpu().detach()
        seg_np = reverse_transform(seg_np)
        seg_np = np.where(seg_np > 220, 1, 0)
        
        img_np = img.cpu()
        img_np = reverse_transform(img_np)
        prod_img = np.multiply(seg_np, img_np)
#         plt.figure()
#         plt.imshow(prod_img)
        make_dir(CFG.save_path)
        savename = os.path.join(CFG.save_path, filename)
        imageio.imwrite(savename, prod_img)

In [None]:
import shutil
shutil.make_archive("leaf-segmented-224", 'zip', CFG.save_path)

In [None]:
!rm -rf ./images/*.jpg