In [None]:
import sys
timm_path = '../input/timm-pytorch-image-models/pytorch-image-models-master' 
if timm_path not in sys.path:
    sys.path.append(timm_path)

In [None]:
import timm
import torch
import pytorch_lightning

assert timm.__version__ == '0.4.4'
assert torch.__version__ == '1.7.0'
assert pytorch_lightning.__version__ == '1.2.0'

# CONSTANTS

In [None]:
from pathlib import Path

IMAGE_SIZE = 512

TARGET_COLS = ['ETT - Abnormal', 'ETT - Borderline',
       'ETT - Normal', 'NGT - Abnormal', 'NGT - Borderline',
       'NGT - Incompletely Imaged', 'NGT - Normal', 'CVC - Abnormal',
       'CVC - Borderline', 'CVC - Normal', 'Swan Ganz Catheter Present',
       ]

COLS = ['StudyInstanceUID','ETT - Abnormal', 'ETT - Borderline', 'ETT - Normal',
                 'NGT - Abnormal', 'NGT - Borderline', 'NGT - Incompletely Imaged', 'NGT - Normal', 
                 'CVC - Abnormal', 'CVC - Borderline', 'CVC - Normal',
                 'Swan Ganz Catheter Present']

COLOR_MAP = {'ETT - Abnormal': (255, 0, 0),
             'ETT - Borderline': (0, 255, 0),
             'ETT - Normal': (0, 0, 255),
             'NGT - Abnormal': (255, 255, 0),
             'NGT - Borderline': (255, 0, 255),
             'NGT - Incompletely Imaged': (0, 255, 255),
             'NGT - Normal': (128, 0, 0),
             'CVC - Abnormal': (0, 128, 0),
             'CVC - Borderline': (0, 0, 128),
             'CVC - Normal': (128, 128, 0),
             'Swan Ganz Catheter Present': (128, 0, 128),
            }

DATADIR = Path('../input/ranzcr-clip-catheter-line-classification')

TRAINDIR = DATADIR.joinpath('train')

TESTDIR = '../input/ranzcr-clip-catheter-line-classification/test'



In [None]:
# clean-up previous logs:

import os
import shutil

output_files = os.listdir('/kaggle/working')
print(output_files)

if 'lightning_logs' in output_files:
    shutil.rmtree('/kaggle/working/lightning_logs')

output_files = os.listdir('/kaggle/working')

assert not 'lightning_logs' in output_files

# UTILS

In [None]:
from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, 
    IAAAdditiveGaussianNoise, Transpose, HueSaturationValue, CoarseDropout
    )
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

In [None]:
def check_exists(path):
    if os.path.exists(path):
        raise ValueError("This path already exists")
        
        
