# **Imports**

In [None]:
import os
import cv2
import warnings
import random
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

import torch
from torch import nn
from torch import optim
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Dataset

from albumentations import Normalize, Resize, Compose
from albumentations.pytorch import ToTensorV2

warnings.filterwarnings("ignore")

def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

fix_all_seeds(42)

In [None]:
SAMPLE_SUBMISSION  = '../input/sartorius-cell-instance-segmentation/sample_submission.csv'
TRAIN_CSV = "../input/sartorius-cell-instance-segmentation/train.csv"
TRAIN_PATH = "../input/sartorius-cell-instance-segmentation/train"
TEST_PATH = "../input/sartorius-cell-instance-segmentation/test"
EXTRA_DATA_PATH = "../input/sartorius-cell-instance-segmentation/train_semi_supervised"

In [None]:
df_train = pd.read_csv(TRAIN_CSV)
df_class = df_train.groupby("id")[['cell_type']].first().reset_index()
df_class['cell_type'].value_counts(normalize=True).round(2)

df_class_train, df_class_val = train_test_split(df_class, test_size=0.2)
df_class_train['cell_type'].value_counts(normalize=True).round(2)

# **Efficientnet Model**

In [None]:
!pip install efficientnet_pytorch

In [None]:
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained(model_name='efficientnet-b3', num_classes=3)

# **Semi-Supervised Data**

In [None]:
class CellClassificationDatasetExtraData(Dataset):
    def __init__(self):
        self.base_path = EXTRA_DATA_PATH
        self.transforms = Compose([
            Resize(224, 224), 
            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1), 
            ToTensorV2()
        ])
        self.files = os.listdir(EXTRA_DATA_PATH)
        self.labels = ['shsy5y', 'astro', 'cort']


    def __getitem__(self, idx):
        file = self.files[idx]
        image_path = os.path.join(self.base_path, file)
        image = self.transforms(image=cv2.imread(image_path))['image']
        
        label = file.split("[")[0]
        if label == 'astros':
            label = 'astro'
            
        return {'image': image, 'label': self.labels.index(label)}

    def __len__(self):
        return len(self.files)

# **Dataset**

In [None]:
class CellClassificationDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.base_path = TRAIN_PATH
        self.transforms = Compose([
            Resize(244, 244), 
            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1), 
            ToTensorV2()
        ])
        self.image_ids = df.id.unique().tolist()
        self.labels = ['shsy5y', 'astro', 'cort']

    def get_label_for_img(self, image_id):
        label = self.df.loc[self.df['id'] == image_id, 'cell_type'].iloc[0]
        label_id = self.labels.index(label)
        return label_id
        
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.base_path, image_id + ".png")
        image = self.transforms(image=cv2.imread(image_path))['image']
        label = self.get_label_for_img(image_id)
        return {'image': image, 'label': label}

    def __len__(self):
        return len(self.image_ids)

In [None]:
ds_train = CellClassificationDataset(df_class_train)
dl_train = DataLoader(ds_train, batch_size=64, num_workers=8, pin_memory=True, shuffle=True)

ds_train_extra = CellClassificationDatasetExtraData()
dl_train_extra = DataLoader(ds_train_extra, batch_size=64, num_workers=8, pin_memory=True, shuffle=True)

ds_val = CellClassificationDataset(df_class_val)
dl_val = DataLoader(ds_val, batch_size=8, num_workers=8, pin_memory=True, shuffle=True)

# **Training**

In [None]:
# !!! Run more epochs :)
LEARNING_RATE = 5e-4
EPOCHS = 3

In [None]:
model.cuda()

In [None]:
n_samples_val = len(ds_val)
n_batches_val = len(ds_val)
n_batches_train = len(dl_train)
n_batches_train_extra = len(dl_train_extra)
criterion = CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
for epoch in range(1, EPOCHS + 1):
    print(f"Starting epoch: {epoch} / {EPOCHS}")
    
    train_loss = 0.0
    train_extra_loss = 0.0
    optimizer.zero_grad()
    model.train()
    
    # Train on extra data
    for batch_idx, batch in enumerate(dl_train_extra):
        
        # Predict
        images, labels = batch['image'], batch['label']
        images, labels = images.cuda(),  labels.cuda()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Back prop
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_extra_loss += loss.item()
    
    # Train on train data
    for batch_idx, batch in enumerate(dl_train):
        
        # Predict
        images, labels = batch['image'], batch['label']
        images, labels = images.cuda(),  labels.cuda()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Back prop
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
    
    # Validate
    model.eval()
    loss = 0
    correct = 0
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(dl_val, 1):
            images, labels = batch['image'], batch['label']
            images, labels = images.cuda(),  labels.cuda()
            preds = model(images)
            final_pred = preds.argmax(dim=1)
            correct += (final_pred == labels).sum().item()
            loss += criterion(preds, labels)

    train_loss = train_loss / n_batches_train
    train_extra_loss = train_extra_loss / n_batches_train_extra
    loss = loss / n_batches_val
    acc = correct / n_samples_val
    
    print(f"Epoch: {epoch} - Train Extra Loss {train_extra_loss:.5f}. Train Loss {train_loss:.5f}. Val. Loss: {loss:.5f} Accuracy: {acc*100:.4f}%")

In [None]:
torch.save(model, 'cell-classification-model.pth')