# Colab-efficientnet-lightning

GridMixup: [IlyaDobrynin/GridMixup](https://github.com/IlyaDobrynin/GridMixup)

Diffaug: [mit-han-lab/data-efficient-gans](https://github.com/mit-han-lab/data-efficient-gans)

AdamP: [clovaai/AdamP](https://github.com/clovaai/AdamP)

EfficientNet repo: [lukemelas/EfficientNet-PyTorch](https://github.com/lukemelas/EfficientNet-PyTorch)

Original repo: [bentrevett/pytorch-image-classification](https://github.com/bentrevett/pytorch-image-classification)

My fork: [styler00dollar/Colab-image-classification](https://github.com/styler00dollar/Colab-image-classification)

Currently trains without validation and displays training-only metrics. Saves model as pytorch ``pth``. Uses GridMix loss as default.

In [None]:
!nvidia-smi

In [None]:
#@title Install
!pip install efficientnet_pytorch
!pip install adamp
!pip install pytorch_lightning
!pip install tensorboardX

In [None]:
#@title print means and stds
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from tqdm import tqdm
train_dir = '/content/data/' #@param
train_data = datasets.ImageFolder(root = train_dir, 
                                  transform = transforms.ToTensor())

means = torch.zeros(3)
stds = torch.zeros(3)

for img, label in tqdm(train_data):
    means += torch.mean(img, dim = (1,2))
    stds += torch.std(img, dim = (1,2))

means /= len(train_data)
stds /= len(train_data)

print("\n")
print(f'Calculated means: {means}')
print(f'Calculated stds: {stds}')

In [None]:
#@title init.py
import torch.nn.init as init

def weights_init(net, init_type = 'kaiming', init_gain = 0.02):
    #Initialize network weights.
    #Parameters:
    #    net (network)       -- network to be initialized
    #    init_type (str)     -- the name of an initialization method: normal | xavier | kaiming | orthogonal
    #    init_var (float)    -- scaling factor for normal, xavier and orthogonal.

    def init_func(m):
        classname = m.__class__.__name__

        if hasattr(m, 'weight') and classname.find('Conv') != -1:
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain = init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain = init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, 0.02)
            init.constant_(m.bias.data, 0.0)
        elif classname.find('Linear') != -1:
            init.normal_(m.weight, 0, 0.01)
            init.constant_(m.bias, 0)

    # Apply the initialization function <init_func>
    print('Initialization method [{:s}]'.format(init_type))
    net.apply(init_func)

In [None]:
#@title dataloader.py
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torchvision.datasets as datasets

class DataModule(pl.LightningDataModule):
    def __init__(self, training_path: str = './', validation_path: str = './', test_path: str = './', batch_size: int = 5, num_workers: int = 2, size = 256):
        super().__init__()
        self.training_dir = training_path
        self.validation_dir = validation_path
        self.test_dir = test_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.size = size

        self.means = [0.6013, 0.5443, 0.5521] #@param {type:"raw"}
        self.std = [0.2496, 0.2509, 0.2433] #@param {type:"raw"}

    def setup(self, stage=None):
        img_tf = transforms.Compose([
            transforms.Resize(size=self.size),
            transforms.RandomRotation(5),
            #transforms.CenterCrop(size=self.size),
            transforms.RandomCrop(self.size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean = self.means, 
                                  std = self.std)
        ])
        
        self.DS_train = datasets.ImageFolder(root = self.training_dir, 
                                  transform = img_tf)
        self.DS_validation = datasets.ImageFolder(root = self.training_dir, 
                                  transform = img_tf)
        self.DS_test = datasets.ImageFolder(root = self.training_dir, 
                                  transform = img_tf)

    def train_dataloader(self):
        return DataLoader(self.DS_train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.DS_validation, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.DS_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [None]:
#@title GridMixupLoss.py
"""
gridmix_pytorch.py (27-3-20)
https://github.com/IlyaDobrynin/GridMixup/blob/main/gridmix/gridmix_pytorch.py
"""
import typing as t
import random
import numpy as np
import math
import torch
from torch import nn


class GridMixupLoss(nn.Module):
    """ Implementation of GridMixup loss

    :param alpha: Percent of the first image on the crop. Can be float or Tuple[float, float]
                    - if float: lambda parameter gets from the beta-distribution np.random.beta(alpha, alpha)
                    - if Tuple[float, float]: lambda parameter gets from the uniform
                        distribution np.random.uniform(alpha[0], alpha[1])
    :param n_holes_x: Number of holes by OX
    :param hole_aspect_ratio: hole aspect ratio
    :param crop_area_ratio: Define percentage of the crop area
    :param crop_aspect_ratio: Define crop aspect ratio
    """
    def __init__(
            self,
            alpha: t.Union[float, t.Tuple[float, float]] = (0.1, 0.9),
            n_holes_x: t.Union[int, t.Tuple[int, int]] = 20,
            hole_aspect_ratio: t.Union[float, t.Tuple[float, float]] = 1.,
            crop_area_ratio: t.Union[float, t.Tuple[float, float]] = 1.,
            crop_aspect_ratio: t.Union[float, t.Tuple[float, float]] = 1.,
    ):

        super().__init__()
        self.alpha = alpha
        self.n_holes_x = n_holes_x
        self.hole_aspect_ratio = hole_aspect_ratio
        self.crop_area_ratio = crop_area_ratio
        self.crop_aspect_ratio = crop_aspect_ratio
        if isinstance(self.n_holes_x, int):
            self.n_holes_x = (self.n_holes_x, self.n_holes_x)
        if isinstance(self.hole_aspect_ratio, float):
            self.hole_aspect_ratio = (self.hole_aspect_ratio, self.hole_aspect_ratio)
        if isinstance(self.crop_area_ratio, float):
            self.crop_area_ratio = (self.crop_area_ratio, self.crop_area_ratio)
        if isinstance(self.crop_aspect_ratio, float):
            self.crop_aspect_ratio = (self.crop_aspect_ratio, self.crop_aspect_ratio)

        self.loss = nn.CrossEntropyLoss()

    def __str__(self):
        return "gridmixup"

    @staticmethod
    def _get_random_crop(height: int, width: int, crop_area_ratio: float, crop_aspect_ratio: float) -> t.Tuple:
        crop_area = int(height * width * crop_area_ratio)
        crop_width = int(np.sqrt(crop_area / crop_aspect_ratio))
        crop_height = int(crop_width * crop_aspect_ratio)

        cx = np.random.random()
        cy = np.random.random()

        y1 = int((height - crop_height) * cy)
        y2 = y1 + crop_height
        x1 = int((width - crop_width) * cx)
        x2 = x1 + crop_width
        return x1, y1, x2, y2

    def _get_gridmask(
            self,
            image_shape: t.Tuple[int, int],
            crop_area_ratio: float,
            crop_aspect_ratio: float,
            lam: float,
            nx: int,
            ar: float,
    ) -> np.ndarray:
        """ Method make grid mask

        :param image_shape: Shape of the images
        :param lam: Lambda parameter
        :param crop_area_ratio: Ratio of the crop area
        :param crop_aspect_ratio: Aspect ratio of the crop
        :param nx: Amount of holes by width
        :param ar: Aspect ratio of the hole
        :return: Binary mask, where holes == 1, background == 0
        """
        img_height, img_width = image_shape

        # Get coordinates of random box
        xc1, yc1, xc2, yc2 = self._get_random_crop(
            height=img_height,
            width=img_width,
            crop_area_ratio=crop_area_ratio,
            crop_aspect_ratio=crop_aspect_ratio
        )
        height = yc2 - yc1
        width = xc2 - xc1

        if not 1 <= nx <= width // 2:
            raise ValueError(
                f"The nx must be between 1 and {width // 2}.\n"
                f"Give: {nx}"
            )

        # Get patch width, height and ny
        patch_width = math.ceil(width / nx)
        patch_height = int(patch_width * ar)
        ny = math.ceil(height / patch_height)

        # Calculate ratio of the hole - percent of hole pixels in the patch
        ratio = np.sqrt(1 - lam)

        # Get hole size
        hole_width = int(patch_width * ratio)
        hole_height = int(patch_height * ratio)

        # min 1 pixel and max patch length - 1
        hole_width = min(max(hole_width, 1), patch_width - 1)
        hole_height = min(max(hole_height, 1), patch_height - 1)

        # Make grid mask
        holes = []
        for i in range(nx + 1):
            for j in range(ny + 1):
                x1 = min(patch_width * i, width)
                y1 = min(patch_height * j, height)
                x2 = min(x1 + hole_width, width)
                y2 = min(y1 + hole_height, height)
                holes.append((x1, y1, x2, y2))

        mask = np.zeros(shape=image_shape, dtype=np.uint8)
        for x1, y1, x2, y2 in holes:
            mask[yc1+y1:yc1+y2, xc1+x1:xc1+x2] = 1
        return mask

    def get_sample(self, images: torch.Tensor, targets: torch.Tensor) -> t.Tuple[torch.Tensor, torch.Tensor]:
        """ Method returns augmented images and targets

        :param images: Batch of non-augmented images
        :param targets: Batch of non-augmented targets
        :return: Augmented images and targets
        """
        # Get new indices
        indices = torch.randperm(images.size(0)).to(images.device)

        # Shuffle labels
        shuffled_targets = targets[indices].to(targets.device)

        # Get image shape
        height, width = images.shape[2:]

        # Get lambda
        if isinstance(self.alpha, float):
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = np.random.uniform(self.alpha[0], self.alpha[1])

        nx = random.randint(self.n_holes_x[0], self.n_holes_x[1])
        ar = np.random.uniform(self.hole_aspect_ratio[0], self.hole_aspect_ratio[1])
        crop_area_ratio = np.random.uniform(self.crop_area_ratio[0], self.crop_area_ratio[1])
        crop_aspect_ratio = np.random.uniform(self.crop_aspect_ratio[0], self.crop_aspect_ratio[1])
        mask = self._get_gridmask(
            image_shape=(height, width),
            crop_area_ratio=crop_area_ratio,
            crop_aspect_ratio=crop_aspect_ratio,
            lam=lam,
            nx=nx,
            ar=ar
        )
        # Adjust lambda to exactly match pixel ratio
        lam = 1 - (mask.sum() / (images.size()[-1] * images.size()[-2]))

        # Make shuffled images
        mask = torch.from_numpy(mask).to(targets.device)
        images = images * (1 - mask) + images[indices, ...] * mask

        # Prepare out labels
        lam_list = torch.from_numpy(np.ones(shape=targets.shape) * lam).to(targets.device)
        out_targets = torch.cat([targets, shuffled_targets, lam_list], dim=1).transpose(0, 1)
        return images, out_targets

    def forward(self, preds: torch.Tensor, trues: torch.Tensor) -> torch.Tensor:
        lam = trues[-1][0].float()
        trues1, trues2 = trues[0].long(), trues[1].long()
        loss = self.loss(preds, trues1) * lam + self.loss(preds, trues2) * (1 - lam)
        return loss


In [None]:
#@title diffaug.py
"""
DiffAugment_pytorch.py (27-3-20)
https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment_pytorch.py
"""
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738

import torch
import torch.nn.functional as F


def DiffAugment(x, policy='', channels_first=True):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x


def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x


def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x


def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.125):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x


def rand_cutout(x, ratio=0.5):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}

In [None]:
#@title accuracy.py
from sklearn.metrics import accuracy_score

def calc_accuracy(preds: torch.Tensor, trues: torch.Tensor) -> float:
    lam = trues[-1, :][0].data.cpu().numpy()
    true_label = [trues[0, :].long(), trues[1, :].long()]
    trues = true_label[0] if lam > 0.5 else true_label[1]
    trues = trues.data.cpu().numpy().astype(np.uint8)
    preds = torch.softmax(preds, dim=1).float()
    preds = np.argmax(preds.data.cpu().numpy(), axis=1).astype(np.uint8)
    metric = accuracy_score(trues, preds)
    return float(metric)

In [None]:
#@title CustomTrainClass.py
from efficientnet_pytorch import EfficientNet
from adamp import AdamP
#from adamp import SGDP
import numpy as np

class CustomTrainClass(pl.LightningModule):
  def __init__(self):
    super().__init__()
    model_train = 'efficientnet-b0' #@param ["efficientnet-b0", "efficientnet-b1", "efficientnet-b2", "efficientnet-b3", "efficientnet-b4", "efficientnet-b5", "efficientnet-b6", "efficientnet-b7"] {type:"string"}
    num_classes = 2 #@param
    if model_train == 'efficientnet-b0':
      self.netD = EfficientNet.from_pretrained('efficientnet-b0', num_classes=num_classes)
    elif model_train == 'efficientnet-b1':
      self.netD = EfficientNet.from_pretrained('efficientnet-b1', num_classes=num_classes)
    elif model_train == 'efficientnet-b2':
      self.netD = EfficientNet.from_pretrained('efficientnet-b2', num_classes=num_classes)
    elif model_train == 'efficientnet-b3':
      self.netD = EfficientNet.from_pretrained('efficientnet-b3', num_classes=num_classes)
    elif model_train == 'efficientnet-b4':
      self.netD = EfficientNet.from_pretrained('efficientnet-b4', num_classes=num_classes)
    elif model_train == 'efficientnet-b5':
      self.netD = EfficientNet.from_pretrained('efficientnet-b5', num_classes=num_classes)
    elif model_train == 'efficientnet-b6':
      self.netD = EfficientNet.from_pretrained('efficientnet-b6', num_classes=num_classes)
    elif model_train == 'efficientnet-b7':
      self.netD = EfficientNet.from_pretrained('efficientnet-b7', num_classes=num_classes)

    #weights_init(self.netD, 'kaiming') #only use this if there is no pretrain

    self.criterion = GridMixupLoss(
        alpha=(0.4, 0.7),
        hole_aspect_ratio=1.,
        crop_area_ratio=(0.5, 1),
        crop_aspect_ratio=(0.5, 2),
        n_holes_x=(2, 6)
    )
    self.accuracy = []
    self.losses = []
    self.diffaug_activate = False #@param

    #@markdown Supports ``'color,translation,cutout'``
    self.policy = 'color' #@param

  def training_step(self, train_batch, batch_idx):
    inputs, targets = self.criterion.get_sample(images=train_batch[0], targets=train_batch[1].unsqueeze(-1))  
    targets = targets -1 # fixing range

    if self.diffaug_activate == False:
      preds = self.netD(inputs)
    else:
      preds = self.netD(DiffAugment(inputs, policy=self.policy))

    # Calculate loss
    loss = self.criterion(preds, targets) 
    
    writer.add_scalar('loss', loss, self.current_epoch)

    self.accuracy.append(calc_accuracy(preds, targets))
    self.losses.append(loss.item())
    return loss  

  def configure_optimizers(self):
      #optimizer = torch.optim.Adam(self.netD.parameters(), lr=2e-3)
      optimizer = AdamP(self.netD.parameters(), lr=2e-4, betas=(0.9, 0.999), weight_decay=1e-2)
      #optimizer = SGDP(self.netD.parameters(), lr=0.1, weight_decay=1e-5, momentum=0.9, nesterov=True)
      return optimizer

  def training_epoch_end(self, training_step_outputs):
      loss_mean = np.mean(self.losses)
      accuracy_mean = np.mean(self.accuracy)
      print(f"'Epoch': {self.current_epoch}, 'loss': {loss_mean}, 'accuracy': {accuracy_mean}")
      
      # logging
      self.log('loss_mean', loss_mean, prog_bar=True, logger=True, on_epoch=True)
      self.log('accuracy_mean', accuracy_mean, prog_bar=True, logger=True, on_epoch=True)

      self.losses = []
      self.accuracy = []

      torch.save(trainer.model.netD.state_dict(), f"Checkpoint_{self.current_epoch}_{self.global_step}_loss_{loss_mean:3f}_acc_{accuracy_mean:3f}_D.pth")

  def validation_step(self, train_batch, train_idx):
      print("not implemented")
  def test_step(self, train_batch, train_idx):
      print("not implemented")

In [None]:
#@title training
dm = DataModule(batch_size=2, training_path='/content/data/', num_workers = 1, size = 256)
model = CustomTrainClass()
# skipping validation with limit_val_batches=0
trainer = pl.Trainer(limit_val_batches=0, gpus=1, max_epochs=150, progress_bar_refresh_rate=20, default_root_dir='/content/')
trainer.fit(model, dm)

# Test

In [None]:
#@title sort files with predictions (configured for 2 classes, uses efficientnet-b0 as base)
import torch
import glob
from efficientnet_pytorch import EfficientNet
import cv2
import torch.nn.functional as F
import shutil
import os
from tqdm import tqdm

model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=2)

model_path = '/content/Checkpoint_149_8849_loss_-1.678169_acc_0.855932_D.pth' #@param {type:"string"}
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

rootdir = '/content/data' #@param {type:"string"}

# probably depends on how categories are alphabetically sorted
path0 = '/content/output/0' #@param {type:"string"}
path1 = '/content/output/1' #@param {type:"string"}

if not os.path.exists(path0):
    os.makedirs(path0)
if not os.path.exists(path1):
    os.makedirs(path1)

files = glob.glob(rootdir + '/**/*.png', recursive=True)
files_jpg = glob.glob(rootdir + '/**/*.jpg', recursive=True)
files.extend(files_jpg)

model.to(device)

height_min = 256
width_min = 256

with torch.no_grad():
  for f in tqdm(files):
      image = cv2.imread(f)
      #image = cv2.resize(image, (256,256))

      # resizing to match original training, or detections will be bad
      height = image.shape[0]
      width = image.shape[1]
      if height > height_min and width > width_min:
          height_resized = height_min
          if width < height:
            scale_x = width_min/width
            width_resized = width_min
            height_resized = scale_x * height
          else:
            scale_y = height_min/height
            height_resized = height_min
            width_resized = scale_y * width
          image = cv2.resize(image, (int(width_resized), int(height_resized)))
      #elif height <= height_min or width <= width_min:
      #  break

      image = torch.from_numpy(image).unsqueeze(0).permute(0,3,1,2)/255
      
      image=image.to(device)

      y_pred = model(image)

      #y_prob = F.softmax(y_pred, dim = -1)
      #top_pred = y_prob.argmax(1, keepdim = True)

      y_pred = torch.softmax(y_pred, dim=1).float()
      top_pred = np.argmax(y_pred.data.cpu().numpy(), axis=1).astype(np.uint8)


      if top_pred == 0:
        shutil.move(f, os.path.join(path0, os.path.basename(f)))
      elif top_pred == 1:
        shutil.move(f, os.path.join(path1, os.path.basename(f)))