def get_transforms(*, data):
    
    if data == 'train':
        return Compose([
            #Resize(IMAGE_SIZE, ),
            RandomResizedCrop(IMAGE_SIZE, IMAGE_SIZE, scale=(0.85, 1.0)),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            RandomBrightnessContrast(p=0.2, brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2)),
            HueSaturationValue(p=0.2, hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2),
            ShiftScaleRotate(p=0.2, shift_limit=0.0625, scale_limit=0.2, rotate_limit=20),
            CoarseDropout(p=0.2),
            Cutout(p=0.2, max_h_size=16, max_w_size=16, fill_value=(0., 0., 0.), num_holes=16),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ], additional_targets={'image_annot': 'image'})

    elif data == 'valid' or data == 'test':
        return Compose([
            Resize(IMAGE_SIZE, IMAGE_SIZE),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

# DATASET

In [None]:
import os
import numpy as np
import random
import ast
import cv2
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

In [None]:
class S1_Dataset(Dataset):
    """
    Y.Nakama https://www.kaggle.com/yasufuminakama
    """
    def __init__(self, df, df_annotations, annot_size=50, transform=None):
        self.df = df
        self.df_annotations = df_annotations
        self.annot_size = annot_size
        self.file_names = df['StudyInstanceUID'].values
        self.labels = df[TARGET_COLS].values
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TRAINDIR}/{file_name}.jpg'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        query_string = f"StudyInstanceUID == '{file_name}'"
        df = self.df_annotations.query(query_string)
        for i, row in df.iterrows():
            label = row["label"]
            data = np.array(ast.literal_eval(row["data"]))
            for d in data:
                image[d[1]-self.annot_size//2:d[1]+self.annot_size//2,
                      d[0]-self.annot_size//2:d[0]+self.annot_size//2,
                      :] = COLOR_MAP[label]
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = torch.tensor(self.labels[idx]).float()
        #print('\tS1 image shape:', image.shape)
        return file_name, image, label
    
    
class S2_Dataset(Dataset):
    """
    Y.Nakama https://www.kaggle.com/yasufuminakama
    """
    def __init__(self, df, df_annotations, use_annot=False, annot_size=50, transform=None):
        self.df = df
        self.df_annotations = df_annotations
        self.use_annot = use_annot
        self.annot_size = annot_size
        self.file_names = df['StudyInstanceUID'].values
        self.labels = df[TARGET_COLS].values
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TRAINDIR}/{file_name}.jpg'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        labels = torch.tensor(self.labels[idx]).float()
        if self.use_annot:
            image_annot = image.copy()
            query_string = f"StudyInstanceUID == '{file_name}'"
            df = self.df_annotations.query(query_string)
            for i, row in df.iterrows():
                label = row["label"]
                data = np.array(ast.literal_eval(row["data"]))
                for d in data:
                    image_annot[d[1]-self.annot_size//2:d[1]+self.annot_size//2,
                                d[0]-self.annot_size//2:d[0]+self.annot_size//2,
                                :] = COLOR_MAP[label]
            if self.transform:
                augmented = self.transform(image=image, image_annot=image_annot)
                image = augmented['image']
                image_annot = augmented['image_annot']
            return file_name, image, image_annot, labels
        else:
            if self.transform:
                augmented = self.transform(image=image)
                image = augmented['image']
            return file_name, image, labels
    

class S3_Dataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['StudyInstanceUID'].values
        self.labels = df[TARGET_COLS].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TRAINDIR}/{file_name}.jpg'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = torch.tensor(self.labels[idx]).float()
        return file_name, image, label
    
class TestDataset(Dataset):
    
    def __init__(self, transform):
        self.images = np.array(os.listdir(TESTDIR))
        self.transform = transform
                                
    def __getitem__(self, index):
        ID = self.images[index]
        file_name = os.path.join(TESTDIR, ID)
        image = cv2.imread(file_name)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return ID, image
    
    def __len__(self):
        return len(self.images)

# DataModule

In [None]:
import os
import subprocess
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

import torch.nn.functional as F
import pytorch_lightning as pl

import timm
import gc

#from IPython import embed

In [None]:
class CLiP_DataModule(pl.LightningDataModule):
    
    def __init__(self, hparams, stage):
        super().__init__()
        self.hparams = hparams
        self.expected_files = ['train', 'train_annotations.csv', 'test']
        self.stage = stage
        
    def prepare_data(self):
        files_on_disk = os.listdir(DATADIR)
        if not all(file in files_on_disk for file in self.expected_files):
            kaggle.api.authenticate()
            kaggle.api.dataset_download_files(
                'ranzcr-clip-catheter-line-classification', path=DATADIR, unzip=True)
        if not 'resnet200d_ra2-bdba9bf9.pth' in files_on_disk:
            subprocess.run(['wget', '-P', DATADIR, 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet200d_ra2-bdba9bf9.pth'])

    def setup(self, step):
        df = pd.read_csv(DATADIR.joinpath('train.csv'))
        
        annots_df = pd.read_csv(DATADIR.joinpath('train_annotations.csv'))
        
        train_annots = annots_df.sample(frac=.7, random_state=22117).reset_index(drop=True)
        
        val_annots = annots_df[~annots_df['StudyInstanceUID'].isin(train_annots['StudyInstanceUID'])]
        
        train_df = df[df['StudyInstanceUID'].isin(train_annots['StudyInstanceUID'])]
        
        val_df = df[df['StudyInstanceUID'].isin(val_annots['StudyInstanceUID'])]
        
        if self.stage == 1:    
            self.train_ds = S1_Dataset(train_df, train_annots, transform=get_transforms(data='train'))
            self.val_ds = S1_Dataset(val_df, val_annots, transform=get_transforms(data='valid'))
        
        elif self.stage == 2:
            self.train_ds = S2_Dataset(train_df, train_annots, use_annot=True, transform=get_transforms(data='train'))
            self.val_ds = S2_Dataset(
                df=val_df,
                df_annotations=val_annots,
                use_annot=False,
                transform=get_transforms(data='valid')
            )
            
        elif self.stage == 3:
            self.train_ds = S3_Dataset(train_df, transform=get_transforms(data='train'))
            self.val_ds = S3_Dataset(val_df, transform=get_transforms(data='valid'))
        
        self.test_ds = TestDataset(transform=get_transforms(data='test'))

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers)
    
    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers) 

