# BYOL

https://medium.com/the-dl/easy-self-supervised-learning-with-byol-53b8ad8185d

In [1]:
#import sys
#sys.path.append("../input/pytorchlightning110/pytorch-lightning-1.1.0")

import pytorch_lightning as pl
print("pytorch_lightning version:", pl.__version__)

pytorch_lightning version: 1.1.0


In [2]:
import random
from typing import Callable, Tuple

from kornia import augmentation as aug
from kornia import filters
from kornia.geometry import transform as tf
import torch
from torch import nn, Tensor


class RandomApply(nn.Module):
    def __init__(self, fn: Callable, p: float):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x: Tensor) -> Tensor:
        return x if random.random() > self.p else self.fn(x)


def default_augmentation(image_size: Tuple[int, int] = (600, 800), crop_size: Tuple[int, int] = (224, 224)) -> nn.Module:
    return nn.Sequential(
        tf.Resize(size=image_size),  # 元画像のサイズに拡大してからcropすることにする
        RandomApply(aug.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
        aug.RandomGrayscale(p=0.2),
        aug.RandomHorizontalFlip(),
        RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
        aug.RandomResizedCrop(size=crop_size),
        aug.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        ),
    )

## Encoder Wrapper

In [3]:
from typing import Union


def mlp(dim: int, projection_size: int = 256, hidden_size: int = 4096) -> nn.Module:
    return nn.Sequential(
        nn.Linear(dim, hidden_size),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size),
    )


class EncoderWrapper(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        projection_size: int = 256,
        hidden_size: int = 4096,
        layer: Union[str, int] = -2,
    ):
        super().__init__()
        self.model = model
        self.projection_size = projection_size
        self.hidden_size = hidden_size
        self.layer = layer

        self._projector = None
        self._projector_dim = None
        self._encoded = torch.empty(0)
        self._register_hook()

    @property
    def projector(self):
        if self._projector is None:
            self._projector = mlp(
                self._projector_dim, self.projection_size, self.hidden_size
            )
        return self._projector

    def _hook(self, _, __, output):
        output = output.flatten(start_dim=1)
        if self._projector_dim is None:
            self._projector_dim = output.shape[-1]
        self._encoded = self.projector(output)

    def _register_hook(self):
        if isinstance(self.layer, str):
            layer = dict([*self.model.named_modules()])[self.layer]
        else:
            layer = list(self.model.children())[self.layer]

        layer.register_forward_hook(self._hook)

    def forward(self, x: Tensor) -> Tensor:
        _ = self.model(x)
        return self._encoded

## BYOL and Training Code

In [4]:
from copy import deepcopy
from itertools import chain
from typing import Dict, List

import pytorch_lightning as pl
from torch import optim
import torch.nn.functional as f


def normalized_mse(x: Tensor, y: Tensor) -> Tensor:
    x = f.normalize(x, dim=-1)
    y = f.normalize(y, dim=-1)
    return 2 - 2 * (x * y).sum(dim=-1)


class BYOL(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        image_size: Tuple[int, int] = (128, 128),
        crop_size: Tuple[int, int] = (128, 128),
        hidden_layer: Union[str, int] = -2,
        projection_size: int = 256,
        hidden_size: int = 4096,
        augment_fn: Callable = None,
        beta: float = 0.99,
        **hparams,
    ):
        super().__init__()
        self.augment = default_augmentation(image_size=image_size, crop_size=crop_size) if augment_fn is None else augment_fn
        self.beta = beta
        self.encoder = EncoderWrapper(
            model, projection_size, hidden_size, layer=hidden_layer
        )
        self.predictor = nn.Linear(projection_size, projection_size, hidden_size)
        self.hparams = hparams
        self._target = None

        self.encoder(torch.zeros(2, 3, *image_size))
        
        self.use_amp = True

    def forward(self, x: Tensor) -> Tensor:
        return self.predictor(self.encoder(x))

    @property
    def target(self):
        if self._target is None:
            self._target = deepcopy(self.encoder)
        return self._target

    def update_target(self):
        for p, pt in zip(self.encoder.parameters(), self.target.parameters()):
            pt.data = self.beta * pt.data + (1 - self.beta) * p.data

    # --- Methods required for PyTorch Lightning only! ---

    def configure_optimizers(self):
        optimizer = getattr(optim, self.hparams.get("optimizer", "Adam"))
        lr = self.hparams.get("lr", 1e-4)
        weight_decay = self.hparams.get("weight_decay", 1e-6)
        return optimizer(self.parameters(), lr=lr, weight_decay=weight_decay)

    def training_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x = batch[0]
        with torch.no_grad():
            x1, x2 = self.augment(x), self.augment(x)

        pred1, pred2 = self.forward(x1), self.forward(x2)
        with torch.no_grad():
            targ1, targ2 = self.target(x1), self.target(x2)
        loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1))

        self.log("train_loss",loss,on_step=False,on_epoch=True,prog_bar=True,logger=True)
        return {"loss": loss}

    def validation_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x = batch[0]
        x1, x2 = self.augment(x), self.augment(x)
        pred1, pred2 = self.forward(x1), self.forward(x2)
        targ1, targ2 = self.target(x1), self.target(x2)
        loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1))
        self.log("val_loss",loss,on_epoch=True,prog_bar=True,logger=True)
        return {"val_loss": loss}

    def validation_epoch_end(self, outputs: List[Dict]) -> Dict:
        val_loss = sum(x["val_loss"] for x in outputs) / len(outputs)
        #self.log("val_loss",val_loss,on_epoch=True,prog_bar=True,logger=True)
        print("avg_val_loss:", val_loss)
        
        #model_path = 'BYOL_training_weight.pth'
        #torch.save(self.encoder.model.state_dict(), model_path)

