## install packages

Kaggle notebook: https://www.kaggle.com/benayang/fastai-classifier-v2-563bc4

In [None]:
#!pip install iterative_stratification -q
#!pip install "../input/hpacellsegmentatormaster/HPA-Cell-Segmentation-master"
!pip install "../input/hpapytorchzoozip/pytorch_zoo-master"

In [None]:
import sys
sys.path.append('../input/iterative-stratification/iterative-stratification-master')
sys.path.append('../input/hpacellsegmentatormaster/HPA-Cell-Segmentation-master')

In [None]:
from fastai.vision import *
from fastai.vision.all import *
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
import warnings
warnings.filterwarnings('ignore')

import hpacellseg.cellsegmentator as cellsegmentator
from hpacellseg.utils import label_cell, label_nuclei
import cv2

In [None]:
# from https://www.kaggle.com/c/hpa-single-cell-image-classification/data
# not actually needed

specified_class_names = """0. Nucleoplasm
1. Nuclear membrane
2. Nucleoli
3. Nucleoli fibrillar center
4. Nuclear speckles
5. Nuclear bodies
6. Endoplasmic reticulum
7. Golgi apparatus
8. Intermediate filaments
9. Actin filaments 
10. Microtubules
11. Mitotic spindle
12. Centrosome
13. Plasma membrane
14. Mitochondria
15. Aggresome
16. Cytosol
17. Vesicles and punctate cytosolic patterns
18. Negative"""

class_names = [class_name.split('. ')[1] for class_name in specified_class_names.split('\n')]
class_names

## training setup

In [None]:
path = Path('../input/hpa512x512dataset')
df = pd.read_csv(path/'train.csv')

In [None]:
sample_size = 1
seed = 42
stats = ([0.07237246, 0.04476176, 0.07661699], [0.17179589, 0.10284516, 0.14199627])
item_tfms = RandomResizedCrop(448, min_scale=0.75, ratio=(1.,1.))
batch_tfms = [*aug_transforms(flip_vert=True, max_warp=0), Normalize.from_stats(*stats)]
bs = 32
lr = 3e-2
epochs = 5
cbs = None


path = Path('../input/hpa512x512dataset')
df = pd.read_csv(path/'train.csv')

labels = [str(i) for i in range(19)]
# one-hot encoding of labels
for x in labels: df[x] = df['Label'].apply(lambda r: int(x in r.split('|')))
dfs = df.sample(frac=sample_size, random_state=seed).reset_index(drop=True)

In [None]:
#dfs = df.sample(frac=sample_size, random_state=seed).reset_index(drop=True)
y = dfs[labels].values
X = dfs['ID'].values
dfs['fold'] = np.nan

mskf = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
for i, (_, test_index) in enumerate(mskf.split(X, y)):
    dfs.iloc[test_index, -1] = i

In [None]:
dfs.fold.value_counts().plot.bar();

In [None]:
dfs['fold'] = dfs['fold'].astype('int')
dfs['is_valid'] = False
dfs['is_valid'][dfs.fold == 0] = True
dfs.head(3)

In [None]:
cheat_id = list(dfs[(dfs["0"]==1) & (dfs["3"]==1)].ID)


In [None]:
def get_x(r): return path/"rgb_train"/f'{r["ID"]}.png'
def get_y(r): return list(set(r['Label'].split('|')))

In [None]:
dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock(vocab=labels)),
                    splitter=ColSplitter(col='is_valid'),
                    #splitter=TrainTestSplitter(test_size=0.2, random_state=seed, stratify=None, shuffle=True), 
                    get_x=get_x,
                    get_y=get_y,
                    item_tfms=item_tfms,
                    batch_tfms=batch_tfms
                    )
dls = dblock.dataloaders(dfs, bs=bs)

In [None]:
dls.train.show_batch()

In [None]:
dls.valid.show_batch()

In [None]:
train_df = dfs[dfs.is_valid==False]
valid_df = dfs[dfs.is_valid==True]

In [None]:
def oversample(frame: pd.DataFrame):
    max_size = frame['Label'].value_counts().max()
    lst = [frame]
    for class_index, group in frame.groupby('Label'):
        lst.append(group.sample(max_size-len(group), replace=True))
    return pd.concat(lst)