In [None]:
class hparams:
    batch_size = 1
    
dm = CLiP_DataModule(hparams, stage=2)
dm.setup(step='fit')


In [None]:
batch = dm.val_ds[0]
print(len(batch))

# Stage1

In [None]:
class CustomResNet200D(nn.Module):
    
    def __init__(self, hparams, pretraining, model_name='resnet200d'):
        super().__init__()
        self.hparams = hparams
        self.model = timm.create_model(model_name)
        if pretraining == 'image-net':
            pretrained_path = '../input/resnet200d-pretrained-weight/resnet200d_ra2-bdba9bf9.pth'
            self.model.load_state_dict(torch.load(pretrained_path))
            print(f'load {model_name} pretrained model')
        if pretraining == 'Nakama':
            pretrained_path = hparams.restore
            self.model.load_state_dict(torch.load(pretrained_path), strict=False)
        n_features = self.model.fc.in_features
        self.model.global_pool = nn.Identity()
        self.model.fc = nn.Identity()
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(n_features, 11)

    def forward(self, x):
        bs = x.size(0)
        features = self.model(x)
        pooled_features = self.pooling(features).view(bs, -1)
        output = self.fc(pooled_features)
        return features, pooled_features, output
    
    
class Teacher(pl.LightningModule):
    
    def __init__(self, hparams):
        super(Teacher, self).__init__()
        print('\n\t**** STAGE 1 ****\n')
        self.hparams = hparams
        self.teacher = CustomResNet200D(hparams, pretraining='image-net')
        self.sig = nn.Sigmoid()

    def forward(self, x):      
        fts, pooled_fts, output = self.teacher.forward(x)
        return fts, pooled_fts, output
    
    def training_step(self, batch, batch_idx):
        p, xa, y = batch
        assert xa.shape[1] == 3
        fts, pooled_fts, y_hat = self.forward(xa)
        train_loss = self.loss(y_hat, y)
        self.log('train_loss', train_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        return train_loss
        
    def validation_step(self, batch, batch_idx):
        p, xa, y = batch
        assert xa.shape[1] == 3
        fts, pooled_fts, y_hat = self.forward(xa)
        val_loss = self.loss(y_hat, y)
        self.log('val_loss', val_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        gc.collect() #pytorch/issues/40911
        return val_loss
    
    def loss(self, y_hat, y):
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        return loss
    
    def configure_optimizers(self):
        optimiser = Adam(
                self.parameters(), lr=self.hparams.lr,
                betas=(0.9, 0.999), eps=1e-8, 
                #weight_decay=self.hparams.weight_decay,
                amsgrad=False
                )
        
        scheduler = {
                            'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(
                                    optimiser,
                                    mode='min',
                                    factor=0.1, 
                                    patience=2, 
                                    verbose=True,
                                    threshold=1e-04,
                                    threshold_mode='rel',
                                    cooldown=0, 
                                    min_lr=1e-12,
                                    eps=1e-13),
                            'monitor': 'val_loss_epoch'
                            }
        return [optimiser], [scheduler]

In [None]:
from datetime import datetime
from torch.utils.data import DataLoader, random_split

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.utilities.parsing import AttributeDict

from IPython import embed

In [None]:
def s1(hparams):
    
    early_stop_callback = EarlyStopping(
            monitor='val_loss_epoch',
            min_delta=1,
            patience=4,
            verbose=True,
            mode='min',
            strict=False
            )
        
    date_time = datetime.now().strftime("%Y-%m-%d")
    
    ckpt_callback = ModelCheckpoint(
            dirpath=None,
            monitor='val_loss_epoch',
            verbose=1,
            save_top_k=5,
            save_weights_only=True,
            mode='min',
            period=1,
            filename='s1_{epoch}-{train_loss_epoch:.3f}-{val_loss_epoch:.3f}'
            )   
    
    trainer = Trainer(
        accelerator=hparams.accel,
        accumulate_grad_batches=hparams.grad_cum,
        amp_backend='native',
        auto_lr_find=hparams.autolr,
        auto_scale_batch_size=hparams.auto_scale_batch_size,
        benchmark=True,
        callbacks=[ckpt_callback, early_stop_callback],
        check_val_every_n_epoch=hparams.check_val_n,
        gpus=hparams.gpus,
        max_epochs=hparams.max_epochs,
        
        #overfit_batches=6,
        
        precision=hparams.precision,
        progress_bar_refresh_rate=100,
        )
    
    clip_data = CLiP_DataModule(hparams, stage=1)

    model = Teacher(hparams)
    
    trainer.fit(model, clip_data)

In [None]:
hparams = AttributeDict(
        {
            'accel': None,
            'autolr': True, 
            'auto_scale_batch_size': 'binsearch',
            'batch_size': 1,
            'check_val_n': 1,
            'dev': False,
            'gpus': -1, # change to -1 if gpus enabled
            'grad_cum': 16,
            'lr': 0.0001,
            'max_epochs': 6,
            'num_workers': 0,
            'train_path': TRAINDIR,
            'pl_ver': pl.__version__,
            'precision': 16,
            'seed': 22117,
            'stage': 1,
            #'weight_decay': 1e-07
            })
    
s1(hparams)

# STAGE 2

In [None]:
class Student(pl.LightningModule):
    
    def __init__(self, hparams):
        super(Student, self).__init__()
        print('\n\t**** STAGE 2 ****\n')
        self.hparams = hparams
        self.teacher = CustomResNet200D(hparams, pretraining='Nakama')            
        self.student = CustomResNet200D(hparams, pretraining='image-net')

    def forward(self, x, x_annot):   
        teacher_fts, _, teacher_output = self.teacher.forward(x_annot)
        student_fts, _, student_output = self.student.forward(x)
        return teacher_fts, teacher_output, student_fts, student_output
    
    def training_step(self, batch, batch_idx):
        p, x, x_annot, y = batch
        teacher_fts, teacher_ouput, student_fts, student_output = self.forward(x, x_annot)
        train_loss = self.student_teacher_loss(student_output, y, student_fts, teacher_fts)
        self.log('train_loss', train_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        return train_loss
        
    def validation_step(self, batch, batch_idx):
        #embed()
        p, x, y = batch
        student_fts, _, student_output = self.student.forward(x)
        val_loss = self.bce_loss(student_output, y)
        self.log('val_loss', val_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        gc.collect() #pytorch/issues/40911
        return val_loss
    
    def student_teacher_loss(self, y_hat, y, student_fts, teacher_fts):
        mse_loss = F.mse_loss(student_fts.view(-1), teacher_fts.view(-1))
        bce_loss = F.binary_cross_entropy_with_logits(y_hat, y) 
        loss = self.hparams.loss_w[0] * mse_loss + self.hparams.loss_w[1] * bce_loss
        return loss
    
    def bce_loss(self, y_hat, y):
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        return loss
    
    def configure_optimizers(self):
        optimiser = Adam(
                self.parameters(), lr=self.hparams.lr,
                betas=(0.9, 0.999), eps=1e-8, 
                #weight_decay=self.hparams.weight_decay,
                amsgrad=False
                )
        
        scheduler = {
                            'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(
                                    optimiser,
                                    mode='min',
                                    factor=0.1, 
                                    patience=2, 
                                    verbose=True,
                                    threshold=1e-04,
                                    threshold_mode='rel',
                                    cooldown=0, 
                                    min_lr=1e-12,
                                    eps=1e-13),
                            'monitor': 'val_loss_epoch'
                            }
        return [optimiser], [scheduler]

In [None]:
last_ckpt = os.listdir('./lightning_logs/version_0/checkpoints')[-1]
s1_ckpt_path = os.path.join('./lightning_logs/version_0/checkpoints', last_ckpt)
print(s1_ckpt_path)

In [None]:
def s2(hparams):
    
    early_stop_callback = EarlyStopping(
            monitor='val_loss_epoch',
            min_delta=1,
            patience=4,
            verbose=True,
            mode='min',
            strict=False
            )
        
    date_time = datetime.now().strftime("%Y-%m-%d")
    
    ckpt_callback = ModelCheckpoint(
            dirpath=None,
            monitor='val_loss_epoch',
            verbose=1,
            save_top_k=5,
            save_weights_only=True,
            mode='min',
            period=1,
            filename='s2_{epoch}-{train_loss_epoch:.3f}-{val_loss_epoch:.3f}'
            ) 
    
    trainer = Trainer(
        accelerator=hparams.accel,
        accumulate_grad_batches=hparams.grad_cum,
        amp_backend='native',
        auto_lr_find=hparams.autolr,
        auto_scale_batch_size=hparams.auto_scale_batch_size,
        benchmark=True,
        callbacks=[ckpt_callback, early_stop_callback],
        check_val_every_n_epoch=hparams.check_val_n,
        gpus=hparams.gpus,
        max_epochs=hparams.max_epochs,
        
        #limit_train_batches=6,
        #limit_val_batches=6,
        #overfit_batches=20,

        precision=hparams.precision,
        progress_bar_refresh_rate=100,
        )
    
    clip_data = CLiP_DataModule(hparams, stage=2)

    model = Student(hparams)
     
    trainer.fit(model, clip_data)
    

In [None]:
hparams = AttributeDict(
        {
            'accel': None,
            'autolr': True,
            'auto_scale_batch_size': 'binsearch',
            'batch_size': 1,
            'check_val_n': 1,
            'dev': False,
            'gpus': -1, # change to GPU if enabled
            'grad_cum': 16,
            'loss_w': (0.5, 1),
            'lr': 0.0001,
            'max_epochs': 5,
            'num_workers': 0,
            'train_path': TRAINDIR,
            'pl_ver': pl.__version__,
            'precision': 16,
            'restore': s1_ckpt_path,
            'seed': 22117,
            'stage': 2,
            #'weight_decay': 1e-07
            })

s2(hparams)

# STAGE 3

In [None]:
class YasufumiNet(pl.LightningModule):
    
    def __init__(self, hparams):
        super(YasufumiNet, self).__init__()
        print('\n\t**** STAGE 3 ****\n')
        # save parameters
        self.hparams = hparams
        self.net = CustomResNet200D(hparams, pretraining='Nakama')

    def forward(self, x):      
        x = self.net(x)
        return x 
    
    def training_step(self, batch, batch_idx):
        p, x, y = batch
        fts, pooled_fts, y_hat = self.forward(x)
        train_loss = self.loss(y_hat, y)
        self.log('train_loss', train_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        return train_loss
        
    def validation_step(self, batch, batch_idx):
        p, x, y = batch
        fts, pooled_fts, y_hat = self.forward(x)
        val_loss = self.loss(y_hat, y)
        self.log('val_loss', val_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        gc.collect() #pytorch/issues/40911
        return val_loss
    
    def loss(self, y_hat, y):
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        return loss
    
    def configure_optimizers(self):
        optimiser = Adam(
                self.parameters(), lr=self.hparams.lr,
                betas=(0.9, 0.999), eps=1e-8, 
                #weight_decay=self.hparams.weight_decay,
                amsgrad=False
                )
        
        scheduler = {
                            'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(
                                    optimiser,
                                    mode='min',
                                    factor=0.1, 
                                    patience=2, 
                                    verbose=True,
                                    threshold=1e-04,
                                    threshold_mode='rel',
                                    cooldown=0, 
                                    min_lr=1e-12,
                                    eps=1e-13),
                            'monitor': 'val_loss_epoch'
                            }
        return [optimiser], [scheduler]

In [None]:
last_ckpt = os.listdir('./lightning_logs/version_1/checkpoints')[-1]
s2_ckpt_path = os.path.join('./lightning_logs/version_1/checkpoints', last_ckpt)
print(s2_ckpt_path)

In [None]:
def s3(hparams):
    
    early_stop_callback = EarlyStopping(
            monitor='val_loss_epoch',
            min_delta=1,
            patience=4,
            verbose=True,
            mode='min',
            strict=False
            )
        
    date_time = datetime.now().strftime("%Y-%m-%d")
    
    ckpt_callback = ModelCheckpoint(
            dirpath=None,
            monitor='val_loss_epoch',
            verbose=1,
            save_top_k=5,
            save_weights_only=True,
            mode='min',
            period=1,
            filename='s3_{epoch}-{train_loss_epoch:.3f}-{val_loss_epoch:.3f}'
            ) 
    
    trainer = Trainer(
        accelerator=hparams.accel,
        accumulate_grad_batches=hparams.grad_cum,
        amp_backend='native',
        auto_lr_find=hparams.autolr,
        auto_scale_batch_size=hparams.auto_scale_batch_size,
        benchmark=True,
        callbacks=[ckpt_callback, early_stop_callback],
        check_val_every_n_epoch=hparams.check_val_n,
        gpus=hparams.gpus,
        max_epochs=hparams.max_epochs,
        
        #overfit_batches=10,
        
        precision=hparams.precision,
        profiler=False,
        progress_bar_refresh_rate=100,
        )
    
    clip_data = CLiP_DataModule(hparams, stage=3)

    model = YasufumiNet(hparams)
     
    trainer.fit(model, clip_data)

In [None]:
hparams = AttributeDict(
        {
            'accel': None,
            'autolr': True,
            'auto_scale_batch_size': 'binsearch',
            'batch_size': 1,
            'check_val_n': 1,
            'dev': False,
            'gpus': -1,
            'grad_cum': 16,
            'loss_w': (0.5, 1),
            'lr': 0.0001,
            'max_epochs': 7,
            'num_workers': 0,
            'train_path': TRAINDIR,
            'pl_ver': pl.__version__,
            'precision': 16,
            'restore': s2_ckpt_path,
            'seed': 22117,
            'stage': 3,
            #'weight_decay': 1e-07
            })
    
s3(hparams)

# INFERENCE

In [None]:
import os
import gc
from tqdm import tqdm
#from IPython import embed

def load_ensemble(path, n, device):
    paths = sorted([os.path.join(path, x) for x in os.listdir(path)]) 
    keys = []
    for i in range(n):
        key = 'model' + str(i+1)
        keys.append(key)
        
    d = dict.fromkeys(keys)
    for checkpoint_path, key in zip(paths, d.keys()):
        hparams = AttributeDict(
            {'restore': checkpoint_path})
        model = YasufumiNet(hparams)
        model.to(device)
        d[key] = model
    return d


def inference(args, device):
    """
    Target vector:
    # 'ETT - Abnormal',
    # 'ETT - Borderline',
    # 'ETT - Normal',
    # 'NGT - Abnormal',
    # 'NGT - Borderline',
    # 'NGT - Incompletely Imaged',
    # 'NGT - Normal',
    # 'CVC - Abnormal',
    # 'CVC - Borderline',
    # 'CVC - Normal',
    # 'Swan Ganz Catheter Present',
    """
    
    csv_path = os.path.join(args.csv_out_path, args.version)
    #check_exists(csv_path)
    if not os.path.exists(args.csv_out_path):
        os.makedirs(args.csv_out_path)
            
    ensemble_dic = load_ensemble(args.ckpt_dir, args.best_n, device)
    
    test_loader = DataLoader(TestDataset(transform=get_transforms(data='test')), batch_size=1, num_workers=0)
        
    sigmoid = torch.nn.Sigmoid()
    
    df = pd.DataFrame(columns=COLS)
    
    with torch.no_grad():
        
        for batch in tqdm(test_loader, desc='infering ensemble', miniters=50):
            ID, x = batch
            x = x.cuda()
            y_hats = []

            for model in ensemble_dic.values():
                model.eval()
                model.to(device)
                #embed()
                _, _, output = model(x)
                positive_probability = sigmoid(output)
                y_hats.append(positive_probability.clone().detach())

            if args.voting == 'mean':
                probs = torch.cat(y_hats).mean(axis=0)

            elif args.voting == 'max':
                probs = torch.cat(y_hats).max(-2)

            preds = probs.detach().cpu().apply_(lambda x: x > .5).type(torch.int).tolist()

            d = dict({'StudyInstanceUID': ID[0]})
            for k, v in zip(COLS[1:], preds):
                d[k] = v

            df = df.append(d, ignore_index=True)

            df.to_csv(csv_path, index=False)
            gc.collect()

In [None]:
parent_path = './lightning_logs/version_2/checkpoints'
checkpoint = sorted(os.listdir(parent_path))[-1]
path = os.path.join('./lightning_logs/version_2/checkpoints', checkpoint)
print(path)

In [None]:
args = AttributeDict(
    {
        'best_n': 5,
        'ckpt_dir': parent_path,
        'csv_out_path': './',
        'version': 'submission.csv',
        'voting': 'mean',
        'test_dir': TESTDIR
        }
)

device = torch.device("cuda:0")
inference(args, device=device)