In [None]:
import os
import cv2
import torch
import pandas as pd
import numpy as np
from torch.utils.data.dataset import Dataset

class CellDataset(Dataset):
    def __init__(self, data_dir, csv_file, transform=None):
        super().__init__()

        self.data_dir = data_dir
        self.df = csv_file
        self.transforms = transform           
        self.cell_types = self.df[['0','1','2','3','4','5','6','7','8','9','10','11','12','13','14','15','16','17','18']].values
        self.img_ids = self.df['image_id'].values
        self.cell_ids = self.df['cell_id'].values

    def __len__(self):
        # return len(self.img_ids)
        return 100

    def get_image(self, index):

        image_id = self.img_ids[index]
        cell_id = self.cell_ids[index]
        
        img_path = os.path.join(self.data_dir, 'cells', image_id + '_' + str(cell_id) + '.jpg')
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transforms(image=img)
        img = img['image']

        return img

    def __getitem__(self, index):
        x = self.get_image(index)
        y = self.cell_types[index]
        y = torch.from_numpy(y).float()
        return x, y

        

In [None]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm

In [None]:
!pip install iterative-stratification

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import timm

class Net(nn.Module):
    def __init__(self, name = 'resnet34', num_classes=19):
        super(Net, self).__init__()
        self.model = timm.create_model(name, pretrained=False, num_classes=num_classes)

    def forward(self, x):
        out = self.model(x)

        return out

In [None]:
class Params:
    def __init__(self):
        self.epochs = 2
        self.batch_size = 32 
        self.lr = 2e-3 
        self.n_workers = 24
        self.data_dir = '../input/hpa-cell-tiles-sample-balanced-dataset'

In [None]:
params = Params()

In [None]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.utilities.seed import seed_everything
from argparse import ArgumentParser
from pytorch_lightning.callbacks import Callback
from tqdm import tqdm
from pytorch_lightning.callbacks.progress import ProgressBar
# from base_model import Net
# from dataset import CellDataset
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.metrics import Recall
from pytorch_lightning.loggers import WandbLogger
import os
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import torchvision.transforms as transforms
import pandas as pd
import numpy as np

train_transforms = A.Compose([
    A.Rotate(),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Resize(width=224, height=224),
    A.Normalize(),
    ToTensorV2(),
])

valid_transforms = A.Compose([
    A.Resize(width=224, height=224),
    A.Normalize(),
    ToTensorV2,
])


df = pd.read_csv('../input/hpa-cell-tiles-sample-balanced-dataset/cell_df.csv')
labels = [str(i) for i in range(19)]
for x in labels: df[x] = df['image_labels'].apply(lambda r: int(x in r.split('|')))

dfs = df.sample(frac=1, random_state=42)
dfs = dfs.reset_index(drop=True)
len(dfs)

unique_counts = {}
for lbl in labels:
    unique_counts[lbl] = len(dfs[dfs.image_labels == lbl])

full_counts = {}
for lbl in labels:
    count = 0
    for row_label in dfs['image_labels']:
        if lbl in row_label.split('|'): count += 1
    full_counts[lbl] = count
    
counts = list(zip(full_counts.keys(), full_counts.values(), unique_counts.values()))
counts = np.array(sorted(counts, key=lambda x:-x[1]))
counts = pd.DataFrame(counts, columns=['label', 'full_count', 'unique_count'])
counts.set_index('label').T


nfold = 10
#seed = 42

y = dfs[labels].values
X = dfs[['image_id', 'cell_id']].values

dfs['fold'] = np.nan

mskf = MultilabelStratifiedKFold(n_splits=nfold, shuffle=False)
for i, (_, test_index) in enumerate(mskf.split(X, y)):
    dfs.iloc[test_index, -1] = i
    
dfs['fold'] = dfs['fold'].astype('int')


dfs['is_valid'] = False
dfs['is_valid'][dfs['fold'] == 0] = True

train_df = dfs[dfs['is_valid'] == False]
valid_df = dfs[dfs['is_valid'] == True]
print(len(train_df))
print(len(valid_df))
# def collate_fn(batch):
#     batch = list(filter(lambda x: x is not None, batch))
#     return torch.utils.data.dataloader.default_collate(batch)


