In [None]:
!pip install -q "../input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl"

In [None]:
import sys
sys.path.append("../input/hpapytorchzoozip/pytorch_zoo-master")
sys.path.append('../input/timmlast')
sys.path.append('../input/ttach-kaggle/ttach')
import ttach
import timm
import pytorch_zoo

In [None]:
sys.path.append("../input/hpacellsegmentatormaster/HPA-Cell-Segmentation-master")
import hpacellseg.cellsegmentator as cellsegmentator
from hpacellseg.utils import label_cell, label_nuclei

In [None]:
import pandas as pd
import numpy as np
import os
import tqdm
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import gc
import base64
from pycocotools import _mask as coco_mask
import typing as t
import zlib

from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torch
import torch.nn as nn

from collections import OrderedDict
import ttach as tta

# Cell Classification

In [None]:
df = pd.read_csv('../input/hpa-single-cell-image-classification/sample_submission.csv')

In [None]:
ROOT = '../input/hpa-single-cell-image-classification/'
train_or_test = 'test'

In [None]:
def get_cropped_cell(img, msk):
    bmask = msk.astype(int)[...,None]
    masked_img = img * bmask
    true_points = np.argwhere(bmask)
    top_left = true_points.min(axis=0)
    bottom_right = true_points.max(axis=0)
    cropped_arr = masked_img[top_left[0]:bottom_right[0]+1,top_left[1]:bottom_right[1]+1]
    return cropped_arr

In [None]:
def get_stats(cropped_cell):
    x = (cropped_cell/255.0).reshape(-1,3).mean(0)
    x2 = ((cropped_cell/255.0)**2).reshape(-1,3).mean(0)
    return x, x2

In [None]:
def read_img(image_id, color, train_or_test='test', image_size=None):
    filename = f'{ROOT}/{train_or_test}/{image_id}_{color}.png'
    assert os.path.exists(filename), f'not found {filename}'
    img = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
    if image_size is not None:
        img = cv2.resize(img, (image_size, image_size))
    if img.max() > 255:
        img_max = img.max()
        img = (img/255).astype('uint8')
    return img

In [None]:
def encode_binary_mask(mask: np.ndarray) -> t.Text:
  """Converts a binary mask into OID challenge encoding ascii text."""

  # check input mask --
  if mask.dtype != np.bool:
    raise ValueError(
        "encode_binary_mask expects a binary mask, received dtype == %s" %
        mask.dtype)

  mask = np.squeeze(mask)
  if len(mask.shape) != 2:
    raise ValueError(
        "encode_binary_mask expects a 2d mask, received shape == %s" %
        mask.shape)

  # convert input mask to expected COCO API input --
  mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
  mask_to_encode = mask_to_encode.astype(np.uint8)
  mask_to_encode = np.asfortranarray(mask_to_encode)

  # RLE encode mask --
  encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

  # compress and base64 encoding --
  binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
  base64_str = base64.b64encode(binary_str)
  return base64_str.decode('ascii')

In [None]:
import warnings
from torch.serialization import SourceChangeWarning
warnings.filterwarnings("ignore", category=SourceChangeWarning)
NUC_MODEL = "../input/hpacellsegmentatormodelweights/dpn_unet_nuclei_v1.pth"
CELL_MODEL = "../input/hpacellsegmentatormodelweights/dpn_unet_cell_3ch_v1.pth"
segmentator = cellsegmentator.CellSegmentator(
    NUC_MODEL,
    CELL_MODEL,
    scale_factor=0.25,
    device="cuda",
    padding=True,
    multi_channel_model=True,
);


In [None]:
def read_sample_image_seg(filename):
    
    '''
    read individual images
    of different filters (R, B, Y)
    and stack them for segmentation.
    ---------------------------------
    Arguments:
    filename -- sample image file path
    
    Returns:
    stacked_images -- stacked (RBY) image path in lists.
    '''
    
    red = os.path.join(ROOT, 'test/') + filename + "_red.png"
    blue = os.path.join(ROOT, 'test/') + filename + "_blue.png"
    yellow = os.path.join(ROOT, 'test/') + filename + "_yellow.png"

    stacked_images = [[red], [yellow], [blue]]
    return stacked_images, red, blue, yellow

