# Introduction
This notebook is to compare the score whether the models are prepared for each cell type or not.
1. Build cell type classifier (this notebook)
1. Build segmentation model (shown in <a href=https://www.kaggle.com/yoshikuwano/classified-by-cell-types-before-segmentation-2-2>2/2 notebook</a>)

# Imports

In [None]:
import os, glob, warnings, random
import cv2
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from pprint import pprint
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import DataLoader, Dataset

from albumentations import Normalize, Resize, Compose
from albumentations.pytorch import ToTensorV2

warnings.filterwarnings("ignore")

def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

fix_all_seeds(2021)

# Paramaters

In [None]:
INPUT_PATH = "../input/sartorius-cell-instance-segmentation"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('DEVICE: ', DEVICE)

# Input Data

In [None]:
df_train = pd.read_csv(INPUT_PATH + '/train.csv')
df_train = df_train.groupby("id")[['cell_type']].first().reset_index()
display(df_train)

In [None]:
# Use train and train_semi_supervised data
# images
train_image_paths = [INPUT_PATH + f'/train/{i}.png' for i in df_train['id']]
semi_image_paths = glob.glob(INPUT_PATH + '/train_semi_supervised/*.png')
train_image_paths.extend(semi_image_paths)

# labels
train_labels = df_train['cell_type'].to_list()
semi_labels = [path.split('/')[-1].split('[')[0] for path in semi_image_paths]
semi_labels = ['astro' if label=='astros' else label for label in semi_labels]
train_labels.extend(semi_labels)

df = pd.DataFrame({'image_path': train_image_paths, 'cell_type': train_labels})
display(df)

##  Dataset

In [None]:
IMAGE_RESIZE = (224, 224)
RESNET_MEAN = (0.485, 0.456, 0.406)
RESNET_STD = (0.229, 0.224, 0.225)

class DatasetImageCelltype(Dataset):
    def __init__(self, df):
        self.df = df
        self.image_paths = df['image_path']
        self.labels = df['cell_type']
        
    def __getitem__(self, idx):
        # image
        transforms = Compose([Resize(IMAGE_RESIZE[0], IMAGE_RESIZE[1]), 
                              Normalize(mean=RESNET_MEAN, std=RESNET_STD, p=1), 
                              ToTensorV2()])
        image_path = self.image_paths.iloc[idx]
        image = cv2.imread(image_path)
        image = transforms(image=image)['image']
        # label
        label_list = ['shsy5y', 'astro', 'cort']
        label = self.labels.iloc[idx]
        label_id = label_list.index(label)
       
        return {'image': image, 'label': label_id}

    def __len__(self):
        return len(self.df)

In [None]:
# Split into train and validation
df_train, df_valid = train_test_split(df, test_size=0.20)

# Dataset
ds_train = DatasetImageCelltype(df_train)
ds_valid = DatasetImageCelltype(df_valid)
# Data loader
dl_train = DataLoader(ds_train, batch_size=64, num_workers=0, pin_memory=True, shuffle=True)
dl_valid = DataLoader(ds_valid, batch_size=64, num_workers=0, pin_memory=True, shuffle=False)

print(f'Number of train dataset {len(ds_train)}')
print(f'Number of valid dataset {len(ds_valid)}')

# Classifier Model

## Modeling

In [None]:
from torchvision.models import resnet34

def Resnet():
    model = resnet34(True)
    model.fc = torch.nn.Linear(512, 3)
    return model

model = Resnet()

## Train classifier

In [None]:
LEARNING_RATE = 5e-4
EPOCHS = 5

model.to(DEVICE)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
for epoch in range(1, EPOCHS + 1):
    print(f'Epoch: {epoch}/{EPOCHS}')
    
    # Train on extra data
    model.train()
    optimizer.zero_grad()
    loss_train = 0.0
    correct_train = 0.0
       
    # Train on train data
    for data in tqdm(dl_train, total=len(dl_train), desc='[train]'):
        # Input
        images, labels = data['image'], data['label']
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        # Forward
        outputs = model(images) # probabilities
        loss = criterion(outputs, labels)
        loss_train += loss
        # Back propagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # Metric
        outputs = outputs.argmax(dim=1) # one hot vector
        correct_train += (labels==outputs).sum()
    
    loss_train = loss_train / len(dl_train)
    acc_train = correct_train / len(ds_train)
    print(f'Train loss: {loss_train:.4f}, Train accuracy: {acc_train*100:.2f}%')

    # Validation
    model.eval()
    loss_valid = 0.0
    correct_valid = 0.0
    with torch.no_grad():
        for data in tqdm(dl_valid, total=len(dl_valid), desc='[valid]'):
            images, labels = data['image'], data['label']
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images) # probabilities
            loss_valid += criterion(outputs, labels)
            outputs = outputs.argmax(dim=1) # one hot vector
            correct_valid += (labels==outputs).sum()
            
    
    loss_valid = loss_valid / len(dl_valid)
    acc_valid = correct_valid / len(ds_valid)
    print(f'Valid loss: {loss_valid:.4f}, Valid accuracy: {acc_valid*100:.2f}%\n')    


In [None]:
torch.save(model, 'resnet34_crassifier.bin')