In [None]:
from fastai.vision.all import *
from skimage import measure
from skimage.transform import rescale, resize
from skimage.util import crop, montage
from skimage.morphology import label, square, dilation, watershed
from skimage.io import imsave


from tqdm import tqdm
from PIL import Image
import torch
import torch.nn.functional as F

In [None]:
#path = Path('../input/cropped-imges-and-masks')
#hide
#Path.BASE_PATH = path
df= pd.read_csv('../input/cropped-imges-and-masks-without-val-leak/crops_with_ships.csv')
df.head()

### Tratamiento del CSV
Como podemos observar, en el CSV va a aparecer una entrada por barco, por lo tanto, cuando aparece más de un barco en la misma imagen, ésta aparecerá tantas veces en el csv como barcos contenga. Vamos a agrupar todos los barcos de la imagen en una sola entrada, agrupando por ImageId y uniendo encoded pixels con un espacio de separación. Además de ello, para mayor facilidad en el entrenamiento posterior añadiremos un nuevo campo que llamaremos "has_ship" que vale 1 en caso de contener barcos y 0 en caso de no contenerlos. https://blog.softhints.com/python-detect-prevent-typeerror/

In [None]:
#Para el segmentador nos quedamos sólo con las imagenes con barcos
df.drop(df[df['has_ships'] == False].index, inplace=True)
df.head()


In [None]:

def image_open(img_path):
    return np.array(Image.open(img_path))

def apply_mask(image,mask):
    imax,jmax=mask.shape
    image_masked=np.copy(image)
    for i in range(imax):
        for j in range(jmax):
            if mask[i,j]==1:
                image_masked[i,j,[0,0]]=170
    return image_masked


In [None]:
mascara= image_open('../input/airbus-ship-detection/train_v2/00003e153.jpg')
mascara.reshape(3,768,768)
np.shape(mascara)

## Segmentacion


In [None]:
class Dice(Metric):
    "Dice coefficient metric for binary target in segmentation"
    def __init__(self, axis=1): self.axis = axis
    def reset(self): self.inter,self.union = 0,0
    def accumulate(self, learn):
        pred,targ = flatten_check(learn.pred.argmax(dim=self.axis), learn.y)
        pred, targ = TensorBase(pred), TensorBase(targ)
        self.inter += (pred*targ).float().sum().item()
        self.union += (pred+targ).float().sum().item()

    @property
    def value(self): return 2. * self.inter/self.union if self.union > 0 else None

def IoU(input, target):
    """Intersection over Union (IoU) metric."""
    input = input.argmax(dim=1).float()
    target = target.squeeze(1).float()
    
    smooth = 1.
    intersection = (input * target).sum()
    union = (input + target).sum() - intersection
    return (intersection + smooth) / (union + smooth)

In [None]:
import albumentations

class AlbumentationsTransform(RandTransform):
    "A transform handler for multiple `Albumentation` transforms"
    split_idx,order=None,2
    def __init__(self, train_aug, valid_aug): store_attr()
    
    def before_call(self, b, split_idx):
        self.idx = split_idx
    
    def encodes(self, img: PILImage):
        if self.idx == 0:
            aug_img = self.train_aug(image=np.array(img))['image']
        else:
            aug_img = self.valid_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

In [None]:
def get_train_aug(): return albumentations.Compose([
            albumentations.Resize(256,256),
            albumentations.Transpose(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.ShiftScaleRotate(p=0.5),
])
def get_valid_aug(): return albumentations.Compose([
    
    albumentations.Resize(256,256)
], p=1.)

In [None]:
item_tfms =  AlbumentationsTransform(get_train_aug(), get_valid_aug())

# Función de perdidas


In [None]:
def get_y(r): 
    fname=r.stem
    mascara= image_open(os.path.join('../input/cropped-imges-and-masks-without-val-leak/masks','{0}.tif'.format(fname)))
    barcos=mascara[:,:,0]/255
    bordes=2*(mascara[:,:,1]/255)
    return barcos+bordes
    
#cambiar item_tfms para hacer data augmentation
#dblock = DataBlock(blocks=(ImageBlock,MaskBlock), get_x=get_x, get_y=get_y, item_tfms=Resize(256))
#dsets = dblock.datasets(masks_df)


In [None]:
fnames=[]
for index, row in tqdm(df.iterrows()):
    fnames.append(Path(os.path.join('../input/cropped-imges-and-masks-without-val-leak/crops',row['img_name'])))



In [None]:
#dls=dblock.dataloaders(masks_df,bs=32)
#mult=1.0, do_flip=True, flip_vert=False, max_rotate=10.0, min_zoom=1.0, max_zoom=1.1, max_lighting=0.2, max_warp=0.2, p_affine=0.75, p_lighting=0.75, xtra_tfms=None, size=None, mode='bilinear', pad_mode='reflection', align_corners=True, batch=False, min_scale=1.0
dls = SegmentationDataLoaders.from_label_func(
    "crops", bs=32, fnames = fnames, label_func = get_y,splitter=RandomSplitter(), batch_tfms=[*aug_transforms(mult=2,flip_vert=True,max_warp=0)])

In [None]:
dls.show_batch(max_n=8,unique=True)

In [None]:
class FocalLossFlat(CrossEntropyLossFlat):
    """
    Same as CrossEntropyLossFlat but with focal paramter, `gamma`. Focal loss is introduced by Lin et al.
    https://arxiv.org/pdf/1708.02002.pdf. Note the class weighting factor in the paper, alpha, can be
    implemented through pytorch `weight` argument in nn.CrossEntropyLoss.
    """
    y_int = True
    @use_kwargs_dict(keep=True, weight=None, ignore_index=-100, reduction='mean')
    def __init__(self, *args, gamma=2, axis=-1, **kwargs):
        self.gamma = gamma
        self.reduce = kwargs.pop('reduction') if 'reduction' in kwargs else 'mean'
        super().__init__(*args, reduction='none', axis=axis, **kwargs)
    def __call__(self, inp, targ, **kwargs):
        ce_loss = super().__call__(inp, targ, **kwargs)
        pt = torch.exp(-ce_loss)
        fl_loss = (1-pt)**self.gamma * ce_loss
        return fl_loss.mean() if self.reduce == 'mean' else fl_loss.sum() if self.reduce == 'sum' else fl_loss

In [None]:

learn= unet_learner(dls,resnet18,n_out=3, metrics=[Dice()],lr=1e-2,loss_func=FocalLossFlat(axis=1))


In [None]:
learn.fine_tune(90,freeze_epochs=4)

In [None]:
learn.save('Unet_resnet16_seg_cropped_256_bs32_3channels_wo_leakage_80epochs')

In [None]:
learn.recorder.plot_loss()

In [None]:
learn.show_results(max_n=20,figsize=(15,15))

In [None]:
!ls crops/models