In [None]:
oversampled_train_df = oversample(train_df) # <------- TRAINING DATAFRAME ONLY
# oversampled_train_df['Label'].value_counts()

In [None]:
oversampled_df = pd.concat((oversampled_train_df, valid_df))
# oversampled_df.head()

In [None]:
train_df = dfs[dfs.is_valid==False]
valid_df = dfs[dfs.is_valid==True]

In [None]:
vals,counts=np.unique(oversampled_train_df['Label'], return_counts=True)
fig = plt.figure(figsize=(50,5))
plt.bar(vals,counts,width=0.75)
plt.xticks(rotation = 90)
plt.title("Oversampled")
plt.savefig('Oversampled_class_dist.png')

vals,counts=np.unique(train_df['Label'], return_counts=True)
fig = plt.figure(figsize=(50,5))
plt.bar(vals,counts,width=0.75)
plt.xticks(rotation = 90)
plt.title("Original")
plt.savefig('Original_class_dist.png')


In [None]:
oversampled_dls = dblock.dataloaders(oversampled_df) # Updated - Pass the dataframe, not the folder

In [None]:
oversampled_dls.train.show_batch()

In [None]:
# oversampled_dls.valid.show_batch()

## training part - done

In [None]:
#train_labels = list(dls.train_ds.items.Label)
#unique_train_labels, counts = np.unique(train_labels,return_counts=True)
#class_weights = 1./counts
#class_weight_dict = dict(zip(unique_train_labels, class_weights))
#weights = [class_weight_dict[x] for x in train_labels]
#total_len_oversample = int(dls.train_ds.c*np.max(counts))

In [None]:
# learn = cnn_learner(oversampled_dls, resnet18, metrics=[accuracy_multi, APScoreMulti()]).to_fp16()

In [None]:
# resnet34 (lr_min=0.025118863582611083, lr_steep=0.03981071710586548)
#learn.lr_find()

In [None]:
# cbs=[SaveModelCallback()]
# #learn.fit_one_cycle(2, cbs=cbs)
# learn.fine_tune(epochs, base_lr=2.5e-2, cbs = cbs)
# learn.export('./resnet18_2.5e-2_oversample (4_26_21).pkl')

### load trained learner for analysis

In [None]:
#learn.export(fname='./resnet50_2.5e-2.pkl')
#learn.save('./resnet18_2.5e-2')
#learn.export('./resnet18_2.5e-2_oversample (4_24_21).pkl')
learn = load_learner('../input/resnet18-oversample/resnet18_2.5e-2_oversample(4_27_21).pkl')
#learn = load_learner('./resnet18_2.5e-2_rotate.pkl')
#learn.load('resnet18_2.5e-2_oversample(4_27_21)')

#### get predictions - takes some time, so saved

In [None]:
preds1 = learn.get_preds(dl=oversampled_dls.valid, with_input=False, with_loss=False, with_decoded=True, act=None)

#### load predictions

In [None]:
from sklearn.metrics import multilabel_confusion_matrix, precision_score,f1_score
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


In [None]:
#np.save(arr=pd.DataFrame(preds), file='./resnet18_2.5e-2_oversample_predictions.npy')
preds = np.load('../input/resnet18oversamplepredictions/resnet18_2.5e-2_oversample_predictions.npy', allow_pickle=True)

In [None]:
pred_raw = preds[0][0]
y_true = preds[1][0]

In [None]:
# fast heatmap of labels
# k = 3
# fig,ax = plt.subplots(figsize=(k*2,k*3))
# sns.heatmap(pred_raw.numpy())
# plt.savefig("raw_pred_scores.png",dpi=300)

### plotting various metrics

In [None]:
# multi-label accuracy
accuracy_multi(pred_raw, y_true, thresh=0.9,sigmoid=False)

In [None]:
pscore = []
f1 = []
for thresh in np.arange(0.05,1,0.05):
    y_pred = (pred_raw.numpy() > thresh).astype(int)
    f1.append(f1_score(y_true,y_pred,average="weighted"))
    pscore.append(precision_score(y_true,y_pred,average="weighted"))


plt.plot(np.arange(0.05,1,0.05),pscore,label="precision")
plt.plot(np.arange(0.05,1,0.05),f1,label="f1")
plt.xlabel("threshold")
plt.legend()

In [None]:
# prediction heatmaps for various thresholds
# for i in range(5,10):
#     thresh = 0.1*i
    