## Supervised Training Module

In [5]:
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import ModelCheckpoint,EarlyStopping


class SupervisedLightningModule(pl.LightningModule):
    def __init__(self, model: nn.Module, n_classes, **hparams):
        super().__init__()
        self.model = model
        
        self.model.fc = nn.Linear(self.model.fc.in_features, n_classes)
        
        self.use_amp = True

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def configure_optimizers(self):
        optimizer = getattr(optim, self.hparams.get("optimizer", "Adam"))
        lr = self.hparams.get("lr", 1e-4)
        weight_decay = self.hparams.get("weight_decay", 1e-6)
        optimizer = optimizer(self.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=supervised_max_epochs, T_mult=1, eta_min=1e-6)
        return [optimizer], [scheduler]

    def training_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x, y = batch
        y_hat = self.forward(x)
        loss = f.cross_entropy(y_hat, y)
        acc = accuracy(y_hat,y)
        self.log("train_acc",acc,on_step=False,on_epoch=True,prog_bar=True,logger=True)
        self.log("train_loss",loss,on_step=False,on_epoch=True,prog_bar=True,logger=True)
        return {"loss": loss, "acc": acc}

    def validation_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x, y = batch
        y_hat = self.forward(x)
        loss = f.cross_entropy(y_hat, y)
        acc = accuracy(y_hat,y)
        self.log("val_acc",acc,on_epoch=True,prog_bar=True,logger=True),
        self.log("val_loss",loss,on_epoch=True,prog_bar=True,logger=True)
        return {"val_loss": loss, "val_acc": acc}

    #def validation_epoch_end(self, outputs: List[Dict]) -> Dict:
    #    val_loss = sum(x["loss"] for x in outputs) / len(outputs)


def accuracy(pred: Tensor, labels: Tensor) -> float:
    return (pred.argmax(dim=-1) == labels).float().mean().item()

# Cassava Dataset

In [6]:
import cv2
from PIL import Image
import pandas as pd
from torch.utils.data import Dataset,DataLoader

class CassavaDataset(Dataset):
    def __init__(self, df:pd.DataFrame, train:bool = True, transforms=None):
        self.df = df
        self.train = train
        self.transforms = transforms
        
    def __getitem__(self,index):
        im_path = self.df.iloc[index]['file_path']
        x = cv2.imread(im_path,cv2.IMREAD_COLOR)
        x = cv2.cvtColor(x,cv2.COLOR_BGR2RGB)
        
        x = self.transforms(image=x)['image']
        
        if(self.train):
            y = self.df.iloc[index]['label']    
            return (x, y)    
        else:
            return (x)
        
    def __len__(self):
        return len(self.df)

In [7]:
import glob
import pytorch_lightning as pl
from torchvision.transforms import ToTensor
from sklearn.model_selection import StratifiedKFold