class HPALit(pl.LightningModule):

    def __init__(self):
        super().__init__()
        
        
        self.lr = params.lr
        self.model = Net()
        self.criterion = torch.nn.BCEWithLogitsLoss()
        self.train_dataset = CellDataset(data_dir=params.data_dir,
                                        csv_file=train_df, transform=train_transforms)
        self.val_dataset = CellDataset(data_dir=params.data_dir,
                                        csv_file=valid_df, transform=valid_transforms)
        self.train_accuracy = Accuracy(subset_accuracy=True)
        self.val_accuracy = Accuracy(subset_accuracy=True)
        self.train_recall = Recall()
        self.val_recall = Recall()

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                        #   collate_fn = collate_fn,
                          batch_size=params.batch_size,
                          shuffle=True,
                          num_workers=params.n_workers,
                          pin_memory=True,
                          #drop_last=True
                          )

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                        #   collate_fn = collate_fn,
                          batch_size=params.batch_size,
                          shuffle=False,
                          num_workers=params.n_workers,
                          pin_memory=True,
                          #drop_last=True
                          )

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5, eps=1e-6)
        lr_scheduler = {'scheduler': scheduler, 'interval': 'epoch', 'monitor': 'valid_loss_epoch'}

        return [optimizer], [lr_scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)

        train_loss = self.criterion(pred, y)
        self.train_accuracy(torch.sigmoid(pred), y.type(torch.int))
        self.train_recall(torch.sigmoid(pred), y.type(torch.int))

        # self.log('train_acc_step', self.train_accuracy)
        # self.log('train_loss_step', train_loss)
        # self.log('train_recall_step', self.train_recall)
        
        return {'loss': train_loss}

    def training_epoch_end(self, outputs):
        train_loss_epoch = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('train_loss_epoch', train_loss_epoch)
        self.log('train_acc_epoch', self.train_accuracy.compute())
        self.log('train_recall_epoch', self.train_recall.compute())
        self.train_accuracy.reset()
        self.train_recall.reset()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)

        val_loss = self.criterion(pred, y)
        self.val_accuracy(torch.sigmoid(pred), y.type(torch.int))
        self.val_recall(torch.sigmoid(pred), y.type(torch.int))
        
        # self.log('valid_recall_step', self.val_recall)
        # self.log('valid_loss_step', val_loss)
        # self.log('valid_acc_step', self.val_accuracy)
        return {'valid_loss': val_loss}


    def validation_epoch_end(self, outputs):
        val_loss_epoch = torch.stack([x['valid_loss'] for x in outputs]).mean()
        self.log('valid_loss_epoch', val_loss_epoch)
        self.log('valid_acc_epoch', self.val_accuracy.compute())
        self.log('valid_recall_epoch', self.val_recall.compute())
        self.val_accuracy.reset()
        self.val_recall.reset()



class MyPrintingCallback(Callback):

    def on_validation_epoch_end(self, trainer, pl_module):
        # print(trainer.callback_metrics)
        print('\n')
        print('Train Loss: {:.3f}'.format(trainer.callback_metrics['train_loss_epoch'].item()))
        print('Val Loss: {:.3f}'.format(trainer.callback_metrics['valid_loss_epoch'].item()))
        print('Train Accuracy: {:.3f}'.format(trainer.callback_metrics['train_acc_epoch'].item()))
        print('Val Accuracy: {:.3f}'.format(trainer.callback_metrics['valid_acc_epoch'].item()))
        print('Train Recall: {:.3f}'.format(trainer.callback_metrics['train_recall_epoch'].item()))
        print('Val Recall: {:.3f}'.format(trainer.callback_metrics['valid_recall_epoch'].item()))


class LitProgressBar(ProgressBar):

    def init_validation_tqdm(self):
        bar = tqdm(
            disable=True,
        )
        return bar

In [None]:
checkpoint = ModelCheckpoint(
    dirpath='logs/resnet34-224',
    filename='{epoch}-{valid_loss_epoch:.3f}',
    save_top_k=-1,
    verbose=False,
)
printer = MyPrintingCallback()
logger = CSVLogger(save_dir="logs/resnet34-224", name="text_logs")
wandb_logger = WandbLogger(name='resnet34-224', project='polish-pipeline')
bar = LitProgressBar()
lr_monitor = LearningRateMonitor(logging_interval='step')
model = HPALit()
trainer = pl.Trainer(
    progress_bar_refresh_rate=1,
    max_epochs=params.epochs,
    callbacks=[checkpoint, printer, bar, lr_monitor],
    gradient_clip_val=1,
    logger=[logger],# wandb_logger],
    gpus=1,
    # accelerator='ddp',
    num_sanity_val_steps=0,
    #auto_lr_find=True,
)
#trainer.tune(model)
trainer.fit(model)