#     y_pred = (pred_raw.numpy() > thresh).astype(int)
#     fig,ax = plt.subplots(figsize=(k*2,k*3))
#     sns.heatmap(y_pred,cbar=False)
#     plt.savefig(f"pres_thresh_{i}.png",dpi=300)

In [None]:
from sklearn.metrics import classification_report
xs = np.linspace(0.5,1,100)
cl_reports = []
for i in xs:
    cl_rep = classification_report(y_true=preds[0][1].numpy(), y_pred=(preds[0][0]>i).numpy().astype(int), 
                                  output_dict=True, target_names=class_names)
    cl_reports.append(cl_rep)

In [None]:
accs = [accuracy_multi(preds[0][0], preds[0][1], thresh=i, sigmoid=False) for i in xs]
dropout = [sum(np.sum((preds[0][0]>i).numpy().astype(int), axis=1)==0)/len(preds[0][0]) for i in xs]

fig = plt.figure(figsize=(5,4), facecolor='white')
plt.plot(xs, [x['weighted avg']['f1-score'] for x in cl_reports], label='F1-Score')
plt.plot(xs, [x['weighted avg']['recall'] for x in cl_reports], label='Recall')
plt.plot(xs, [x['weighted avg']['precision'] for x in cl_reports], label='Precision')
plt.plot(xs, accs, label='Multi-label Accuracy')
plt.plot(xs, dropout, label='Dropout Rate')
plt.xlabel('Classification Threshold')
#plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.legend(loc='center', prop={"size":10}, bbox_to_anchor=(0.5,1.15), ncol=3)
plt.tight_layout()
plt.show()
#fig.savefig('classification_thresh_metrics.png', dpi=300)

### confusion matrix

In [None]:
# confusion matrix setup
y_pred = (pred_raw.numpy() > 0.8).astype(int)
confusion_matrix = multilabel_confusion_matrix(y_true, y_pred)

In [None]:
sns.set(rc={'figure.facecolor':'white'})

def print_confusion_matrix(confusion_matrix, axes, class_label, class_names, fontsize=14):

    df_cm = pd.DataFrame(
        confusion_matrix, index=class_names, columns=class_names,
    )

    try:
        heatmap = sns.heatmap(df_cm, annot=True, vmin=0,vmax=1,fmt=".2f", cbar=False, ax=axes,cmap="Blues_r")
    except ValueError:
        raise ValueError("Confusion matrix values must be integers.")
    heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
    heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
    axes.set_ylabel('True label')
    axes.set_xlabel('Predicted label')
    axes.set_title("Confusion Matrix for class " + class_label)

In [None]:
fig, ax = plt.subplots(4, 5, figsize=(20, 10))
    
for axes, cfs_matrix, label in zip(ax.flatten(), confusion_matrix, labels):
    # changed this bit here to plot fractions - otherwise the high number of true negatives would
    # obscure some of the patterns
    norm_cfs_matrix = cfs_matrix/cfs_matrix.sum(axis=1, keepdims=True)
    print_confusion_matrix(norm_cfs_matrix, axes, label, ["N", "Y"])

fig.tight_layout()
plt.savefig('./confusion_matrix_probs_oversample.png',dpi=300)

## Look at top losses

In [None]:
losses = preds[0][3].numpy()
sorted_loss_idx = losses.argsort()[::-1]
top_losses = valid_df.Label.iloc[sorted_loss_idx[:1000]] # get true classes with top losses
losses_df = pd.DataFrame({'cl':np.array(top_losses), 'loss':losses[sorted_loss_idx[:1000]]})

In [None]:
means = {}
for cl in np.unique(losses_df['cl']):
    tmp = losses_df.loc[losses_df.cl==cl]
    means[cl] = tmp['loss'].mean()

In [None]:
plt.rcParams.update({'font.size': 15})

vals,counts=np.unique(top_losses, return_counts=True)
idx = counts.argsort()[::-1]
vals_sorted = vals[idx]
counts_sorted = counts[idx]

In [None]:
plt.rcParams.update({'font.size': 20})

fig,ax1 = plt.subplots(figsize=(7,7))

ax1.barh(vals_sorted[:15],counts_sorted[:15])
ax1.invert_yaxis()
ax1.set_ylabel("Classes")
ax1.set_xlabel("Number of Images")