# segment cell 
def segmentCell(image, segmentator):
    
    '''
    segment cell and nuclei from
    microtubules, endoplasmic reticulum,
    and nuclei (R, B, Y) filters.
    ------------------------------------
    Argument:
    image -- (R, B, Y) list of image arrays
    segmentator -- CellSegmentator class object
    
    Returns:
    cell_mask -- segmented cell mask
    '''
    
    nuc_segmentations = segmentator.pred_nuclei(image[2])
    cell_segmentations = segmentator.pred_cells(image)
    nuclei_mask, cell_mask = label_cell(nuc_segmentations[0], cell_segmentations[0])
    
    gc.collect(); del nuc_segmentations; del cell_segmentations; del nuclei_mask
    
    return cell_mask

In [None]:
!mkdir cells

In [None]:
x_tot,x2_tot = [],[]
lbls = []
num_files = len(df)
all_cells = []


for idx in tqdm.tqdm(range(num_files)):
    image_id = df.iloc[idx].ID
    ryb, r, b, y = read_sample_image_seg(image_id)
    cell_mask = segmentCell(ryb, segmentator)

    red = read_img(image_id, "red", train_or_test, None)
    green = read_img(image_id, "green", train_or_test, None)
    blue = read_img(image_id, "blue", train_or_test, None)
    #yellow = read_img(image_id, "yellow", train_or_test, image_size)
    stacked_image = np.transpose(np.array([blue, green, red]), (1,2,0))

    for j in range(1, np.max(cell_mask) + 1):
        bmask = (cell_mask == j)
        enc = encode_binary_mask(bmask)
        cropped_cell = get_cropped_cell(stacked_image, bmask)
        fname = f'{image_id}_{j}.jpg'
        cv2.imwrite("cells/"+fname,cropped_cell)
        x, x2 = get_stats(cropped_cell)
        x_tot.append(x)
        x2_tot.append(x2)
        all_cells.append({
            'image_id': image_id,
            'fname': fname,
            'r_mean': x[0],
            'g_mean': x[1],
            'b_mean': x[2],
            'cell_id': j,
            'size1': cropped_cell.shape[0],
            'size2': cropped_cell.shape[1],
            'enc': enc,
        })

#image stats
img_avr =  np.array(x_tot).mean(0)
img_std =  np.sqrt(np.array(x2_tot).mean(0) - img_avr**2)
cell_df = pd.DataFrame(all_cells)
cell_df.to_csv('cell_df.csv', index=False)
print('mean:',img_avr, ', std:', img_std)

In [None]:
df = pd.read_csv('./cell_df.csv')

In [None]:
df.head()

In [None]:
valid_transforms = A.Compose([
                        A.Resize(width=224, height=224),
                        A.Normalize(),
                        ToTensorV2(),
                        ])

class CellDataset(Dataset):
    def __init__(self, data_dir, csv_file, transform=None):
        super().__init__()

        self.data_dir = data_dir
        self.df = csv_file
        self.transforms = transform           
        #self.cell_types = self.df[['0','1','2','3','4','5','6','7','8','9','10','11','12','13','14','15','16','17','18']].values
        self.img_ids = self.df['image_id'].values
        self.cell_ids = self.df['cell_id'].values

    def __len__(self):
        return len(self.img_ids)
        #return 100

    def get_image(self, index):
        # image_id = self.img_ids[index % self.__len__()]
        # cell_id = self.cell_ids[index % self.__len__()]
        image_id = self.img_ids[index]
        cell_id = self.cell_ids[index]
        
        img_path = os.path.join(self.data_dir, 'cells', image_id + '_' + str(cell_id) + '.jpg')
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transforms(image=img)
        img = img['image']
        
        return img

    def __getitem__(self, index):

        x = self.get_image(index)
        #y = self.cell_types[index]
        #y = torch.from_numpy(y).float()
        return x

In [None]:
test_dataset = CellDataset(data_dir='', csv_file=df, transform=valid_transforms)

In [None]:
test_loader = DataLoader(test_dataset,
                         batch_size=8,
                         shuffle=False,
                         num_workers=4,
                         drop_last=False)

