In [None]:
## Lesson 3 CamVid Tiramisu
%reload_ext autoreload
%autoreload 1
%matplotlib inline

from fastai import *
from fastai.vision import *
from fastai.callbacks.hooks import *
#import matplotlib.pyplot as plt

path = Path('/mnt/c/School/Scripts/TestData/')

In [None]:
from fastai.utils.show_install import *
show_install()

In [None]:
path.ls()

In [None]:
path_lbl = path/'masks'
path_img = path/'images'

In [None]:
fnames = get_image_files(path_img)
fnames[:3]

In [None]:
lbl_names = get_image_files(path_lbl)
lbl_names[:3]

In [None]:
## Data

img_f = fnames[10]
img = open_image(img_f)
img.show(figsize=(5,5))
#plt.show()

In [None]:
img_f.parent.parent

In [None]:
# TODO Continue from here...
def get_y_fn(x):
    print(str(x.parent))
    return Path(str(x.parent.parent)+'/masks') / x.name

codes = array(['Belt', 'Meat', 'Bone', 'Metal'])

mask = open_mask(get_y_fn(img_f))
mask.show(figsize=(5,5), alpha=1)
#plt.show()

src_size = np.array(mask.shape[1:])
print(src_size, mask.data)

In [None]:
## Datasets
bs,size = 4,src_size//2

src = (SegmentationItemList.from_folder(path_img)
       .split_by_fname_file('/mnt/c/School/Scripts/TestData/valid.txt')
       .label_from_func(get_y_fn, classes=codes))
data = (src.transform(get_transforms(), size=size, tfm_y=True)
        .databunch(bs=bs)
        .normalize(imagenet_stats))

In [None]:
data.show_batch(1, figsize=(10,7))

In [None]:
name2id = {v:k for k,v in enumerate(codes)}

def error_measure(input, target):
    target = target.squeeze(1)
    mask_isVoid = target != name2id['Belt']
    mask_isBone = target == name2id['Bone']
    mask_ismetal = target == name2id['Metal']
    
    error = (input.argmax(dim=1)[mask_isVoid]!=target[mask_isVoid]).float().mean()
    boneError = input.argmax(dim=1)[mask_isBone]!=target[mask_isBone]
    metalError = input.argmax(dim=1)[mask_ismetal]!=target[mask_ismetal]
    if len(boneError) != 0:
        error += 10*boneError.float().mean()
    if len(metalError) != 0:
        error += 100*metalError.float().mean()
    
    return error

In [None]:
       
weights = torch.FloatTensor([0., 1., 10., 100.]) # ['Belt', 'Meat', 'Bone', 'Metal'])
def forward(y_hat, y)->Rank0Tensor:  # This custom loss did not work with the backpropagation
    y_hat = y_hat.argmax(dim=1)
    y = y.squeeze(1)
    
    costs = weights[y] * torch.log(y_hat.float() + 1e-16) # epsilon=1e-45
    cost = costs.sum().double()
    return cost

weight = torch.FloatTensor([0., 1., 10., 100.])
def customLossFunction(input, target): # Works
    target = target.squeeze(1)
    return F.nll_loss(torch.log_softmax(input, 1), target, weight, reduction='mean')
    
class CustomLoss(): # Works
    "Weighted Cross Entropy Loss."
    def __init__(self):
        self.weight = torch.FloatTensor([0., 1., 10., 100.])
        self.func = nn.CrossEntropyLoss(self.weight)
        functools.update_wrapper(self, self.func)

    def __repr__(self): return f"FlattenedLoss of {self.func}"
    @property
    def reduction(self): return self.func.reduction
    @reduction.setter
    def reduction(self, v): self.func.reduction = v

    def __call__(self, input:Tensor, target:Tensor, **kwargs)->Rank0Tensor:

        target = target.squeeze(1)        
        return self.func.__call__(input, target, **kwargs)

In [None]:
wd=1e-2

In [None]:
# The image is reduced by half (data.batch_stats), and has 3 chanels by default (the other two are set to 1)
learn = unet_learner(data, models.resnet34, metrics=error_measure, wd=wd, loss_func=customLossFunction) #CustomLoss())

# x, y = data.one_batch(DatasetType.Train, True, True)
# data.one_batch

In [None]:
# TODO: Set the new loss_func here
learn.loss_func, data.loss_func

In [None]:
lr_find(learn)

In [None]:
lr=1e-4

In [None]:
learn.recorder.plot()

In [None]:
learn.fit_one_cycle(10, slice(lr), pct_start=0.8)

In [None]:
learn.save('stage-1-big')

In [None]:
learn.load('stage-1-big');

In [None]:
learn.unfreeze()

In [None]:
lrs = slice(1e-6,lr/10)

In [None]:
learn.fit_one_cycle(10, lrs)

In [None]:
learn.save('stage-2-big')

In [None]:
learn.load('stage-2-big');

In [None]:
#learn.show_results(rows=8, figsize=(20,20)) # This shows incorrect predictions?

In [None]:
preds,y = learn.get_preds()

In [None]:
def getContanimentInfo(prediction):
    items = prediction.unique()
    
    name2id = {v:k for k,v in enumerate(codes)}
    if name2id['Bone'] in items and name2id['Metal'] in items:
        return "Bone and Metal"
    if name2id['Metal'] in items:
        return "Metal"
    if name2id['Bone'] in items:
        return "Bone"
    if name2id['Meat'] in items:
        return "just Meat"
    else:
        return "Nothing"

In [None]:
y = torch.squeeze(y)
print("########################")
print("Ground truth/Predictions")
print("########################")
for i in range(8):
    prediction = preds.argmax(dim=1)[i]
    print("There is {}. Found {}.".format(getContanimentInfo(y[i]), getContanimentInfo(prediction)))
    fig=plt.figure(figsize=(10,20))
    fig.add_subplot(1, 2, 1)
    plt.imshow(y[i])
    fig.add_subplot(1, 2, 2)
    plt.imshow(prediction)
    plt.show()
    

In [None]:
# TODO: Test this stuff
preds,y = learn.get_preds()
interp = ClassificationInterpretation(data, preds, y, loss_class=nn.CrossEntropyLoss)
interp.plot_top_losses(9, figsize=(10,10))
interp.plot_confusion_matrix()