#ax2 = plt.twinx()
#ax2.plot([means[x] for x in vals_sorted[:15]], vals_sorted[:15], color='k', label='Sine')
#ax2.invert_yaxis()
#ax2.set_ylabel('Line plot')
#ax2.plot([means[x] for x in vals_sorted[:15]],vals_sorted[:15])
#ax2.invert_yaxis()

#plt.xticks(rotation = 90)
#plt.xticks(x_pos,vals_sorted,rotation = 90)
plt.title("Top 15 Mislabeled Classes")
fig.set_facecolor('white')
plt.tight_layout()
#plt.show()
plt.savefig("Top_15_Losses.png", dpi=300)

In [None]:
pred_thresh = 0.6

# Most incorrect or top losses
num_cl = 15
top_cl = vals_sorted[:num_cl]
cl_thresh = 0.6

plt.subplots(figsize=(10,7), facecolor='white')
for i in range(num_cl): 
    curr_idx = valid_df.loc[valid_df.Label==top_cl[i]].index[1] # get index of first match with class
    tag = valid_df.ID.loc[curr_idx]
    img = plt.imread(path/"rgb_train"/f'{tag}.png')
    
    img,true_cl,_,img_preds = learn.predict(path/"rgb_train"/f'{tag}.png', with_input=True)
    pred_cl = (img_preds>cl_thresh).numpy().astype(int)
    pred_cl = np.array2string(np.asarray(np.where(pred_cl>0)))
    true_cl = valid_df.loc[curr_idx].Label
    #true_cl = top_cl[i]
    
    plt.subplot(3,5,i+1)
    plt.imshow(img.permute(2,1,0))
    #plt.imshow(img)
    plt.xticks([])
    plt.yticks([])
    plt.title(f'{true_cl} / {pred_cl}')
    
plt.suptitle(f'True Classes / Predicted Classes (Thresh={cl_thresh})', fontsize=20)
plt.tight_layout()
#plt.show()
plt.savefig('./True_and_predicted_imgs_top1000.png', dpi=300)

## segmentator

In [None]:
class Hook():
    def __init__(self, m):
        self.hook = m.register_forward_hook(self.hook_func)   
    def hook_func(self, m, i, o): self.stored = o.detach().clone()
    def __enter__(self, *args): return self
    def __exit__(self, *args): self.hook.remove()

In [None]:
class HookBwd():
    def __init__(self, m):
        self.hook = m.register_backward_hook(self.hook_func)   
    def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()
    def __enter__(self, *args): return self
    def __exit__(self, *args): self.hook.remove()

In [None]:
dfs.drop(dfs.columns[0],axis=1,inplace=True)
dfs.head(3)

In [None]:
def get_img(tag,show_img=False):
    """imagefile getter for a given tag. Can be an integer (df row index)
        or string (the actual image ID)"""
    if type(tag) == int:
        sample = dfs.iloc[tag]
        label = sample.Label
        ID = sample.ID
        header = f'ID: {ID}, Label: {label}'
        img = PILImage.create(get_x(sample))
        
    elif type(tag) == str:
        sample = dfs[dfs.ID == tag].reset_index(drop=True)
        label = sample.Label.loc[0]
        ID = sample.ID.loc[0]
        header = f'ID: {ID}, Label: {label}'
        img = PILImage.create(get_x(sample.loc[0]))
        
    if show_img: 
        print(header)
        img.show(figsize=(5,5))
    label = np.array([int(i) for i in label.split("|")])
    return img,ID,label

In [None]:
def visualize_grad_cam_cpu(x, cls,return_cam=True):
    """plot and return the grad-cam heatmap"""
    with HookBwd(learn.model[0]) as hookg:
        with Hook(learn.model[0]) as hook:
            output = learn.model.eval()(x.cpu())
            act = hook.stored
        output[0,cls].backward()
        grad = hookg.stored
    
    w = grad[0].mean(dim=[1,2], keepdim=True)
    cam_map = (w * act[0]).sum(0)


    _,ax = plt.subplots()
    x_dec = x[0,:,:,:].permute(1, 2, 0).cpu()
    x_dec.show(ctx=ax,alpha=0.5)
    ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,448,448,0),
                  interpolation='bilinear', cmap='magma');
    
    if return_cam:
        return cam_map.detach().cpu()