In [None]:
class Net(nn.Module):
    def __init__(self, name = 'efficientnet_b0', num_classes=19):
        super(Net, self).__init__()
        self.model = timm.create_model(name, pretrained=False, num_classes=num_classes)

    def forward(self, x):
        out = self.model(x)

        return out


In [None]:
from collections import OrderedDict
def update_state_dict(state_dict):
    
    new_state_dict = OrderedDict()
    for key in state_dict.keys():
        new_state_dict['.'.join(key.split('.')[1:])] = state_dict[key]
    
    return new_state_dict

In [None]:
model_b1_f0 = Net(name = 'efficientnet_b1')
model_b1_f0.load_state_dict(update_state_dict(torch.load('../input/efficientnet-b1-224-fold-0/epoch1-valid_loss_epoch0.118.pth')))
model_b1_f0.cuda();
model_b1_f0.eval();

model_b1_f1 = Net(name = 'efficientnet_b1')
model_b1_f1.load_state_dict(update_state_dict(torch.load('../input/efficientnet-b1-224-fold-1/epoch1-valid_loss_epoch0.117.pth')))
model_b1_f1.cuda();
model_b1_f1.eval();

model_b1_f2 = Net(name = 'efficientnet_b1')
model_b1_f2.load_state_dict(update_state_dict(torch.load('../input/efficientnet-b1-224-fold-4/epoch1-valid_loss_epoch0.119.pth')))
model_b1_f2.cuda();
model_b1_f2.eval();

model_b1_f3 = Net(name = 'efficientnet_b1')
model_b1_f3.load_state_dict(update_state_dict(torch.load('../input/efficientnet-b1-224-fold-3/epoch1-valid_loss_epoch0.116.pth')))
model_b1_f3.cuda();
model_b1_f3.eval();

model_b1_f4 = Net(name = 'efficientnet_b1')
model_b1_f4.load_state_dict(update_state_dict(torch.load('../input/efficientnet-b1-224-fold-44/epoch1-valid_loss_epoch0.117.pth')))
model_b1_f4.cuda();
model_b1_f4.eval();



model_b0_f0 = Net(name = 'efficientnet_b0')
model_b0_f0.load_state_dict(update_state_dict(torch.load('../input/efficientnet-b0-224-fold-0/epoch1-valid_loss_epoch0.119.pth')))
model_b0_f0.cuda();
model_b0_f0.eval();


model_b0_f1 = Net(name = 'efficientnet_b0')
model_b0_f1.load_state_dict(update_state_dict(torch.load('../input/efficientnet-b0-224-fold-1/epoch1-valid_loss_epoch0.118.pth')))
model_b0_f1.cuda();
model_b0_f1.eval();

model_b0_f2 = Net(name = 'efficientnet_b0')
model_b0_f2.load_state_dict(update_state_dict(torch.load('../input/efficientnet-b0-224-fold-2/epoch1-valid_loss_epoch0.120.pth')))
model_b0_f2.cuda();
model_b0_f2.eval();

model_b0_f3 = Net(name = 'efficientnet_b0')
model_b0_f3.load_state_dict(update_state_dict(torch.load('../input/efficientnet-b0-224-fold-3/epoch1-valid_loss_epoch0.117.pth')))
model_b0_f3.cuda();
model_b0_f3.eval();

model_b0_f4 = Net(name = 'efficientnet_b0')
model_b0_f4.load_state_dict(update_state_dict(torch.load('../input/efficientnet-b0-224-fold-4/epoch8-valid_loss_epoch0.117.pth')))
model_b0_f4.cuda();
model_b0_f4.eval();

In [None]:
pred = torch.FloatTensor()
pred = pred.cuda()

