<a href="https://colab.research.google.com/github/yseeker/pytorch_templates/blob/main/template_lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

In [None]:
!pip install gwpy --quiet

In [None]:
%%capture
!pip install wandb

class CFG:
    project_name = 'project name'
    pretrained_model_name = 'efficientnetv2_rw_s'
    pretrained = True
    prettained_path = '../input/timm_weight/efficientnet_b0_ra-3dd342df.pth'
    input_channels = 3
    out_dim = 1
    wandb_note = ''
    colab_or_kaggle = 'kaggle'
    wandb_exp_name = f'{pretrained_model_name}_{colab_or_kaggle}_{wandb_note}'
    batch_size= 16
    epochs = 5
    num_of_fold = 5
    seed = 42
    patience = 3
    delta = 0.002
    num_workers = 8
    fp16 = True
    checkpoint_path = ''
    patience_mode = 'max'
    patience = 3
    delta = 0.002
    mixup_alpha = 1.0
    gpus = 1
    amp_backend = 'native'
    precision = 32
    enable_benchmarking = True
    auto_find_lr = False
    lr = 5e-4

In [None]:
# load data
%%capture
# !unzip "/content/drive/MyDrive/kaggle/input/project/data.zip" -d "/content"

In [None]:
import os
import json
f = open("/content/drive/My Drive/kaggle/kaggle.json", 'r')
json_data = json.load(f) #JSON形式で読み込む
os.environ['KAGGLE_USERNAME'] = json_data['username']
os.environ['KAGGLE_KEY'] = json_data['key']
os.chdir("/content/drive/My Drive/kaggle/working")

In [None]:
!pip install wandb
!pip install pytorch_lightning
!pip install timm
import collections
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from matplotlib import pyplot as plt
import seaborn as sns
import plotly.express as px
from tqdm import tqdm
import cv2

from sklearn.model_selection import StratifiedKFold
from sklearn import metrics

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau, OneCycleLR
from torch.optim.optimizer import Optimizer
import torchvision.utils as vutils

import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.metrics.functional import accuracy, f1, auroc
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping,LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

import timm
import albumentations as A

import warnings
warnings.filterwarnings("ignore")


In [None]:
df = pd.read_csv('../input/***/train_labels.csv')
df['img_path'] = df['id'].apply(
    lambda x: f'../input/***train/{x[0]}/{x}.npy'
)
X = df.img_path.values
Y = df.target.values
skf = StratifiedKFold(n_splits = CFG.num_of_fold)

def set_seed(seed = 0):
    np.random.seed(seed)
    random_state = np.random.RandomState(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    os.environ['PYTHONHASHSEED'] = str(seed)
    return random_state
set_seed(42)

In [None]:
class ClassificationDataset():
    def __init__(self, image_paths, targets, transform = None): 
        self.image_paths = image_paths
        self.targets = targets
        self.transform = None

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, item): 
        targets = self.targets[item]
        image1 = np.load(self.image_paths[item])[::2].astype(np.float32)
        image = np.vstack(image1).transpose((1, 0))

        image = ((image - np.mean(image, axis=1, keepdims=True)) / np.std(image, axis=1, keepdims=True))
        image = ((image - np.mean(image, axis=0, keepdims=True)) / np.std(image, axis=0, keepdims=True))
    
        image = image.astype(np.float32)[np.newaxis, ]

        if self.transform:
            image = self.transform(image=image)["image"]

        return {'image' : torch.tensor(image, dtype=torch.float), 
                'targets' : torch.tensor(targets, dtype=torch.float)}


class LitData(pl.LightningDataModule):
    def __init__(self, train_images, train_targets, valid_images, valid_targets):
        super().__init__()
        self.train_images = train_images
        self.train_targets = train_targets
        self.valid_images = valid_images
        self.valid_targets = valid_targets
    
    def setup(self,stage=None):
        self.train_aug = A.Compose(
            [
                A.Resize(p = 1, height = 512, width = 512),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.ShiftScaleRotate(p=0.5, 
                                scale_limit=0.02,
                                rotate_limit=10, 
                                border_mode = cv2.BORDER_REPLICATE),
                A.MotionBlur(p=0.5),
            ]
        )
        self.train_dataset = ClassificationDataset(
            image_paths=self.train_images, 
            targets=self.train_targets, 
            transform = None
        )
        self.valid_dataset = ClassificationDataset(
            image_paths=self.valid_images, 
            targets=self.valid_targets, 
            transform = None
        )

    def train_dataloader(self):
        return DataLoader(
                    self.train_dataset,
                    batch_size=CFG.batch_size,
                    shuffle= True,
                    num_workers=CFG.num_workers,
                    pin_memory = True
                 )

    def val_dataloader(self):
        return DataLoader(
                    self.valid_dataset,
                    batch_size=CFG.batch_size,
                    shuffle= False,
                    num_workers=CFG.num_workers,
                    pin_memory=True
                )