class CassavaDataModule(pl.LightningDataModule):
    def __init__(self,
                 train_torchvision_transform,
                 valid_torchvision_transform,
                 #data_dir=r"C:\Users\81908\jupyter_notebook\pytorch_lightning_work\kaggle_Cassava\input\cassava-leaf-disease-classification",
                 #old_data_dir=r"C:\Users\81908\jupyter_notebook\pytorch_lightning_work\kaggle_Cassava\old_input\kaggle_upload",
                 data_dir=r"C:\Users\shingo\jupyter_notebook\pytorch_lightning_work\kaggle_Cassava\input\cassava-leaf-disease-classification",
                 old_data_dir=r"C:\Users\shingo\jupyter_notebook\pytorch_lightning_work\kaggle_Cassava\2019_compe_data\kaggle_upload",
                 n_splits=5,
                 random_state=0,
                 batch_size=128,
                 num_workers=4,
                 is_self_supervised=False,
                ):
        super().__init__()
        self.data_dir = data_dir
        self.old_data_dir=old_data_dir
        self.n_splits = n_splits
        self.random_state = random_state
        self.batch_size = batch_size
        self.num_workers=num_workers
        self.train_transform = train_torchvision_transform
        self.valid_transform = valid_torchvision_transform
        self.is_self_supervised = is_self_supervised
        self.train_df, self.valid_df = None, None
        
    def prepare_data(self):
        train = pd.read_csv(f"{self.data_dir}/train.csv")
        
        train["file_path"] = f"{self.data_dir}/train_images/"+ train["image_id"]
        
        cv = StratifiedKFold(n_splits=self.n_splits, shuffle=True, random_state=self.random_state)
        for j, (train_idx, valid_idx) in enumerate(cv.split(train, train["label"])):
            self.train_df, self.valid_df = train.iloc[train_idx], train.iloc[valid_idx]
            break
        
        # 旧コンペのデータ追加（Self-Supervised用）
        if self.is_self_supervised:
            old_train = glob.glob(f"{self.old_data_dir}/train/*/*jpg") + glob.glob(f"{self.old_data_dir}/test/*/*jpg") + glob.glob(f"{self.old_data_dir}/extraimages/*jpg")
            old_train = pd.DataFrame({"image_id":[None]*len(old_train), "label":[-1]*len(old_train), "file_path":old_train})
            self.train_df = pd.concat([self.train_df, old_train])
            
        if DEBUG:
            self.train_df = self.train_df.iloc[:150]
            self.valid_df = self.valid_df.iloc[:30]
        
        print("train_df.shape, valid_df.shape:", self.train_df.shape, self.valid_df.shape)
    
    def setup(self,stage=None):
        self.train_dataset = CassavaDataset(self.train_df,
                                            train=True,
                                            transforms=self.train_transform)
        
        self.valid_dataset = CassavaDataset(self.valid_df,
                                            train=True,
                                            transforms=self.valid_transform)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          drop_last=True,
                          shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.valid_dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False)

# params

In [8]:
import os, shutil
if os.path.exists("lightning_logs/"):
    shutil.rmtree("lightning_logs/")
    os.mkdir("lightning_logs/")

In [9]:
n_classes = 5

im_h, im_w = 600, 800
crop_h, crop_w = 512, 512
batch_size = 3
#batch_size = 3 * 2
num_workers = 0
supervised_max_epochs = 20
self_supervised_max_epochs = 100
patience = 15

#DEBUG = True
DEBUG = False
if DEBUG:
    im_h, im_w = 224, 224
    crop_h, crop_w = 224, 224
    batch_size = 16
    #batch_size = 128
    supervised_max_epochs = 1
    #supervised_max_epochs = 5
    self_supervised_max_epochs = 1
    #self_supervised_max_epochs = 10
    print("DEBUG:", DEBUG)

In [10]:
# https://www.kaggle.com/mekhdigakhramanian/pytorch-efficientnet-baseline-inference-tta/data?select=tf_efficientnet_b3_ns_fold_0_0

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

train_transform = A.Compose(
    [
        A.Resize(im_h, im_w),
        A.RandomResizedCrop(crop_h, crop_w),
        A.Transpose(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(p=0.5),
        A.HueSaturationValue(
            hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5
        ),
        A.RandomBrightnessContrast(
            brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5
        ),
        A.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0,
        ),
        A.CoarseDropout(p=0.5),
        A.Cutout(p=0.5),
        ToTensorV2(p=1.0),
    ],
    p=1.0,
)