In [None]:
tta1 = tta.ClassificationTTAWrapper(model_b1_f0, tta.aliases.flip_transform())
tta2 = tta.ClassificationTTAWrapper(model_b1_f1, tta.aliases.flip_transform())
tta3 = tta.ClassificationTTAWrapper(model_b1_f2, tta.aliases.flip_transform())
tta4 = tta.ClassificationTTAWrapper(model_b1_f3, tta.aliases.flip_transform())
tta5 = tta.ClassificationTTAWrapper(model_b1_f4, tta.aliases.flip_transform())
tta6 = tta.ClassificationTTAWrapper(model_b0_f0, tta.aliases.flip_transform())
tta7 = tta.ClassificationTTAWrapper(model_b0_f1, tta.aliases.flip_transform())
tta8 = tta.ClassificationTTAWrapper(model_b0_f2, tta.aliases.flip_transform())
tta9 = tta.ClassificationTTAWrapper(model_b0_f3, tta.aliases.flip_transform())
tta10 = tta.ClassificationTTAWrapper(model_b0_f4, tta.aliases.flip_transform())

In [None]:
with torch.no_grad():
    for inp in tqdm.tqdm(test_loader):
        bs, c, h, w = inp.size()
        input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda())
        
        # output = model(input_var)
        output = (tta1(input_var)+tta2(input_var)+\
                    tta3(input_var)+tta4(input_var)+\
                    tta5(input_var)+tta6(input_var)+\
                    tta7(input_var)+tta8(input_var)+tta9(input_var)+tta10(input_var))/10
        output_mean = output.view(bs, -1)
        pred = torch.cat((pred, output_mean.data), 0)

In [None]:
pred_torch = torch.sigmoid(pred.cpu())

In [None]:
def isNaN(num):
    return num != num

In [None]:
cell_df = pd.read_csv('./cell_df.csv')
cell_df['cls'] = ''

In [None]:
threshold = 0.0

for i in range(pred_torch.shape[0]): 
    p = torch.nonzero(pred_torch[i] > threshold).squeeze().numpy().tolist()
    if type(p) != list: p = [p]
    if len(p) == 0: cls = [(pred_torch[i].argmax().item(), pred_torch[i].max().item())]
    else: cls = [(x, pred_torch[i][x].item()) for x in p]
    cell_df['cls'].loc[i] = cls

In [None]:
def combine(r):
    cls = r[0]
    enc = r[1]
    classes = [str(c[0]) + ' ' + str(c[1]) + ' ' + enc for c in cls]
    return ' '.join(classes)

combine(cell_df[['cls', 'enc']].loc[24]);

In [None]:
cell_df['pred'] = cell_df[['cls', 'enc']].apply(combine, axis=1)
cell_df.head()

In [None]:
subm = cell_df.groupby(['image_id'])['pred'].apply(lambda x: ' '.join(x)).reset_index()
# subm = subm.loc[3:]
subm.head()

In [None]:
sample_submission = pd.read_csv('../input/hpa-single-cell-image-classification/sample_submission.csv')
sample_submission.head()

In [None]:
sub = pd.merge(
    sample_submission,
    subm,
    how="left",
    left_on='ID',
    right_on='image_id',
)
sub.head()

In [None]:
def isNaN(num):
    return num != num

for i, row in sub.iterrows():
    if isNaN(row['pred']): continue
    sub.PredictionString.loc[i] = row['pred']

In [None]:
sub = sub[sample_submission.columns]
sub.head()

In [None]:
# sub.to_csv('submission.csv', index=False)

In [None]:
sub.head()

In [None]:
!rm -r cells

In [None]:
!rm cell_df.csv

# Image Classificaion

In [None]:
ss_df = sub

In [None]:
!pip install /kaggle/input/kerasapplications -q
!pip install /kaggle/input/efficientnet-keras-source-code/ -q --no-deps

In [None]:
import numpy as np
import pandas as pd
print("\n... INSTALLING AND IMPORTING CELL-PROFILER TOOL (HPACELLSEG) ...\n")
try:
    import hpacellseg.cellsegmentator as cellsegmentator
    from hpacellseg.utils import label_cell
except:
    !pip install -q "/kaggle/input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl"
    !pip install -q "/kaggle/input/hpapytorchzoozip/pytorch_zoo-master"
    !pip install -q "/kaggle/input/hpacellsegmentatormaster/HPA-Cell-Segmentation-master"
    import hpacellseg.cellsegmentator as cellsegmentator
    from hpacellseg.utils import label_cell

print("\n... OTHER IMPORTS STARTING ...\n")
print("\n\tVERSION INFORMATION")