In [None]:
class LitNN(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = timm.create_model(CFG.pretrained_model_name, 
                                       pretrained = CFG.pretrained, 
                                       in_chans = CFG.input_channels)
        if not CFG.pretrained: 
            self.model.load_state_dict(torch.load(CFG.pretrained_path))
        self.model.classifier = nn.Linear(self.model.classifier.in_features, CFG.out_dim)
        self.conv1 = nn.Conv2d(1, 3, 
                               kernel_size=3, 
                               stride=1, 
                               padding=3, 
                               bias=False)
        self.criterion = nn.BCEWithLogitsLoss()
        self.lr = CFG.lr

    def forward(self, inputs):
        x = self.conv1(inputs)
        outputs = self.model(x)
        return outputs

    def training_step(self, batch, batch_idx):
        inputs, targets = batch['image'], batch['targets']
        preds = self(inputs)
        loss = self.criterion(preds, targets.view(-1, 1))
        data = {"loss": loss, "preds": preds, "targets": targets}
        self.log('train/loss', loss.cpu().detach().numpy(), prog_bar=False, logger=True)
        self.log('lr',self.optimizer.param_groups[0]['lr'], prog_bar=False, logger=True)
        self.calculate_metrics('train_step',
                               loss.cpu().detach().numpy(),
                               preds.cpu().detach().numpy(),
                               targets.cpu().detach().numpy())
        return data

    def training_epoch_end(self, training_step_outputs):
        self.calculate_metrics_epoch_end('train_epoch', training_step_outputs)

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch['image'], batch['targets']
        preds = self(inputs)
        loss = self.criterion(preds, targets.view(-1, 1))
        data = {"loss": loss, "preds": preds, "targets": targets}
        self.calculate_metrics('valid_step',
                               loss.cpu().detach().numpy(), 
                               preds.cpu().detach().numpy(),
                               targets.cpu().detach().numpy())
        return data
      
    def validation_epoch_end(self, validation_step_outputs):
        self.calculate_metrics_epoch_end('valid_epoch', validation_step_outputs)
        
    def test_step(self, batch, batch_idx):
        inputs = batch['image']
        preds_batch = self(inputs)
        return preds_batch
    
    def test_epoch_end(self, test_step_outputs):
        preds = torch.cat(outputs).detach().cpu().numpy()
        df = pd.DataFrame({'target':y_preds})
        N = len(glob.glob('submission*.csv'))
        df.target.to_csv(f'submission{N}.csv')

    def configure_optimizers(self):
        self.optimizer = Adam(self.parameters(), lr=self.lr)
        self.scheduler = CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1)
        return [self.optimizer], [self.scheduler]

    def calculate_metrics(self, stage, loss, preds, targets):
        try :roc_auc = metrics.roc_auc_score(targets, preds)
        except : roc_auc = 1.1
        self.log_dict({f'{stage}/metric' : roc_auc}, prog_bar= True, logger=True)
        if stage == 'valid_epoch': print('valid_epoch : ', roc_auc)

    def calculate_metrics_epoch_end(self, stage, outputs):
        data = {}
        for output in outputs:
            for key in output.keys():
                if key not in data: data[key] = []
                else: 
                    data[key].append(output[key].cpu().detach().numpy())
        for key in data.keys():
            if key != 'loss': 
                data[key] = np.concatenate(data[key])
        self.calculate_metrics(stage, **data)

In [None]:
for fold_cnt, (train_index, test_index) in enumerate(skf.split(X, Y), 1):
    train_images, valid_images = X[train_index], X[test_index]
    train_targets, valid_targets = Y[train_index], Y[test_index]

    data_module = LitData(train_images, train_targets, valid_images, valid_targets)
    model = LitNN()

#     cpt = ModelCheckpoint(
#         save_top_k=CFG.epochs,
#         verbose=True,
#         monitor='valid/roc_auc',
#         mode=CFG.patience_mode
#     )

    es = EarlyStopping(
        monitor='valid/roc_auc',
        mode=CFG.patience_mode,
        patience= CFG.patience)
    
#     lr_monitor = LearningRateMonitor(
#         logging_interval='step')

    wandb_logger = WandbLogger(
        name=CFG.wandb_exp_name,
        project= CFG.project_name, 
        offline=False, 
        log_model=False
    )

    Trainer = pl.Trainer(
        #checkpoint_callback=cpt,
        #callbacks=[es],
        max_epochs=CFG.epochs,
        amp_backend = CFG.amp_backend,
        gpus=CFG.gpus,
        precision=CFG.precision,
        logger=wandb_logger,
        auto_lr_find=CFG.auto_find_lr
    )
    if CFG.auto_find_lr:
        Trainer.tune(model, datamodule=data_module)
        print('best initial lr : ', model.lr)
    Trainer.fit(model, data_module)