In [None]:
# old code. not used.
def visualize_grad_cam(x, cls,show_img=False):
    with HookBwd(learn.model[0]) as hookg:
        with Hook(learn.model[0]) as hook:
            output = learn.model.cuda().eval()(x.cuda())
            act = hook.stored
        output[0,cls].backward()
        grad = hookg.stored
    
    w = grad[0].mean(dim=[1,2], keepdim=True)
    cam_map = (w * act[0]).sum(0)

    plt_size = np.ceil(len(cls)/2)
    x_img = TensorImage(dls.train.decode((x,))[0][0])

    if plt_size > 1:
        fig = plt.figure()
        for c in np.arange(len(cls)):
            ax = fig.add_subplot(1,1,c)
            ax.imshow(x_img.numpy().transpose(1,2,0))
            ax.imshow(
                cam_map.detach().cpu(), alpha=0.6, extent=(0, 448, 448,0),
                interpolation='bicubic', cmap='magma'
            )
            ax.set_title(class_names[c])
    else:
        _,ax = plt.subplots()
        ax.imshow(x_img.numpy().transpose(1,2,0))
        ax.imshow(
            cam_map.detach().cpu(), alpha=0.6, extent=(0, 448, 448,0),
            interpolation='bicubic', cmap='magma'
        )
        ax.set_title(str(cls[0])+': '+class_names[cls[0]])

    plt.tight_layout()
    if show_img:
        plt.show()
        
    #_,ax = plt.subplots(len(cls))
    #x_img = TensorImage(dls.train.decode((x,))[0][0])
    #x_img.show(ax=ax)
    #ax.imshow(x_img.numpy().transpose(1,2,0))
    #ax.imshow(
    #    cam_map.detach().cpu(), alpha=0.6, extent=(0, 448, 448,0),
    #    interpolation='bicubic', cmap='magma'
    #)
    #ax.set_axis_off()
    
    #ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,448,448,0),
    #              interpolation='bilinear', cmap='magma');

## Single cell segmentation

In [None]:

NUC_MODEL = '../input/hpacellsegmentatormodelweights/dpn_unet_nuclei_v1.pth'
CELL_MODEL = '../input/hpacellsegmentatormodelweights/dpn_unet_cell_3ch_v1.pth'

segmentator = cellsegmentator.CellSegmentator(
    NUC_MODEL,
    CELL_MODEL,
    scale_factor=0.25,
    device='cuda',
    padding=True,
    multi_channel_model=True
)

In [None]:
# helper function to get files *** from the large version - 512x512 ones don't segment well ***
def build_image_names(image_id: str,dataset: str) -> list:
    root_dir = "/kaggle/input/hpa-single-cell-image-classification"
    # mt is the microtubules
    mt = f'{root_dir}/{dataset}/{image_id}_red.png'
    # er is the endoplasmic reticulum
    er = f'{root_dir}/{dataset}/{image_id}_yellow.png'
    # nu is the nuclei
    nu = f'{root_dir}/{dataset}/{image_id}_blue.png'
    
    return [[mt], [er], [nu]]

In [None]:
def get_mask(img_ID):
    """get 512x512 masks for a given image ID
        this uses the non-resized images, then downsizes the mask to 512x512."""
    images = build_image_names(image_id=img_ID,dataset="train")    
    # For nuclei
    nuc_segmentations = segmentator.pred_nuclei(images[2])

    # For full cells
    cell_segmentations = segmentator.pred_cells(images)
    nuclei_mask, cell_mask = label_cell(nuc_segmentations[0], cell_segmentations[0])
    mask_small = cv2.resize(cell_mask,dsize=(512,512))
    
    return mask_small

## merging everything

In [None]:
import numpy.ma as ma

In [None]:
# takes the ID/df row number, and prediction threshold.
valid_ids = list(dfs[dfs.is_valid].ID)


tag = "4f47bb78-bbb7-11e8-b2ba-ac1f6b6435d0"
#tag = '08705684-bbc5-11e8-b2bc-ac1f6b6435d0'
pred_thresh = 0.6

In [None]:
# get the image 
img,ID,true_label = get_img(tag,True)

#sample_pred, _ = learn.get_preds(dl=dls.test_dl([img]))
true_cl,_,sample_pred = learn.predict(path/"rgb_train"/f'{tag}.png', with_input=False)
pred_label = np.where(sample_pred> pred_thresh)[0]
print(pred_label)