# Machine Learning and Data Science Imports
import tensorflow as tf; print(f"\t\t– TENSORFLOW VERSION: {tf.__version__}");
import pandas as pd; pd.options.mode.chained_assignment = None;
import numpy as np; print(f"\t\t– NUMPY VERSION: {np.__version__}");
import torch

import pandas as pd
import os

import efficientnet.tfkeras as efn
import numpy as np
import pandas as pd
import tensorflow as tf

# Built In Imports
from collections import Counter
from datetime import datetime
import multiprocessing
from glob import glob
import warnings
import requests
import imageio
import IPython
import urllib
import zipfile
import pickle
import random
import shutil
import string
import math
import tqdm
import time
import gzip
import sys
import ast
import csv; csv.field_size_limit(sys.maxsize)
import io
import os
import gc
import re

# Visualization Imports
from matplotlib.colors import ListedColormap
import matplotlib.patches as patches
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import plotly.express as px
import seaborn as sns
from PIL import Image
import matplotlib; print(f"\t\t– MATPLOTLIB VERSION: {matplotlib.__version__}");
import plotly
import PIL
import cv2

# Submission Imports
from pycocotools import _mask as coco_mask
import typing as t
import base64
import zlib

# PRESETS
LBL_NAMES = ["Nucleoplasm", "Nuclear Membrane", "Nucleoli", "Nucleoli Fibrillar Center", "Nuclear Speckles", "Nuclear Bodies", "Endoplasmic Reticulum", "Golgi Apparatus", "Intermediate Filaments", "Actin Filaments", "Microtubules", "Mitotic Spindle", "Centrosome", "Plasma Membrane", "Mitochondria", "Aggresome", "Cytosol", "Vesicles", "Negative"]
INT_2_STR = {x:LBL_NAMES[x] for x in np.arange(19)}
INT_2_STR_LOWER = {k:v.lower().replace(" ", "_") for k,v in INT_2_STR.items()}
STR_2_INT_LOWER = {v:k for k,v in INT_2_STR_LOWER.items()}
STR_2_INT = {v:k for k,v in INT_2_STR.items()}
FIG_FONT = dict(family="Helvetica, Arial", size=14, color="#7f7f7f")
LABEL_COLORS = [px.colors.label_rgb(px.colors.convert_to_RGB_255(x)) for x in sns.color_palette("Spectral", len(LBL_NAMES))]
LABEL_COL_MAP = {str(i):x for i,x in enumerate(LABEL_COLORS)}

print("\n\n... IMPORTS COMPLETE ...\n")

##### THIS IS FOR PROTOTYPING AND PUBLIC LB PROBING #####
ONLY_PUBLIC = True
##### THIS IS FOR PROTOTYPING AND PUBLIC LB PROBING#####

if ONLY_PUBLIC:
    print("\n... ONLY INFERRING ON PUBLIC TEST DATA (USING PRE-PROCESSED DF) ...\n")
else:
    # Stop Tensorflow From Eating All The Memory
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "... Physical GPUs,", len(logical_gpus), "Logical GPUs ...\n")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

In [None]:
sub_df = pd.read_csv('../input/hpa-single-cell-image-classification/sample_submission.csv')