valid_transform = A.Compose(
    [
        A.Resize(im_h, im_w),
        A.CenterCrop(crop_h, crop_w),
        A.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0,
        ),
        ToTensorV2(p=1.0),
    ],
    p=1.0,
)

In [11]:
#from torchvision.models import resnet18, resnet50, resnext50_32x4d, resnext101_32x8d
# resnext50_32x4d はBYOLの学習1epoch 13分ぐらい
# resnext101_32x8d はBYOLの学習1epoch 30分ぐらい

import timm

# Supervised Training without BYOL

In [12]:
%%time
#model = resnet50(pretrained=True)
model = timm.create_model("seresnext50_32x4d", pretrained=True)

model = SupervisedLightningModule(model, n_classes)

early_stopping = EarlyStopping("val_acc", 
                               patience=patience, 
                               mode='max')
model_checkpoint = ModelCheckpoint(
    monitor="val_acc",
    save_top_k=1,
    mode="max",
)
trainer = pl.Trainer(
    max_epochs=supervised_max_epochs, 
    gpus=-1, 
    weights_summary=None,
    callbacks=[model_checkpoint, early_stopping],
)

dm = CassavaDataModule(train_torchvision_transform=train_transform, 
                       valid_torchvision_transform=valid_transform, 
                       batch_size=batch_size, 
                       num_workers=num_workers)
trainer.fit(model, dm)

model_path = 'supervised_weight.pth'
torch.save(model.state_dict(), model_path)

model.cuda()
acc = sum([accuracy(model(x.cuda()), y.cuda()) for x, y in dm.val_dataloader()]) / len(dm.val_dataloader())
print(f"Accuracy: {acc:.3f}")

del model
torch.cuda.empty_cache()

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


train_df.shape, valid_df.shape: (17117, 3) (4280, 3)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…


Accuracy: 0.887
Wall time: 8h 52min 56s


# Self-Supervised Training with BYOL

In [13]:
#%%time
##model = resnet50(pretrained=True)
#model = timm.create_model("seresnext50_32x4d", pretrained=True)
#
#byol = BYOL(model, image_size=(im_h, im_w), crop_size=(crop_h, crop_w))
#
#trainer = pl.Trainer(
#    max_epochs=self_supervised_max_epochs, 
#    gpus=-1,
#    accumulate_grad_batches=2048 // batch_size,
#    weights_summary=None,
#    callbacks=[model_checkpoint, early_stopping],
#)
#
#dm = CassavaDataModule(train_torchvision_transform=transform, 
#                       valid_torchvision_transform=transform, 
#                       batch_size=batch_size, 
#                       num_workers=num_workers, 
#                       is_self_supervised=True
#                      )
#trainer.fit(byol, dm)
#
#model_path = 'BYOL_weight.pth'
#torch.save(model.state_dict(), model_path)

# Supervised Training Again

In [14]:
%%time
#model = resnet50()
model = timm.create_model("seresnext50_32x4d")
model_path = r'C:\Users\shingo\jupyter_notebook\pytorch_lightning_work\kaggle_Cassava\notebook\byol\byol-pytorch_seresnet50_v2\byol-pytorch_seresnext50_32x4d_512\BYOL_weight.pth'
model.load_state_dict(torch.load(model_path))

model = SupervisedLightningModule(model, n_classes)

early_stopping = EarlyStopping("val_acc", 
                               patience=patience, 
                               mode='max')
model_checkpoint = ModelCheckpoint(
    monitor="val_acc",
    save_top_k=1,
    mode="max",
)
trainer = pl.Trainer(
    max_epochs=supervised_max_epochs, 
    gpus=-1,
    weights_summary=None,
    callbacks=[model_checkpoint, early_stopping],
)

dm = CassavaDataModule(train_torchvision_transform=train_transform, 
                       valid_torchvision_transform=valid_transform, 
                       batch_size=batch_size, 
                       num_workers=num_workers)
trainer.fit(model, dm)

model_path = 'load_BYOL_supervised_weight.pth'
torch.save(model.state_dict(), model_path)

model.cuda()
acc = sum([accuracy(model(x.cuda()), y.cuda()) for x, y in dm.val_dataloader()]) / len(dm.val_dataloader())
print(f"Accuracy: {acc:.3f}")

del model
torch.cuda.empty_cache()

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


train_df.shape, valid_df.shape: (17117, 3) (4280, 3)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…


Accuracy: 0.892
Wall time: 8h 51min 43s


In [None]:
# tensorboard --logdir ./lightning_logs