In [None]:
# get learner-friendly image format and mask
x, = first(oversampled_dls.test_dl([img]))
mask_512 = get_mask(ID)

In [None]:
# loop over all the predicted classes here.

int_dict = {"cell":range(1,np.max(mask_512)+1)}
# for each class:
for p_cls in pred_label:
    # get acivation map and resize
    class_cam = visualize_grad_cam_cpu(x, p_cls)
    class_cam_512 = cv2.resize(class_cam.numpy(),dsize=(512,512))
    class_cam_intensities = []
    # loop through cells
    for i in range(1,np.max(mask_512)+1):
        # get the region of overlap between cell location and CAM, and take the mean intensity
        intensity = np.mean(ma.masked_where(mask_512 != i, class_cam_512).compressed())
        class_cam_intensities.append(abs(intensity))
    int_dict[str(p_cls)] = class_cam_intensities

In [None]:
#make dataframe. Each row is a cell, and avg intensities for each class.
results = pd.DataFrame(data=int_dict)
labels = np.array(list(results)[1:])
results.head(3)

In [None]:
# generate cell-level labels - if a cell has intensity > cam_thresh for a class, it'
cam_thresh = 0.015
all_cell_labels = []
for r in results.itertuples():
    use_inds = np.where(np.array(r[2:])>cam_thresh)[0]
    cell_label = [int(labels[i]) for i in use_inds]
    all_cell_labels.append(cell_label)

In [None]:
# display final results
results["pred_label"] = all_cell_labels
results

### for plotting the report figures

In [None]:
# for p_cls in pred_label:

p_cls = pred_label[0]
class_cam = visualize_grad_cam_cpu(x, p_cls)
class_cam_512_0 = cv2.resize(class_cam.numpy(),dsize=(512,512))

p_cls = pred_label[1]
class_cam = visualize_grad_cam_cpu(x, p_cls)
class_cam_512_1 = cv2.resize(class_cam.numpy(),dsize=(512,512))
    

In [None]:
clipped_mask_0 = ma.masked_where(mask_512 == 0, class_cam_512_0)
clipped_mask_1 = ma.masked_where(mask_512 == 0, class_cam_512_1)
k = 10
fsize=16
fig,ax = plt.subplots(2,2,figsize=(k,k))
ax[0,0].imshow(mask_512,cmap="viridis")

ax[0,0].set_title("Cell Segmentation",fontsize=fsize)
ax[0,1].imshow(class_cam_512_0,cmap="inferno")
ax[0,1].set_title(f"Grad-CAM, class {pred_label[0]}",fontsize=fsize)

ax[1,0].imshow(clipped_mask_0,cmap="inferno")
ax[1,0].set_title(f"Segmented Grad-CAM, class {pred_label[0]}",fontsize=fsize)

ax[1,1].imshow(clipped_mask_1,cmap="inferno")
ax[1,1].set_title(f"Segmented Grad-CAM, class {pred_label[1]}",fontsize=fsize)

for iax in ax:
    for jax in iax:
        jax.grid(False)
        jax.set_xticks([])
        jax.set_yticks([])
        
#plt.savefig("grad_cam_seg_example.png",dpi=300)

## a plot for training stats for un-augmented resnet model

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

train_df = pd.read_csv("../input/resnetnoaugtrainstats/resnet_no_aug_training_data.csv")

In [None]:
fig,axs = plt.subplots(1,2,figsize=(15,5))
axs[0].plot(train_df.epoch,train_df.train_loss,label="training loss")
axs[0].plot(train_df.epoch,train_df.valid_loss,label="validation loss")
axs[0].legend()

axs[1].plot(train_df.epoch,train_df.accuracy_multi,label="accuracy")
axs[1].plot(train_df.epoch,train_df.average_precision_score,label="avg precision")
axs[1].legend()

In [None]:
train_df

## everything below is old and not used