if sub_df.shape[0] != 559:
    def auto_select_accelerator():
        try:
            tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
            tf.config.experimental_connect_to_cluster(tpu)
            tf.tpu.experimental.initialize_tpu_system(tpu)
            strategy = tf.distribute.experimental.TPUStrategy(tpu)
            print("Running on TPU:", tpu.master())
        except ValueError:
            strategy = tf.distribute.get_strategy()
        print(f"Running on {strategy.num_replicas_in_sync} replicas")

        return strategy


    def build_decoder(with_labels=True, target_size=(300, 300), ext='jpg'):
        def decode(path):
            file_bytes = tf.io.read_file(path)
            if ext == 'png':
                img = tf.image.decode_png(file_bytes, channels=3)
            elif ext in ['jpg', 'jpeg']:
                img = tf.image.decode_jpeg(file_bytes, channels=3)
            else:
                raise ValueError("Image extension not supported")

            img = tf.cast(img, tf.float32) / 255.0
            img = tf.image.resize(img, target_size)

            return img

        def decode_with_labels(path, label):
            return decode(path), label

        return decode_with_labels if with_labels else decode


    def build_augmenter(with_labels=True):
        def augment(img):
            img = tf.image.random_flip_left_right(img)
            img = tf.image.random_flip_up_down(img)
            return img

        def augment_with_labels(img, label):
            return augment(img), label

        return augment_with_labels if with_labels else augment


    def build_dataset(paths, labels=None, bsize=32, cache=True,
                      decode_fn=None, augment_fn=None,
                      augment=True, repeat=True, shuffle=1024, 
                      cache_dir=""):
        if cache_dir != "" and cache is True:
            os.makedirs(cache_dir, exist_ok=True)

        if decode_fn is None:
            decode_fn = build_decoder(labels is not None)

        if augment_fn is None:
            augment_fn = build_augmenter(labels is not None)

        AUTO = tf.data.experimental.AUTOTUNE
        slices = paths if labels is None else (paths, labels)

        dset = tf.data.Dataset.from_tensor_slices(slices)
        dset = dset.map(decode_fn, num_parallel_calls=AUTO)
        dset = dset.cache(cache_dir) if cache else dset
        dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
        dset = dset.repeat() if repeat else dset
        dset = dset.shuffle(shuffle) if shuffle else dset
        dset = dset.batch(bsize).prefetch(AUTO)

        return dset

    COMPETITION_NAME = "hpa-single-cell-image-classification"
    strategy = auto_select_accelerator()
    BATCH_SIZE = strategy.num_replicas_in_sync * 16

    IMSIZE = (224, 240, 260, 300, 380, 456, 528, 600)

    load_dir = f"/kaggle/input/{COMPETITION_NAME}/"
    sub_df = pd.read_csv('../input/hpa-single-cell-image-classification/sample_submission.csv')
    #sub_df = ss_df.copy()

    sub_df = sub_df.drop(sub_df.columns[1:],axis=1)

    for i in range(19):
        sub_df[f'{i}'] = pd.Series(np.zeros(sub_df.shape[0]))


    test_paths = load_dir + "/test/" + sub_df['ID'] + '_green.png'
    # Get the multi-labels
    label_cols = sub_df.columns[1:]

    test_decoder = build_decoder(with_labels=False, target_size=(IMSIZE[7], IMSIZE[7]))
    dtest = build_dataset(
        test_paths, bsize=BATCH_SIZE, repeat=False, 
        shuffle=False, augment=False, cache=False,
        decode_fn=test_decoder
    )

    with strategy.scope():
        model = tf.keras.models.load_model(
            '../input/hpa-classification-efnb7-train/model_green.h5'
        )

    model.summary()
    sub_df[label_cols] = model.predict(dtest, verbose=1)

    sub_df.head()

    ss_df = pd.merge(ss_df, sub_df, on = 'ID', how = 'left')

    for i in range(ss_df.shape[0]):
        if ss_df.loc[i,'PredictionString'] == '0 1 eNoLCAgIMAEABJkBdQ==':
            continue
        a = ss_df.loc[i,'PredictionString']
        b = a.split()
        for j in range(int(len(a.split())/3)):
            for k in range(19):
                if int(b[0 + 3 * j]) == k:

                    c = b[0 + 3 * j + 1]               
                    b[0 + 3 * j + 1] = str(ss_df.loc[i,f'{k}'] * 0.6 + float(c) * 0.4)# * 0.9 + float(c) * 0.1

        ss_df.loc[i,'PredictionString'] = ' '.join(b)

    ss_df = ss_df[['ID','ImageWidth','ImageHeight','PredictionString']]
    ss_df.to_csv('submission.csv',index = False)