In [None]:
def Components(Image, shade = 230, cutoff = 5000, distance = 20):
    V = np.array((Image[:,:,2]> shade), dtype = np.int)
    V[:,0] = 0
    V[:, -1] = 0
    W = V[:,1:]- V[:,:-1]
    v = np.where(W==1)
    w = np.where(W == -1)
    
    n = len(v[0])
    In = []
    Out = []
    Comp = [i for i in range(n)]
    row = 0
    for i in range(n):
        new_row = 0 + v[0][i]
       
        if new_row == row:
            In.append(i)
        elif new_row == row+1:
            Out.append(i)
        
        else:
            for p in In:
                for q in Out:
                    a = v[1][p]
                    b = w[1][p]
                    c = v[1][q]
                    d = w[1][q]
             
                    if ((a <= c and c <=b) or (c <= a and a <= d)) and (b-a > distance and d-c > distance):
                    
                        if Comp[p] !=  Comp[q]:
                            # compute root of q:
                            root1 = p+0
                            root2 = Comp[p]
                            while root2 < root1:
                                root2, root1 = Comp[root2], root2
                            root3 = q+0
                            root4 = Comp[q]
                            while root4<root3:
                                root4, root3 = Comp[root4], root4
                            if root1 < root3:
                                Comp[root3] = root1
                                Comp[q] = root1
                                Comp[p] = root1
                            else:
                                Comp[root1] = root3
                                Comp[p] = root3
                                Comp[q] = root3
                                
                    
            
            if new_row == row+2:
                In, Out = Out, [i]
                row = row + 1
            else:
                In, Out = [i], []
                row = 0 + new_row
    
    
    
    for i in range(n):
        a = 0 + i
        b = Comp[i]
        while b < a:
            b, a  = Comp[b], b
        Comp[i] = b
    
    L1 = list(set(Comp)) 
    
    D1 = {i:0 for i in L1}
    Total_Weight = 0
    for i in range(n):
        a = v[1][i]
        b = w[1][i]
        Total_Weight+=(b-a+1)
        D1[Comp[i]]+=(b-a+1)
    
    L2 = [(-b,a) for a,b in D1.items() if b > cutoff]
    
    L2.sort()
    L2 = L2[:100]
    L2 = [a for b,a in L2]
    
    D2 = {}
    
    for i in range(len(L2)):
        D2[L2[i]] = i
    
    t = len(L2)
    x_max = [0 for _ in range(t)]
    x_min = [V.shape[1] for _ in range(t)]
    y_max = [0 for _ in range(t)]
    y_min = [V.shape[0] for _ in range(t)]
    
    Segmented_Image = np.zeros((Image.shape[0],Image.shape[1]), dtype = np.uint8)
    for i in range(n):
        if Comp[i] in L2:
            value = t - D2[Comp[i]] 
            row = v[0][i]
            a = v[1][i]
            b = w[1][i]
            Segmented_Image[row,a:b] = value
         

    return Segmented_Image      

In [None]:
class_names

In [None]:
def build_image_names(image_id: str) -> list:
    path = '../input/hpa512x512dataset/train/'
    # mt is the mitchondria
    mt = f'{path}{image_id}_red.png'
    
    # er is the endoplasmic reticulum
    er = f'{path}{image_id}_yellow.png'
    
    # nu is the nuclei
    nu = f'{path}{image_id}_blue.png'
    
    return [[mt], [er], [nu]]
def Factory_Segmentation(file_id):
    images = build_image_names(file_id)
    cell = segmentator.pred_cells(images)[0]
    nuclei = segmentator.pred_nuclei(images[2])[0]
    fine_grained_segmentation = label_cell(nuclei,cell)[1]
    return fine_grained_segmentation
def Coarse_Segmentation(file_id):
    images = build_image_names(file_id)
    cell = segmentator.pred_cells(images)[0]
    coarse_segmentation = Components(cell)
    return coarse_segmentation
def color_image(file_id):
    a, b, c = build_image_names(file_id)
    image0 = cv2.imread(a[0])
    image1 = cv2.imread(b[0])
    image2 = cv2.imread(c[0])
    image = np.zeros_like(image0)
    image[:,:,0] = image0[:,:,0]
    image[:,:,1] = image1[:,:,0]
    image[:,:,2] = image2[:,:,0]
    return image

In [None]:
file_0 = df['ID'][0]
coarse_segmentation_0 = Coarse_Segmentation(file_0)
precise_segmentation_0 = Factory_Segmentation(file_0)

In [None]:
build_image_names(file_0)

In [None]:
_,(ax1,ax2) = plt.subplots(1,2)
ax1.imshow(precise_segmentation_0)
ax2.imshow(color_image(file_0))