else:
    def auto_select_accelerator():
        try:
            tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
            tf.config.experimental_connect_to_cluster(tpu)
            tf.tpu.experimental.initialize_tpu_system(tpu)
            strategy = tf.distribute.experimental.TPUStrategy(tpu)
            print("Running on TPU:", tpu.master())
        except ValueError:
            strategy = tf.distribute.get_strategy()
        print(f"Running on {strategy.num_replicas_in_sync} replicas")

        return strategy


    def build_decoder(with_labels=True, target_size=(300, 300), ext='jpg'):
        def decode(path):
            file_bytes = tf.io.read_file(path)
            if ext == 'png':
                img = tf.image.decode_png(file_bytes, channels=3)
            elif ext in ['jpg', 'jpeg']:
                img = tf.image.decode_jpeg(file_bytes, channels=3)
            else:
                raise ValueError("Image extension not supported")

            img = tf.cast(img, tf.float32) / 255.0
            img = tf.image.resize(img, target_size)

            return img

        def decode_with_labels(path, label):
            return decode(path), label

        return decode_with_labels if with_labels else decode


    def build_augmenter(with_labels=True):
        def augment(img):
            img = tf.image.random_flip_left_right(img)
            img = tf.image.random_flip_up_down(img)
            return img

        def augment_with_labels(img, label):
            return augment(img), label

        return augment_with_labels if with_labels else augment


    def build_dataset(paths, labels=None, bsize=32, cache=True,
                      decode_fn=None, augment_fn=None,
                      augment=True, repeat=True, shuffle=1024, 
                      cache_dir=""):
        if cache_dir != "" and cache is True:
            os.makedirs(cache_dir, exist_ok=True)

        if decode_fn is None:
            decode_fn = build_decoder(labels is not None)

        if augment_fn is None:
            augment_fn = build_augmenter(labels is not None)

        AUTO = tf.data.experimental.AUTOTUNE
        slices = paths if labels is None else (paths, labels)

        dset = tf.data.Dataset.from_tensor_slices(slices)
        dset = dset.map(decode_fn, num_parallel_calls=AUTO)
        dset = dset.cache(cache_dir) if cache else dset
        dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
        dset = dset.repeat() if repeat else dset
        dset = dset.shuffle(shuffle) if shuffle else dset
        dset = dset.batch(bsize).prefetch(AUTO)

        return dset

    COMPETITION_NAME = "hpa-single-cell-image-classification"
    strategy = auto_select_accelerator()
    BATCH_SIZE = strategy.num_replicas_in_sync * 16

    IMSIZE = (224, 240, 260, 300, 380, 456, 528, 600)

    load_dir = f"/kaggle/input/{COMPETITION_NAME}/"
    sub_df = pd.read_csv('../input/hpa-single-cell-image-classification/sample_submission.csv')
    sub_df = ss_df.copy()

    sub_df = sub_df.drop(sub_df.columns[1:],axis=1)

    for i in range(19):
        sub_df[f'{i}'] = pd.Series(np.zeros(sub_df.shape[0]))


    test_paths = load_dir + "/test/" + sub_df['ID'] + '_green.png'
    # Get the multi-labels
    label_cols = sub_df.columns[1:]

    test_decoder = build_decoder(with_labels=False, target_size=(IMSIZE[7], IMSIZE[7]))
    dtest = build_dataset(
        test_paths, bsize=BATCH_SIZE, repeat=False, 
        shuffle=False, augment=False, cache=False,
        decode_fn=test_decoder
    )

    with strategy.scope():
        model = tf.keras.models.load_model(
            '../input/hpa-classification-efnb7-train/model_green.h5'
        )

    model.summary()
    sub_df[label_cols] = model.predict(dtest, verbose=1)

    sub_df.head()

    ss_df = pd.merge(ss_df, sub_df, on = 'ID', how = 'left')

    for i in range(ss_df.shape[0]):
        if ss_df.loc[i,'PredictionString'] == '0 1 eNoLCAgIMAEABJkBdQ==':
            continue
        a = ss_df.loc[i,'PredictionString']
        b = a.split()
        for j in range(int(len(a.split())/3)):
            for k in range(19):
                if int(b[0 + 3 * j]) == k:

                    c = b[0 + 3 * j + 1]               
                    b[0 + 3 * j + 1] = str(ss_df.loc[i,f'{k}'] * 0.6 + float(c) * 0.4)# * 0.9 + float(c) * 0.1

        ss_df.loc[i,'PredictionString'] = ' '.join(b)

    ss_df = ss_df[['ID','ImageWidth','ImageHeight','PredictionString']]
    ss_df.to_csv('submission.csv',index = False)